Skip to content

Commit e07607a

Browse files
authored
Merge pull request #1713 from JuliaRobotics/23Q2/perf/bitstypes
Change points to StaticArrays
2 parents cb23c2c + 98444e7 commit e07607a

11 files changed

+186
-48
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name = "IncrementalInference"
22
uuid = "904591bb-b899-562f-9e6f-b8df64c7d480"
33
keywords = ["MM-iSAMv2", "Bayes tree", "junction tree", "Bayes network", "variable elimination", "graphical models", "SLAM", "inference", "sum-product", "belief-propagation"]
44
desc = "Implements the Multimodal-iSAMv2 algorithm."
5-
version = "0.33.0"
5+
version = "0.34.0"
66

77
[deps]
88
ApproxManifoldProducts = "9bbbb610-88a1-53cd-9763-118ce10c1f89"
@@ -48,7 +48,7 @@ TimeZones = "f269a46b-ccf7-5d73-abea-4c690281aa53"
4848
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
4949

5050
[compat]
51-
ApproxManifoldProducts = "0.6.3"
51+
ApproxManifoldProducts = "0.7"
5252
BSON = "0.2, 0.3"
5353
Combinatorics = "1.0"
5454
DataStructures = "0.16, 0.17, 0.18"

src/ConsolidateParametricRelatives.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,12 @@ function solveFactorParameteric(
5252
# hasp ? getPPE(vari, key).suggested : calcMean(getBelief(vari, key))
5353
pt = calcMean(getBelief(vari, key))
5454

55-
return getCoordinates(getVariableType(vari), pt)
55+
return collect(getCoordinates(getVariableType(vari), pt))
5656
end
5757

5858
# overwrite specific src values from user
5959
coordVals = _getParametric.(getVariable.(dfg, varLbls), solveKey)
60+
6061
for (srcsym, currval) in srcsym_vals
6162
coordVals[findfirst(varLbls .== srcsym)] = currval
6263
end

src/Deprecated.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
## ================================================================================================
2+
## ================================================================================================
3+
4+
# TODO maybe upstream to DFG
5+
DFG.MeanMaxPPE(solveKey::Symbol, suggested::SVector, max::SVector, mean::SVector) =
6+
DFG.MeanMaxPPE(solveKey, collect(suggested), collect(max), collect(mean))
7+
8+
19
## ================================================================================================
210
## Manifolds.jl Consolidation
311
## TODO: Still to be completed and tested.

src/Factors/Circular.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ function (cf::CalcFactor{<:CircularCircular})(X, p, q)
2727
return distanceTangent2Point(M, X, p, q)
2828
end
2929

30+
function getSample(cf::CalcFactor{<:CircularCircular})
31+
# FIXME workaround for issue with manifolds CircularGroup,
32+
return [rand(cf.factor.Z)]
33+
end
34+
3035
function Base.convert(::Type{<:MB.AbstractManifold}, ::InstanceType{CircularCircular})
3136
return Manifolds.RealCircleGroup()
3237
end
@@ -59,6 +64,7 @@ function getSample(cf::CalcFactor{<:PriorCircular})
5964
# JuliaManifolds/Manifolds.jl#415
6065
# no method similar(::Float64, ::Type{Float64})
6166
return samplePoint(cf.manifold, cf.factor.Z, [0.0])
67+
# return [Manifolds.sym_rem(rand(cf.factor.Z))]
6268
end
6369

6470
function (cf::CalcFactor{<:PriorCircular})(m, p)

src/ManifoldsExtentions.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,21 +132,17 @@ function getPointIdentity(G::SemidirectProductGroup, ::Type{T} = Float64) where
132132
return ArrayPartition(np, hp)
133133
end
134134

135-
#FIXME fix back to SA
136135
function getPointIdentity(G::SpecialOrthogonal{N}, ::Type{T} = Float64) where {N, T <: Real}
137-
# return SMatrix{N,N, T}(I)
138-
return Matrix{T}(I, N, N)
136+
return SMatrix{N, N, T}(I)
139137
end
140138

141139
function getPointIdentity(
142140
G::TranslationGroup{Tuple{N}},
143141
::Type{T} = Float64,
144142
) where {N, T <: Real}
145-
# return zeros(SVector{N,T})
146-
return zeros(T, N)
143+
return zeros(SVector{N,T})
147144
end
148145

149146
function getPointIdentity(G::RealCircleGroup, ::Type{T} = Float64) where {T <: Real}
150-
# return zero(T)
151-
return [zero(T)]
147+
return [zero(T)] #FIXME we cannot support scalars yet
152148
end

src/ParametricUtils.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ end
5252

5353
function Base.setindex!(
5454
flatVar::FlatVariables{T},
55-
val::Vector{T},
55+
val::AbstractVector{T},
5656
vId::Symbol,
5757
) where {T <: Real}
5858
if length(val) == length(flatVar.idx[vId])
@@ -858,15 +858,15 @@ function initParametricFrom!(
858858
for v in getVariables(fg)
859859
fromvnd = getSolverData(v, fromkey)
860860
dims = getDimension(v)
861-
getSolverData(v, parkey).val[1] .= fromvnd.val[1]
862-
getSolverData(v, parkey).bw[1:dims, 1:dims] .= LinearAlgebra.I(dims)
861+
getSolverData(v, parkey).val[1] = fromvnd.val[1]
862+
getSolverData(v, parkey).bw[1:dims, 1:dims] = LinearAlgebra.I(dims)
863863
end
864864
else
865865
for var in getVariables(fg)
866866
dims = getDimension(var)
867867
μ, Σ = calcMeanCovar(var, fromkey)
868-
getSolverData(var, parkey).val[1] .= μ
869-
getSolverData(var, parkey).bw[1:dims, 1:dims] .= Σ
868+
getSolverData(var, parkey).val[1] = μ
869+
getSolverData(var, parkey).bw[1:dims, 1:dims] = Σ
870870
end
871871
end
872872
end
@@ -986,7 +986,7 @@ function autoinitParametric!(
986986

987987
vnd.initialized = true
988988
#fill in ppe as mean
989-
Xc = getCoordinates(getVariableType(xi), val)
989+
Xc = collect(getCoordinates(getVariableType(xi), val))
990990
ppe = MeanMaxPPE(:parametric, Xc, Xc, Xc)
991991
getPPEDict(xi)[:parametric] = ppe
992992

src/Variables/DefaultVariables.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,7 @@ $(TYPEDEF)
5050
Circular is a `Manifolds.Circle{ℝ}` mechanization of one rotation, with `theta in [-pi,pi)`.
5151
"""
5252
@defVariable Circular RealCircleGroup() [0.0;]
53+
#TODO This is an example of what we want working, possible issue upstream in Manifolds.jl
54+
# @defVariable Circular RealCircleGroup() Scalar(0.0)
5355

5456
#

src/services/EvalFactor.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -370,16 +370,6 @@ function evalPotentialSpecific(
370370
return ccwl.varValsAll[ccwl.varidx[]], ipc
371371
end
372372

373-
374-
#TODO workaround for supporting bitstypes, need rewrite, can do with `PowerManifoldNestedReplacing` or similar
375-
function AMP.setPointsMani!(dest::AbstractArray{T}, src::AbstractArray{U}, destIdx, srcIdx=destIdx) where {T<:AbstractArray,U<:AbstractArray}
376-
if isbitstype(T)
377-
dest[destIdx] = src[srcIdx]
378-
else
379-
setPointsMani!(dest[destIdx],src[srcIdx])
380-
end
381-
end
382-
383373
# TODO `measurement` might not be properly wired up yet
384374
# TODO consider 1051 here to inflate proposals as general behaviour
385375
function evalPotentialSpecific(
@@ -477,10 +467,12 @@ function evalPotentialSpecific(
477467

478468
setPointPartial!(
479469
mani,
480-
addEntr[m],
470+
addEntr,
481471
Msrc,
482-
ccwl.measurement[m], # FIXME, measurements are tangents=>relative or points=>priors
472+
ccwl.measurement, # FIXME, measurements are tangents=>relative or points=>priors
483473
partialCoords,
474+
m,
475+
m,
484476
asPartial,
485477
)
486478
else

src/services/FGOSUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ function manikde!(
119119
variableType::Union{InstanceType{<:InferenceVariable}, InstanceType{<:AbstractFactor}},
120120
pts::AbstractVector{P};
121121
kw...,
122-
) where {P <: Union{<:AbstractArray, <:Number, <:ProductRepr, <:Manifolds.ArrayPartition}}
122+
) where {P <: Union{<:AbstractArray, <:Number, <:Manifolds.ArrayPartition}}
123123
#
124124
M = getManifold(variableType)
125125
infoPerCoord = ones(AMP.getNumberCoords(M, pts[1]))

src/services/NumericalCalculations.jl

Lines changed: 135 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,11 @@ function _solveLambdaNumeric(
9898
r = if islen1
9999
Optim.optimize((x) -> (residual .= objResX(x); sum(residual .^ 2)), u0, Optim.BFGS())
100100
else
101-
Optim.optimize((x) -> (residual .= objResX(x); sum(residual .^ 2)), u0)
101+
Optim.optimize((x) -> (residual .= objResX(x); sum(residual .^ 2)), u0, Optim.Options(;iterations=1000))
102+
end
103+
104+
if !Optim.converged(r)
105+
@warn "Optim did not converge:" r maxlog=10
102106
end
103107

104108
#
@@ -113,6 +117,20 @@ function _solveLambdaNumeric(
113117
variableType::InferenceVariable,
114118
islen1::Bool = false,
115119
) where {N_, F <: AbstractManifoldMinimize, S, T}
120+
121+
return _solveCCWNumeric_test_SA(fcttype, objResX, residual, u0, variableType, islen1)
122+
# return _solveLambdaNumeric_test_optim_manifold(fcttype, objResX, residual, u0, variableType, islen1)
123+
124+
end
125+
126+
function _solveLambdaNumeric_original(
127+
fcttype::Union{F, <:Mixture{N_, F, S, T}},
128+
objResX::Function,
129+
residual::AbstractVector{<:Real},
130+
u0,#::AbstractVector{<:Real},
131+
variableType::InferenceVariable,
132+
islen1::Bool = false,
133+
) where {N_, F <: AbstractManifoldMinimize, S, T}
116134
#
117135
M = getManifold(variableType) #fcttype.M
118136
# the variable is a manifold point, we are working on the tangent plane in optim for now.
@@ -153,6 +171,93 @@ function _solveLambdaNumeric(
153171
return exp(M, ϵ, hat(M, ϵ, r.minimizer))
154172
end
155173

174+
# 1.355700 seconds (11.78 M allocations: 557.677 MiB, 6.96% gc time)
175+
function _solveCCWNumeric_test_SA(
176+
fcttype::Union{F, <:Mixture{N_, F, S, T}},
177+
objResX::Function,
178+
residual::AbstractVector{<:Real},
179+
u0,#::AbstractVector{<:Real},
180+
variableType::InferenceVariable,
181+
islen1::Bool = false,
182+
) where {N_, F <: AbstractManifoldMinimize, S, T}
183+
#
184+
M = getManifold(variableType) #fcttype.M
185+
# the variable is a manifold point, we are working on the tangent plane in optim for now.
186+
#
187+
#TODO this is not general to all manifolds, should work for lie groups.
188+
# ϵ = identity_element(M, u0)
189+
ϵ = getPointIdentity(variableType)
190+
191+
X0c = zero(MVector{getDimension(M),Float64})
192+
X0c .= vee(M, u0, log(M, ϵ, u0))
193+
194+
#TODO check performance
195+
function cost(Xc)
196+
X = hat(M, ϵ, Xc)
197+
p = exp(M, ϵ, X)
198+
residual = objResX(p)
199+
return sum(residual .^ 2)
200+
end
201+
202+
alg = islen1 ? Optim.BFGS() : Optim.NelderMead()
203+
204+
r = Optim.optimize(cost, X0c, alg)
205+
if !Optim.converged(r)
206+
# TODO find good way for a solve to store diagnostics about number of failed converges etc.
207+
@warn "Optim did not converge (maxlog=10):" r maxlog=10
208+
end
209+
return exp(M, ϵ, hat(M, ϵ, r.minimizer))
210+
end
211+
212+
# sloooowwww and does not always converge, unusable slow with gradient
213+
# NelderMead 5.513693 seconds (38.60 M allocations: 1.613 GiB, 6.62% gc time)
214+
function _solveLambdaNumeric_test_optim_manifold(
215+
fcttype::Union{F, <:Mixture{N_, F, S, T}},
216+
objResX::Function,
217+
residual::AbstractVector{<:Real},
218+
u0,#::AbstractVector{<:Real},
219+
variableType::InferenceVariable,
220+
islen1::Bool = false,
221+
) where {N_, F <: AbstractManifoldMinimize, S, T}
222+
#
223+
M = getManifold(variableType) #fcttype.M
224+
# the variable is a manifold point, we are working on the tangent plane in optim for now.
225+
#
226+
#TODO this is not general to all manifolds, should work for lie groups.
227+
ϵ = getPointIdentity(variableType)
228+
229+
function cost(p)
230+
residual = objResX(p)
231+
return sum(residual .^ 2)
232+
end
233+
234+
alg = islen1 ? Optim.BFGS(;manifold=ManifoldWrapper(M)) : Optim.NelderMead(;manifold=ManifoldWrapper(M))
235+
# alg = Optim.ConjugateGradient(; manifold=ManifoldWrapper(M))
236+
# alg = Optim.BFGS(; manifold=ManifoldWrapper(M))
237+
238+
# r_backend = ManifoldDiff.TangentDiffBackend(
239+
# ManifoldDiff.FiniteDifferencesBackend()
240+
# )
241+
242+
# ## finitediff gradient
243+
# function costgrad_FD!(X,p)
244+
# copyto!(X, ManifoldDiff.gradient(M, cost, p, r_backend))
245+
# X
246+
# end
247+
248+
u0_m = allocate(M, u0)
249+
u0_m .= u0
250+
# r = Optim.optimize(cost, costgrad_FD!, u0_m, alg)
251+
r = Optim.optimize(cost, u0_m, alg)
252+
253+
if !Optim.converged(r)
254+
@warn "Optim did not converge:" r maxlog=10
255+
end
256+
257+
return r.minimizer
258+
# return exp(M, ϵ, hat(M, ϵ, r.minimizer))
259+
end
260+
156261
#TODO Consolidate with _solveLambdaNumeric, see #1374
157262
function _solveLambdaNumericMeas(
158263
fcttype::Union{F, <:Mixture{N_, F, S, T}},
@@ -231,7 +336,7 @@ DevNotes
231336
function _buildCalcFactorLambdaSample(
232337
ccwl::CommonConvWrapper,
233338
smpid::Integer,
234-
target = view(ccwl.varValsAll[ccwl.varidx[]][smpid], ccwl.partialDims),
339+
target,
235340
measurement_ = ccwl.measurement;
236341
# fmd_::FactorMetadata = cpt_.factormetadata;
237342
_slack = nothing,
@@ -316,20 +421,39 @@ function _solveCCWNumeric!(
316421
islen1 = length(ccwl.partialDims) == 1 || ccwl.partial
317422
# islen1 = length(cpt_.X[:, smpid]) == 1 || ccwl.partial
318423

424+
if ccwl.partial
425+
target = view(ccwl.varValsAll[ccwl.varidx[]][smpid], ccwl.partialDims)
426+
else
427+
target = ccwl.varValsAll[ccwl.varidx[]][smpid];
428+
end
319429
# build the pre-objective function for this sample's hypothesis selection
320-
unrollHypo!, target = _buildCalcFactorLambdaSample(ccwl, smpid; _slack = _slack)
430+
unrollHypo!, _ = _buildCalcFactorLambdaSample(
431+
ccwl,
432+
smpid,
433+
target;
434+
_slack = _slack
435+
)
321436

322437
# broadcast updates original view memory location
323438
## using CalcFactor legacy path inside (::CalcFactor)
324-
_hypoObj = (x) -> (target .= x; unrollHypo!())
439+
440+
# _hypoObj = (x) -> (target[] = x; unrollHypo!())
441+
function _hypoObj(x)
442+
copyto!(target, x)
443+
return unrollHypo!()
444+
end
325445

326446
# TODO small off-manifold perturbation is a numerical workaround only, make on-manifold requires RoME.jl #244
327447
# use all element dimensions : ==> 1:ccwl.xDim
328-
target .+= _perturbIfNecessary(getFactorType(ccwl), length(target), perturb)
448+
# target .+= _perturbIfNecessary(getFactorType(ccwl), length(target), perturb)
329449

330450
sfidx = ccwl.varidx[]
331451
# do the parameter search over defined decision variables using Minimization
332-
X = ccwl.varValsAll[sfidx][smpid][ccwl.partialDims]
452+
if ccwl.partial
453+
X = collect(view(ccwl.varValsAll[sfidx][smpid], ccwl.partialDims))
454+
else
455+
X = ccwl.varValsAll[sfidx][smpid][ccwl.partialDims]
456+
end
333457
retval = _solveLambdaNumeric(
334458
getFactorType(ccwl),
335459
_hypoObj,
@@ -345,7 +469,11 @@ function _solveCCWNumeric!(
345469
end
346470

347471
# insert result back at the correct variable element location
348-
ccwl.varValsAll[sfidx][smpid][ccwl.partialDims] .= retval
472+
if ccwl.partial
473+
ccwl.varValsAll[sfidx][smpid][ccwl.partialDims] .= retval
474+
else
475+
copyto!(ccwl.varValsAll[sfidx][smpid], retval)
476+
end
349477

350478
return nothing
351479
end

0 commit comments

Comments
 (0)