Skip to content

Commit d2963ac

Browse files
authored
Merge pull request #1778 from JuliaRobotics/23Q3/dev/parametric
Refactor Parametric solve for better performance and better use of Manopt.jl
2 parents 525602e + 6f3e431 commit d2963ac

24 files changed

+943
-594
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ version = "0.34.1"
88
ApproxManifoldProducts = "9bbbb610-88a1-53cd-9763-118ce10c1f89"
99
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
1010
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
11-
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
1211
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1312
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1413
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
@@ -67,7 +66,6 @@ IncrInfrInterpolationsExt = "Interpolations"
6766
[compat]
6867
ApproxManifoldProducts = "0.7, 0.8"
6968
BSON = "0.2, 0.3"
70-
BlockArrays = "0.16"
7169
Combinatorics = "1.0"
7270
DataStructures = "0.16, 0.17, 0.18"
7371
DelimitedFiles = "1"

src/Factors/GenericFunctions.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ end
3737

3838
#::MeasurementOnTangent
3939
function distanceTangent2Point(M::SemidirectProductGroup, X, p, q)
40-
= Manifolds.compose(M, p, exp(M, identity_element(M, p), X)) #for groups
40+
= Manifolds.compose(M, p, exp(M, getPointIdentity(M), X)) #for groups
4141
# return log(M, q, q̂)
4242
return vee(M, q, log(M, q, q̂))
4343
# return distance(M, q, q̂)
@@ -96,7 +96,7 @@ end
9696

9797
# function (cf::CalcFactor{<:ManifoldFactor{<:AbstractDecoratorManifold}})(Xc, p, q)
9898
function (cf::CalcFactor{<:ManifoldFactor})(X, p, q)
99-
return distanceTangent2Point(cf.manifold, X, p, q)
99+
return distanceTangent2Point(cf.factor.M, X, p, q)
100100
end
101101

102102
## ======================================================================================
@@ -141,12 +141,20 @@ function getSample(cf::CalcFactor{<:ManifoldPrior})
141141
return point
142142
end
143143

144+
function getFactorMeasurementParametric(fac::ManifoldPrior)
145+
M = getManifold(fac)
146+
dims = manifold_dimension(M)
147+
meas = fac.p
148+
= convert(SMatrix{dims, dims}, invcov(fac.Z))
149+
meas, iΣ
150+
end
151+
144152
#TODO investigate SVector if small dims, this is slower
145153
# dim = manifold_dimension(M)
146154
# Xc = [SVector{dim}(rand(Z)) for _ in 1:N]
147155

148156
function (cf::CalcFactor{<:ManifoldPrior})(m, p)
149-
M = cf.manifold # .factor.M
157+
M = cf.factor.M
150158
# return log(M, p, m)
151159
return vee(M, p, log(M, p, m))
152160
# return distancePrior(M, m, p)

src/Factors/Mixture.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ function sampleFactor(cf::CalcFactor{<:Mixture}, N::Int = 1)
118118
## example case is old FluxModelsPose2Pose2 requiring velocity
119119
# FIXME better consolidation of when to pass down .mechanics, also see #1099 and #1094 and #1069
120120

121-
cf_ = CalcFactor(
121+
cf_ = CalcFactorNormSq(
122122
cf.factor.mechanics,
123123
0,
124124
cf._legacyParams,
@@ -133,10 +133,19 @@ function sampleFactor(cf::CalcFactor{<:Mixture}, N::Int = 1)
133133
#out memory should be right size first
134134
length(cf.factor.labels) != N ? resize!(cf.factor.labels, N) : nothing
135135
cf.factor.labels .= rand(cf.factor.diversity, N)
136+
M = cf.manifold
137+
138+
# mixture needs to be refactored so let's make it worse :-)
139+
if cf.factor.mechanics isa AbstractPrior
140+
samplef = samplePoint
141+
elseif cf.factor.mechanics isa AbstractRelative
142+
samplef = sampleTangent
143+
end
144+
136145
for i = 1:N
137146
mixComponent = cf.factor.components[cf.factor.labels[i]]
138147
# measurements relate to the factor's manifold (either tangent vector or manifold point)
139-
setPointsMani!(smpls[i], rand(mixComponent, 1))
148+
setPointsMani!(smpls, samplef(M, mixComponent), i)
140149
end
141150

142151
# TODO only does first element of meas::Tuple at this stage, see #1099

src/Factors/MsgPrior.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ end
3232
getManifold(mp::MsgPrior{<:ManifoldKernelDensity}) = mp.Z.manifold
3333
getManifold(mp::MsgPrior) = mp.M
3434

35+
#FIXME this will not work on manifolds
3536
(cfo::CalcFactor{<:MsgPrior})(z, x1) = z .- x1
3637

3738
Base.@kwdef struct PackedMsgPrior <: AbstractPackedFactor

src/IncrementalInference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ include("parametric/services/ConsolidateParametricRelatives.jl")
221221
include("parametric/services/ParametricCSMFunctions.jl")
222222
include("parametric/services/ParametricUtils.jl")
223223
include("parametric/services/ParametricOptim.jl")
224-
include("parametric/services/ParametricManoptDev.jl")
224+
include("parametric/services/ParametricManopt.jl")
225225
include("services/MaxMixture.jl")
226226

227227
#X-stroke

src/entities/AliasScalarSampling.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ struct AliasingScalarSampler
4646
end
4747
end
4848

49+
function sampleTangent(
50+
M::AbstractDecoratorManifold, # stand-in type to restrict to just group manifolds
51+
z::AliasingScalarSampler,
52+
p = getPointIdentity(M),
53+
)
54+
return hat(M, p, SVector{manifold_dimension(M)}(rand(z)))
55+
end
56+
4957
function rand!(ass::AliasingScalarSampler, smpls::Array{Float64})
5058
StatsBase.alias_sample!(ass.domain, ass.weights, smpls)
5159
return nothing

src/entities/CalcFactor.jl

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
abstract type AbstractMaxMixtureSolver end
33

44

5+
abstract type CalcFactor{T<:AbstractFactor} end
6+
7+
58
"""
69
$TYPEDEF
710
@@ -21,18 +24,19 @@ end
2124
2225
DevNotes
2326
- Follow the Github project in IIF to better consolidate CCW FMD CPT CF CFM
24-
27+
- TODO CalcFactorNormSq is a step towards having a dedicated structure for non-parametric solve.
28+
CalcFactorNormSq will calculate the Norm Squared of the factor.
2529
Related
2630
2731
[`CalcFactorMahalanobis`](@ref), [`CommonConvWrapper`](@ref)
2832
"""
29-
struct CalcFactor{
33+
struct CalcFactorNormSq{
3034
FT <: AbstractFactor,
3135
X,
3236
C,
3337
VT <: Tuple,
3438
M <: AbstractManifold
35-
}
39+
} <: CalcFactor{FT}
3640
""" the interface compliant user object functor containing the data and logic """
3741
factor::FT
3842
""" what is the sample (particle) id for which the residual is being calculated """
@@ -54,7 +58,15 @@ struct CalcFactor{
5458
manifold::M
5559
end
5660

57-
61+
#TODO deprecate after CalcFactor is updated to CalcFactorNormSq
62+
function CalcFactor(args...; kwargs...)
63+
Base.depwarn(
64+
"`CalcFactor` changed to an abstract type, use CalcFactorNormSq, CalcFactorMahalanobis, or CalcFactorResidual",
65+
:CalcFactor
66+
)
67+
68+
CalcFactorNormSq(args...; kwargs...)
69+
end
5870

5971
"""
6072
$TYPEDEF
@@ -65,32 +77,47 @@ Related
6577
6678
[`CalcFactor`](@ref)
6779
"""
68-
struct CalcFactorMahalanobis{N, D, L, S <: Union{Nothing, AbstractMaxMixtureSolver}}
80+
struct CalcFactorMahalanobis{
81+
FT,
82+
N,
83+
C,
84+
MEAS<:AbstractArray,
85+
D,
86+
L,
87+
S <: Union{Nothing, AbstractMaxMixtureSolver}
88+
} <: CalcFactor{FT}
6989
faclbl::Symbol
70-
calcfactor!::CalcFactor
90+
factor::FT
91+
cache::C
7192
varOrder::Vector{Symbol}
72-
meas::NTuple{N, <:AbstractArray}
93+
meas::NTuple{N, MEAS}
7394
::NTuple{N, SMatrix{D, D, Float64, L}}
7495
specialAlg::S
7596
end
7697

7798

78-
79-
80-
struct CalcFactorManopt{
99+
struct CalcFactorResidual{
100+
FT <: AbstractFactor,
101+
C,
81102
D,
82103
L,
83-
FT <: AbstractFactor,
84-
M <: AbstractManifold,
104+
P,
85105
MEAS <: AbstractArray,
86-
}
106+
N
107+
} <: CalcFactor{FT}
87108
faclbl::Symbol
88-
calcfactor!::CalcFactor{FT, Nothing, Nothing, Tuple{}, M}
89-
varOrder::Vector{Symbol}
90-
varOrderIdxs::Vector{Int}
109+
factor::FT
110+
cache::C
111+
varOrder::NTuple{N, Symbol}
112+
varOrderIdxs::NTuple{N, Int}
113+
points::P #TODO remove or not?
91114
meas::MEAS
92-
::SMatrix{D, D, Float64, L}
115+
::SMatrix{D, D, Float64, L} #TODO remove or not?
93116
sqrt_iΣ::SMatrix{D, D, Float64, L}
94117
end
95118

119+
_nvars(::CalcFactorResidual{FT, C, D, L, P, MEAS, N}) where {FT, C, D, L, P, MEAS, N} = N
120+
# _typeof_meas(::CalcFactorManopt{FT, C, D, L, MEAS, N}) where {FT, C, D, L, MEAS, N} = MEAS
121+
DFG.getDimension(::CalcFactorResidual{FT, C, D, L, P, MEAS, N}) where {FT, C, D, L, P, MEAS, N} = D
122+
96123

src/manifolds/services/ManifoldSampling.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@ end
3636
function sampleTangent(
3737
M::AbstractDecoratorManifold,
3838
z::Distribution,
39-
p = identity_element(M), #getPointIdentity(M),
39+
p = getPointIdentity(M),
4040
)
41-
return hat(M, p, rand(z, 1)[:]) #TODO find something better than (z,1)[:]
41+
return hat(M, p, SVector{length(z)}(rand(z))) #TODO make sure all Distribution has length,
42+
# if this errors maybe fall back no next line
43+
# return convert(typeof(p), hat(M, p, rand(z, 1)[:])) #TODO find something better than (z,1)[:]
4244
end
4345

4446
"""

src/manifolds/services/ManifoldsExtentions.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,54 +98,54 @@ end
9898

9999
import DistributedFactorGraphs: getPointIdentity
100100

101-
function getPointIdentity(G::ProductGroup, ::Type{T} = Float64) where {T <: Real}
101+
function DFG.getPointIdentity(G::ProductGroup, ::Type{T} = Float64) where {T <: Real}
102102
M = G.manifold
103103
return ArrayPartition(map(x -> getPointIdentity(x, T), M.manifolds))
104104
end
105105

106106
# fallback
107-
function getPointIdentity(G::GroupManifold, ::Type{T} = Float64) where {T <: Real}
107+
function DFG.getPointIdentity(G::GroupManifold, ::Type{T} = Float64) where {T <: Real}
108108
return error("getPointIdentity not implemented on $G")
109109
end
110110

111-
function getPointIdentity(
111+
function DFG.getPointIdentity(
112112
@nospecialize(G::ProductManifold),
113113
::Type{T} = Float64,
114114
) where {T <: Real}
115115
return ArrayPartition(map(x -> getPointIdentity(x, T), G.manifolds))
116116
end
117117

118-
function getPointIdentity(
118+
function DFG.getPointIdentity(
119119
@nospecialize(M::PowerManifold),
120120
::Type{T} = Float64,
121121
) where {T <: Real}
122122
N = Manifolds.get_iterator(M).stop
123123
return fill(getPointIdentity(M.manifold, T), N)
124124
end
125125

126-
function getPointIdentity(M::NPowerManifold, ::Type{T} = Float64) where {T <: Real}
126+
function DFG.getPointIdentity(M::NPowerManifold, ::Type{T} = Float64) where {T <: Real}
127127
return fill(getPointIdentity(M.manifold, T), M.N)
128128
end
129129

130-
function getPointIdentity(G::SemidirectProductGroup, ::Type{T} = Float64) where {T <: Real}
130+
function DFG.getPointIdentity(G::SemidirectProductGroup, ::Type{T} = Float64) where {T <: Real}
131131
M = base_manifold(G)
132132
N, H = M.manifolds
133133
np = getPointIdentity(N, T)
134134
hp = getPointIdentity(H, T)
135135
return ArrayPartition(np, hp)
136136
end
137137

138-
function getPointIdentity(G::SpecialOrthogonal{N}, ::Type{T} = Float64) where {N, T <: Real}
138+
function DFG.getPointIdentity(G::SpecialOrthogonal{N}, ::Type{T} = Float64) where {N, T <: Real}
139139
return SMatrix{N, N, T}(I)
140140
end
141141

142-
function getPointIdentity(
142+
function DFG.getPointIdentity(
143143
G::TranslationGroup{Tuple{N}},
144144
::Type{T} = Float64,
145145
) where {N, T <: Real}
146146
return zeros(SVector{N,T})
147147
end
148148

149-
function getPointIdentity(G::RealCircleGroup, ::Type{T} = Float64) where {T <: Real}
149+
function DFG.getPointIdentity(G::RealCircleGroup, ::Type{T} = Float64) where {T <: Real}
150150
return [zero(T)] #FIXME we cannot support scalars yet
151151
end

0 commit comments

Comments
 (0)