Skip to content

Commit 48c5ea5

Browse files
committed
WIP Change points to StaticArrays
1 parent bdb9abe commit 48c5ea5

File tree

9 files changed

+164
-42
lines changed

9 files changed

+164
-42
lines changed

src/ConsolidateParametricRelatives.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,12 @@ function solveFactorParameteric(
5252
# hasp ? getPPE(vari, key).suggested : calcMean(getBelief(vari, key))
5353
pt = calcMean(getBelief(vari, key))
5454

55-
return getCoordinates(getVariableType(vari), pt)
55+
return collect(getCoordinates(getVariableType(vari), pt))
5656
end
5757

5858
# overwrite specific src values from user
5959
coordVals = _getParametric.(getVariable.(dfg, varLbls), solveKey)
60+
6061
for (srcsym, currval) in srcsym_vals
6162
coordVals[findfirst(varLbls .== srcsym)] = currval
6263
end

src/Deprecated.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
## ================================================================================================
2+
## ================================================================================================
3+
4+
# TODO maybe upstream to DFG
5+
DFG.MeanMaxPPE(solveKey::Symbol, suggested::SVector, max::SVector, mean::SVector) =
6+
DFG.MeanMaxPPE(solveKey, collect(suggested), collect(max), collect(mean))
7+
8+
19
## ================================================================================================
210
## Manifolds.jl Consolidation
311
## TODO: Still to be completed and tested.

src/Factors/Circular.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ function (cf::CalcFactor{<:CircularCircular})(X, p, q)
2727
return distanceTangent2Point(M, X, p, q)
2828
end
2929

30+
function getSample(cf::CalcFactor{<:CircularCircular})
31+
# FIXME workaround for issue with manifolds CircularGroup,
32+
return [rand(cf.factor.Z)]
33+
end
34+
3035
function Base.convert(::Type{<:MB.AbstractManifold}, ::InstanceType{CircularCircular})
3136
return Manifolds.RealCircleGroup()
3237
end
@@ -59,6 +64,7 @@ function getSample(cf::CalcFactor{<:PriorCircular})
5964
# JuliaManifolds/Manifolds.jl#415
6065
# no method similar(::Float64, ::Type{Float64})
6166
return samplePoint(cf.manifold, cf.factor.Z, [0.0])
67+
# return [Manifolds.sym_rem(rand(cf.factor.Z))]
6268
end
6369

6470
function (cf::CalcFactor{<:PriorCircular})(m, p)

src/ManifoldsExtentions.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,21 +132,17 @@ function getPointIdentity(G::SemidirectProductGroup, ::Type{T} = Float64) where
132132
return ArrayPartition(np, hp)
133133
end
134134

135-
#FIXME fix back to SA
136135
function getPointIdentity(G::SpecialOrthogonal{N}, ::Type{T} = Float64) where {N, T <: Real}
137-
# return SMatrix{N,N, T}(I)
138-
return Matrix{T}(I, N, N)
136+
return SMatrix{N, N, T}(I)
139137
end
140138

141139
function getPointIdentity(
142140
G::TranslationGroup{Tuple{N}},
143141
::Type{T} = Float64,
144142
) where {N, T <: Real}
145-
# return zeros(SVector{N,T})
146-
return zeros(T, N)
143+
return zeros(SVector{N,T})
147144
end
148145

149146
function getPointIdentity(G::RealCircleGroup, ::Type{T} = Float64) where {T <: Real}
150-
# return zero(T)
151-
return [zero(T)]
147+
return zero(T)
152148
end

src/ParametricUtils.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ end
5252

5353
function Base.setindex!(
5454
flatVar::FlatVariables{T},
55-
val::Vector{T},
55+
val::AbstractVector{T},
5656
vId::Symbol,
5757
) where {T <: Real}
5858
if length(val) == length(flatVar.idx[vId])
@@ -858,15 +858,15 @@ function initParametricFrom!(
858858
for v in getVariables(fg)
859859
fromvnd = getSolverData(v, fromkey)
860860
dims = getDimension(v)
861-
getSolverData(v, parkey).val[1] .= fromvnd.val[1]
862-
getSolverData(v, parkey).bw[1:dims, 1:dims] .= LinearAlgebra.I(dims)
861+
getSolverData(v, parkey).val[1] = fromvnd.val[1]
862+
getSolverData(v, parkey).bw[1:dims, 1:dims] = LinearAlgebra.I(dims)
863863
end
864864
else
865865
for var in getVariables(fg)
866866
dims = getDimension(var)
867867
μ, Σ = calcMeanCovar(var, fromkey)
868-
getSolverData(var, parkey).val[1] .= μ
869-
getSolverData(var, parkey).bw[1:dims, 1:dims] .= Σ
868+
getSolverData(var, parkey).val[1] = μ
869+
getSolverData(var, parkey).bw[1:dims, 1:dims] = Σ
870870
end
871871
end
872872
end
@@ -986,7 +986,7 @@ function autoinitParametric!(
986986

987987
vnd.initialized = true
988988
#fill in ppe as mean
989-
Xc = getCoordinates(getVariableType(xi), val)
989+
Xc = collect(getCoordinates(getVariableType(xi), val))
990990
ppe = MeanMaxPPE(:parametric, Xc, Xc, Xc)
991991
getPPEDict(xi)[:parametric] = ppe
992992

src/Variables/DefaultVariables.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,7 @@ $(TYPEDEF)
5050
Circular is a `Manifolds.Circle{ℝ}` mechanization of one rotation, with `theta in [-pi,pi)`.
5151
"""
5252
@defVariable Circular RealCircleGroup() [0.0;]
53+
#TODO This is an example of what we want working, possible issue upstream in Manifolds.jl
54+
# @defVariable Circular RealCircleGroup() Scalar(0.0)
5355

5456
#

src/services/EvalFactor.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -370,16 +370,6 @@ function evalPotentialSpecific(
370370
return ccwl.varValsAll[ccwl.varidx[]], ipc
371371
end
372372

373-
374-
#TODO workaround for supporting bitstypes, need rewrite, can do with `PowerManifoldNestedReplacing` or similar
375-
function AMP.setPointsMani!(dest::AbstractArray{T}, src::AbstractArray{U}, destIdx, srcIdx=destIdx) where {T<:AbstractArray,U<:AbstractArray}
376-
if isbitstype(T)
377-
dest[destIdx] = src[srcIdx]
378-
else
379-
setPointsMani!(dest[destIdx],src[srcIdx])
380-
end
381-
end
382-
383373
# TODO `measurement` might not be properly wired up yet
384374
# TODO consider 1051 here to inflate proposals as general behaviour
385375
function evalPotentialSpecific(
@@ -477,10 +467,12 @@ function evalPotentialSpecific(
477467

478468
setPointPartial!(
479469
mani,
480-
addEntr[m],
470+
addEntr,
481471
Msrc,
482-
ccwl.measurement[m], # FIXME, measurements are tangents=>relative or points=>priors
472+
ccwl.measurement, # FIXME, measurements are tangents=>relative or points=>priors
483473
partialCoords,
474+
m,
475+
m,
484476
asPartial,
485477
)
486478
else

src/services/NumericalCalculations.jl

Lines changed: 120 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,11 @@ function _solveLambdaNumeric(
9898
r = if islen1
9999
Optim.optimize((x) -> (residual .= objResX(x); sum(residual .^ 2)), u0, Optim.BFGS())
100100
else
101-
Optim.optimize((x) -> (residual .= objResX(x); sum(residual .^ 2)), u0)
101+
Optim.optimize((x) -> (residual .= objResX(x); sum(residual .^ 2)), u0, Optim.Options(;iterations=1000))
102+
end
103+
104+
if !Optim.converged(r)
105+
@warn "Optim did not converge:" r maxlog=10
102106
end
103107

104108
#
@@ -113,6 +117,20 @@ function _solveLambdaNumeric(
113117
variableType::InferenceVariable,
114118
islen1::Bool = false,
115119
) where {N_, F <: AbstractManifoldMinimize, S, T}
120+
121+
return _solveCCWNumeric_test_SA(fcttype, objResX, residual, u0, variableType, islen1)
122+
# return _solveLambdaNumeric_test_optim_manifold(fcttype, objResX, residual, u0, variableType, islen1)
123+
124+
end
125+
126+
function _solveLambdaNumeric_original(
127+
fcttype::Union{F, <:Mixture{N_, F, S, T}},
128+
objResX::Function,
129+
residual::AbstractVector{<:Real},
130+
u0,#::AbstractVector{<:Real},
131+
variableType::InferenceVariable,
132+
islen1::Bool = false,
133+
) where {N_, F <: AbstractManifoldMinimize, S, T}
116134
#
117135
M = getManifold(variableType) #fcttype.M
118136
# the variable is a manifold point, we are working on the tangent plane in optim for now.
@@ -153,6 +171,93 @@ function _solveLambdaNumeric(
153171
return exp(M, ϵ, hat(M, ϵ, r.minimizer))
154172
end
155173

174+
# 1.355700 seconds (11.78 M allocations: 557.677 MiB, 6.96% gc time)
175+
function _solveCCWNumeric_test_SA(
176+
fcttype::Union{F, <:Mixture{N_, F, S, T}},
177+
objResX::Function,
178+
residual::AbstractVector{<:Real},
179+
u0,#::AbstractVector{<:Real},
180+
variableType::InferenceVariable,
181+
islen1::Bool = false,
182+
) where {N_, F <: AbstractManifoldMinimize, S, T}
183+
#
184+
M = getManifold(variableType) #fcttype.M
185+
# the variable is a manifold point, we are working on the tangent plane in optim for now.
186+
#
187+
#TODO this is not general to all manifolds, should work for lie groups.
188+
# ϵ = identity_element(M, u0)
189+
ϵ = getPointIdentity(variableType)
190+
191+
X0c = zero(MVector{getDimension(M),Float64})
192+
X0c .= vee(M, u0, log(M, ϵ, u0))
193+
194+
#TODO check performance
195+
function cost(Xc)
196+
X = hat(M, ϵ, Xc)
197+
p = exp(M, ϵ, X)
198+
residual = objResX(p)
199+
return sum(residual .^ 2)
200+
end
201+
202+
alg = islen1 ? Optim.BFGS() : Optim.NelderMead()
203+
204+
r = Optim.optimize(cost, X0c, alg)
205+
if !Optim.converged(r)
206+
# TODO find good way for a solve to store diagnostics about number of failed converges etc.
207+
@warn "Optim did not converge (maxlog=10):" r maxlog=10
208+
end
209+
return exp(M, ϵ, hat(M, ϵ, r.minimizer))
210+
end
211+
212+
# sloooowwww and does not always converge, unusable slow with gradient
213+
# NelderMead 5.513693 seconds (38.60 M allocations: 1.613 GiB, 6.62% gc time)
214+
function _solveLambdaNumeric_test_optim_manifold(
215+
fcttype::Union{F, <:Mixture{N_, F, S, T}},
216+
objResX::Function,
217+
residual::AbstractVector{<:Real},
218+
u0,#::AbstractVector{<:Real},
219+
variableType::InferenceVariable,
220+
islen1::Bool = false,
221+
) where {N_, F <: AbstractManifoldMinimize, S, T}
222+
#
223+
M = getManifold(variableType) #fcttype.M
224+
# the variable is a manifold point, we are working on the tangent plane in optim for now.
225+
#
226+
#TODO this is not general to all manifolds, should work for lie groups.
227+
ϵ = getPointIdentity(variableType)
228+
229+
function cost(p)
230+
residual = objResX(p)
231+
return sum(residual .^ 2)
232+
end
233+
234+
alg = islen1 ? Optim.BFGS(;manifold=ManifoldWrapper(M)) : Optim.NelderMead(;manifold=ManifoldWrapper(M))
235+
# alg = Optim.ConjugateGradient(; manifold=ManifoldWrapper(M))
236+
# alg = Optim.BFGS(; manifold=ManifoldWrapper(M))
237+
238+
# r_backend = ManifoldDiff.TangentDiffBackend(
239+
# ManifoldDiff.FiniteDifferencesBackend()
240+
# )
241+
242+
# ## finitediff gradient
243+
# function costgrad_FD!(X,p)
244+
# copyto!(X, ManifoldDiff.gradient(M, cost, p, r_backend))
245+
# X
246+
# end
247+
248+
u0_m = allocate(M, u0)
249+
u0_m .= u0
250+
# r = Optim.optimize(cost, costgrad_FD!, u0_m, alg)
251+
r = Optim.optimize(cost, u0_m, alg)
252+
253+
if !Optim.converged(r)
254+
@warn "Optim did not converge:" r maxlog=10
255+
end
256+
257+
return r.minimizer
258+
# return exp(M, ϵ, hat(M, ϵ, r.minimizer))
259+
end
260+
156261
#TODO Consolidate with _solveLambdaNumeric, see #1374
157262
function _solveLambdaNumericMeas(
158263
fcttype::Union{F, <:Mixture{N_, F, S, T}},
@@ -317,15 +422,25 @@ function _solveCCWNumeric!(
317422
# islen1 = length(cpt_.X[:, smpid]) == 1 || ccwl.partial
318423

319424
# build the pre-objective function for this sample's hypothesis selection
320-
unrollHypo!, target = _buildCalcFactorLambdaSample(ccwl, smpid; _slack = _slack)
425+
unrollHypo!, target = _buildCalcFactorLambdaSample(
426+
ccwl,
427+
smpid,
428+
view(ccwl.varValsAll[ccwl.varidx[]], smpid);
429+
_slack = _slack
430+
)
321431

322432
# broadcast updates original view memory location
323433
## using CalcFactor legacy path inside (::CalcFactor)
324-
_hypoObj = (x) -> (target .= x; unrollHypo!())
434+
435+
# _hypoObj = (x) -> (target[] = x; unrollHypo!())
436+
function _hypoObj(x)
437+
target[] = x
438+
return unrollHypo!()
439+
end
325440

326441
# TODO small off-manifold perturbation is a numerical workaround only, make on-manifold requires RoME.jl #244
327442
# use all element dimensions : ==> 1:ccwl.xDim
328-
target .+= _perturbIfNecessary(getFactorType(ccwl), length(target), perturb)
443+
# target .+= _perturbIfNecessary(getFactorType(ccwl), length(target), perturb)
329444

330445
sfidx = ccwl.varidx[]
331446
# do the parameter search over defined decision variables using Minimization
@@ -345,7 +460,7 @@ function _solveCCWNumeric!(
345460
end
346461

347462
# insert result back at the correct variable element location
348-
ccwl.varValsAll[sfidx][smpid][ccwl.partialDims] .= retval
463+
copyto!(ccwl.varValsAll[sfidx][smpid][ccwl.partialDims], retval)
349464

350465
return nothing
351466
end

test/testSpecialEuclidean2Mani.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ import Rotations as _Rot
99

1010
## define new local variable types for testing
1111

12-
@defVariable TranslationGroup2 TranslationGroup(2) [0.0, 0.0]
12+
@defVariable TranslationGroup2 TranslationGroup(2) @SVector[0.0, 0.0]
1313

14-
# @defVariable SpecialEuclidean2 SpecialEuclidean(2) ArrayPartition(@MVector([0.0,0.0]), @MMatrix([1.0 0.0; 0.0 1.0]))
15-
@defVariable SpecialEuclidean2 SpecialEuclidean(2) ArrayPartition([0.0,0.0], [1.0 0.0; 0.0 1.0])
14+
@defVariable SpecialEuclidean2 SpecialEuclidean(2) ArrayPartition(@SVector([0.0,0.0]), @SMatrix([1.0 0.0; 0.0 1.0]))
15+
# @defVariable SpecialEuclidean2 SpecialEuclidean(2) ArrayPartition([0.0,0.0], [1.0 0.0; 0.0 1.0])
1616

1717
##
1818

@@ -22,11 +22,12 @@ import Rotations as _Rot
2222
M = getManifold(SpecialEuclidean2)
2323
@test M == SpecialEuclidean(2)
2424
pT = getPointType(SpecialEuclidean2)
25-
@test pT == ArrayPartition{Float64,Tuple{Vector{Float64}, Matrix{Float64}}}
25+
# @test pT == ArrayPartition{Float64,Tuple{Vector{Float64}, Matrix{Float64}}}
2626
# @test pT == ArrayPartition{Tuple{MVector{2, Float64}, MMatrix{2, 2, Float64, 4}}}
27+
@test pT == ArrayPartition{Float64, Tuple{SVector{2, Float64}, SMatrix{2, 2, Float64, 4}}}
2728
= getPointIdentity(SpecialEuclidean2)
2829
# @test_broken pϵ == ArrayPartition(@MVector([0.0,0.0]), @MMatrix([1.0 0.0; 0.0 1.0]))
29-
@test all(isapprox.(pϵ,ArrayPartition([0.0,0.0], [1.0 0.0; 0.0 1.0])))
30+
@test all(isapprox.(pϵ,ArrayPartition(SA[0.0,0.0], SA[1.0 0.0; 0.0 1.0])))
3031

3132
@test is_point(getManifold(SpecialEuclidean2), getPointIdentity(SpecialEuclidean2))
3233

@@ -38,6 +39,7 @@ v0 = addVariable!(fg, :x0, SpecialEuclidean2)
3839
# mp = ManifoldPrior(SpecialEuclidean(2), ArrayPartition(@MVector([0.0,0.0]), @MMatrix([1.0 0.0; 0.0 1.0])), MvNormal([0.01, 0.01, 0.01]))
3940
# mp = ManifoldPrior(SpecialEuclidean(2), ArrayPartition(@MVector([0.0,0.0]), @MMatrix([1.0 0.0; 0.0 1.0])), MvNormal(Diagonal(abs2.([0.01, 0.01, 0.01]))))
4041
mp = ManifoldPrior(SpecialEuclidean(2), ArrayPartition([0.0,0.0], [1.0 0.0; 0.0 1.]), MvNormal(Diagonal(abs2.([0.01, 0.01, 0.01]))))
42+
mp = ManifoldPrior(SpecialEuclidean(2), ArrayPartition(SA[0.0,0.0], SA[1.0 0.0; 0.0 1.]), MvNormal(Diagonal(abs2.(SA[0.01, 0.01, 0.01]))))
4143
p = addFactor!(fg, [:x0], mp)
4244

4345

@@ -47,18 +49,18 @@ doautoinit!(fg, :x0)
4749

4850
##
4951
vnd = getVariableSolverData(fg, :x0)
50-
@test all(isapprox.(mean(vnd.val), ArrayPartition([0.0,0.0], [1.0 0.0; 0.0 1.0]), atol=0.1))
52+
@test all(isapprox.(mean(vnd.val), ArrayPartition(SA[0.0,0.0], SA[1.0 0.0; 0.0 1.0]), atol=0.1))
5153
@test all(is_point.(Ref(M), vnd.val))
5254

5355
##
5456
v1 = addVariable!(fg, :x1, SpecialEuclidean2)
55-
mf = ManifoldFactor(SpecialEuclidean(2), MvNormal([1,2,pi/4], [0.01,0.01,0.01]))
57+
mf = ManifoldFactor(SpecialEuclidean(2), MvNormal(SA[1,2,pi/4], SA[0.01,0.01,0.01]))
5658
f = addFactor!(fg, [:x0, :x1], mf)
5759

5860
doautoinit!(fg, :x1)
5961

6062
vnd = getVariableSolverData(fg, :x1)
61-
@test all(isapprox(M, mean(M,vnd.val), ArrayPartition([1.0,2.0], [0.7071 -0.7071; 0.7071 0.7071]), atol=0.1))
63+
@test all(isapprox(M, mean(M,vnd.val), ArrayPartition(SA[1.0,2.0], SA[0.7071 -0.7071; 0.7071 0.7071]), atol=0.1))
6264
@test all(is_point.(Ref(M), vnd.val))
6365

6466
##
@@ -67,15 +69,15 @@ solveTree!(fg; smtasks, verbose=true) #, recordcliqs=ls(fg))
6769
# hists = fetchCliqHistoryAll!(smtasks);
6870

6971
vnd = getVariableSolverData(fg, :x0)
70-
@test all(isapprox.(mean(vnd.val), ArrayPartition([0.0,0.0], [1.0 0.0; 0.0 1.0]), atol=0.1))
72+
@test all(isapprox.(mean(vnd.val), ArrayPartition(SA[0.0,0.0], SA[1.0 0.0; 0.0 1.0]), atol=0.1))
7173
@test all(is_point.(Ref(M), vnd.val))
7274

7375
vnd = getVariableSolverData(fg, :x1)
74-
@test all(isapprox.(mean(vnd.val), ArrayPartition([1.0,2.0], [0.7071 -0.7071; 0.7071 0.7071]), atol=0.1))
76+
@test all(isapprox.(mean(vnd.val), ArrayPartition(SA[1.0,2.0], SA[0.7071 -0.7071; 0.7071 0.7071]), atol=0.1))
7577
@test all(is_point.(Ref(M), vnd.val))
7678

7779
v1 = addVariable!(fg, :x2, SpecialEuclidean2)
78-
mf = ManifoldFactor(SpecialEuclidean(2), MvNormal([1,2,pi/4], [0.01,0.01,0.01]))
80+
mf = ManifoldFactor(SpecialEuclidean(2), MvNormal(SA[1,2,pi/4], SA[0.01,0.01,0.01]))
7981
f = addFactor!(fg, [:x1, :x2], mf)
8082

8183
##

0 commit comments

Comments
 (0)