Skip to content

Commit c49e4b0

Browse files
authored
Swop parametric optim for manopt and better convolution performance (#1784)
* swop parametric optim for manopt * improve non-parametric performance * Standardise to use getManifold * covariance calc try catch on hessian inverse
1 parent f1707b0 commit c49e4b0

File tree

11 files changed

+94
-68
lines changed

11 files changed

+94
-68
lines changed

benchmark/runbenchmarks.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ fg.solverParams.graphinit=false;
3535
addVariable!(fg, :x0, Pose2);
3636
addFactor!(fg, [:x0], PriorPose2(MvNormal([0.0,0,0], diagm([0.1,0.1,0.01].^2))));
3737

38-
r = @timed IIF.solveGraphParametric!(fg; init=false);
38+
r = @timed IIF.solveGraphParametric!(fg; init=false, is_sparse=false);
3939

4040
timed = [r];
4141

@@ -44,7 +44,7 @@ for i = 1:14
4444
to = Symbol("x",i)
4545
addVariable!(fg, to, Pose2)
4646
addFactor!(fg, [fr,to], Pose2Pose2(MvNormal([10.0,0,pi/3], diagm([0.5,0.5,0.05].^2))))
47-
r = @timed IIF.solveGraphParametric!(fg; init=false);
47+
r = @timed IIF.solveGraphParametric!(fg; init=false, is_sparse=false);
4848
push!(timed, r)
4949
end
5050

@@ -53,7 +53,7 @@ addVariable!(fg, :l1, RoME.Point2, tags=[:LANDMARK;]);
5353
addFactor!(fg, [:x0; :l1], Pose2Point2BearingRange(Normal(0.0,0.1), Normal(20.0, 1.0)));
5454
addFactor!(fg, [:x6; :l1], Pose2Point2BearingRange(Normal(0.0,0.1), Normal(20.0, 1.0)));
5555

56-
r = @timed IIF.solveGraphParametric!(fg; init=false);
56+
r = @timed IIF.solveGraphParametric!(fg; init=false, is_sparse=false);
5757
push!(timed, r);
5858

5959
getproperty.(timed, :time)

src/Factors/Circular.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ DFG.getManifold(::CircularCircular) = RealCircleGroup()
2323

2424
function (cf::CalcFactor{<:CircularCircular})(X, p, q)
2525
#
26-
M = cf.manifold # getManifold(cf.factor)
26+
M = getManifold(cf)
2727
return distanceTangent2Point(M, X, p, q)
2828
end
2929

@@ -68,7 +68,7 @@ function getSample(cf::CalcFactor{<:PriorCircular})
6868
end
6969

7070
function (cf::CalcFactor{<:PriorCircular})(m, p)
71-
M = cf.manifold # getManifold(cf.factor)
71+
M = getManifold(cf)
7272
Xc = vee(M, p, log(M, p, m))
7373
return Xc
7474
end

src/Factors/GenericFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ DFG.getManifold(f::ManifoldPriorPartial) = f.M
243243

244244
function getSample(cf::CalcFactor{<:ManifoldPriorPartial})
245245
Z = cf.factor.Z
246-
M = cf.manifold # getManifold(cf.factor)
246+
M = getManifold(cf)
247247
partial = collect(cf.factor.partial)
248248

249249
return (samplePointPartial(M, Z, partial),)

src/manifolds/services/ManifoldSampling.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ See also: [`getMeasurementParametric`](@ref)
119119
function getSample end
120120

121121
function getSample(cf::CalcFactor{<:AbstractPrior})
122-
M = cf.manifold # getManifold(cf.factor)
122+
M = getManifold(cf)
123123
if hasfield(typeof(cf.factor), :Z)
124124
X = samplePoint(M, cf.factor.Z)
125125
else
@@ -132,7 +132,7 @@ function getSample(cf::CalcFactor{<:AbstractPrior})
132132
end
133133

134134
function getSample(cf::CalcFactor{<:AbstractRelative})
135-
M = cf.manifold # getManifold(cf.factor)
135+
M =getManifold(cf)
136136
if hasfield(typeof(cf.factor), :Z)
137137
X = sampleTangent(M, cf.factor.Z)
138138
else

src/parametric/services/ParametricCSMFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function solveUp_ParametricStateMachine(csmc::CliqStateMachineContainer)
2626
# store the cliqSubFg for later debugging
2727
_dbgCSMSaveSubFG(csmc, "fg_beforeupsolve")
2828

29-
vardict, result, varIds, Σ = solveGraphParametric(csmc.cliqSubFg)
29+
vardict, result, varIds, Σ = solveGraphParametricOptim(csmc.cliqSubFg)
3030

3131
logCSM(csmc, "$(csmc.cliq.id) vars $(keys(varIds))")
3232
# @info "$(csmc.cliq.id) Σ $(Σ)"

src/parametric/services/ParametricManopt.jl

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,13 @@ function covarianceFiniteDiff(M, jacF!::JacF_RLM!, p0)
297297
H = FiniteDiff.finite_difference_hessian(costf, X0)
298298

299299
# inv(H)
300-
Σ = Matrix(H) \ Matrix{eltype(H)}(I, size(H)...)
301-
# sqrt.(diag(Σ))
300+
Σ = try
301+
Matrix(H) \ Matrix{eltype(H)}(I, size(H)...)
302+
catch ex #TODO only catch correct exception and try with pinv as fallback in certain cases.
303+
@warn "Hessian inverse failed" ex
304+
# Σ = pinv(H)
305+
nothing
306+
end
302307
return Σ
303308
end
304309

@@ -490,27 +495,28 @@ end
490495
# new2 0.010764 seconds (34.61 k allocations: 3.111 MiB)
491496
# dense J 0.022079 seconds (283.54 k allocations: 18.146 MiB)
492497

493-
function autoinitParametricManopt!(
498+
function autoinitParametric!(
494499
fg,
495500
varorderIds = getInitOrderParametric(fg);
496501
reinit = false,
497502
kwargs...
498503
)
499504
@showprogress for vIdx in varorderIds
500-
autoinitParametricManopt!(fg, vIdx; reinit, kwargs...)
505+
autoinitParametric!(fg, vIdx; reinit, kwargs...)
501506
end
502507
return nothing
503508
end
504509

505-
function autoinitParametricManopt!(dfg::AbstractDFG, initme::Symbol; kwargs...)
506-
return autoinitParametricManopt!(dfg, getVariable(dfg, initme); kwargs...)
510+
function autoinitParametric!(dfg::AbstractDFG, initme::Symbol; kwargs...)
511+
return autoinitParametric!(dfg, getVariable(dfg, initme); kwargs...)
507512
end
508513

509-
function autoinitParametricManopt!(
514+
function autoinitParametric!(
510515
dfg::AbstractDFG,
511516
xi::DFGVariable;
512517
solveKey = :parametric,
513518
reinit::Bool = false,
519+
perturb_point::Bool=false,
514520
kwargs...,
515521
)
516522
#
@@ -528,12 +534,27 @@ function autoinitParametricManopt!(
528534
return isInitialized(dfg, vl, solveKey)
529535
end
530536

537+
vnd::VariableNodeData = getSolverData(xi, solveKey)
538+
539+
if perturb_point
540+
_M = getManifold(xi)
541+
p = vnd.val[1]
542+
vnd.val[1] = exp(
543+
_M,
544+
p,
545+
get_vector(
546+
_M,
547+
p,
548+
randn(manifold_dimension(_M))*10^-6,
549+
DefaultOrthogonalBasis()
550+
)
551+
)
552+
end
531553
M, vartypeslist, lm_r, Σ = solve_RLM_conditional(dfg, [initme], initfrom; kwargs...)
532-
554+
533555
val = lm_r[1]
534-
vnd::VariableNodeData = getSolverData(xi, solveKey)
535556
vnd.val[1] = val
536-
557+
537558
!isnothing(Σ) && vnd.bw .= Σ
538559

539560
# updateSolverDataParametric!(vnd, val, Σ)
@@ -555,6 +576,7 @@ end
555576

556577

557578
##
579+
solveGraphParametric(args...; kwargs...) = solve_RLM(args...; kwargs...)
558580

559581
function DFG.solveGraphParametric!(
560582
fg::AbstractDFG,
@@ -578,7 +600,7 @@ function DFG.solveGraphParametric!(
578600

579601
updateParametricSolution!(fg, M, v, r, Σ)
580602

581-
return v,r, Σ
603+
return M, v, r, Σ
582604
end
583605

584606

src/parametric/services/ParametricUtils.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ function _getComponentsCovar(@nospecialize(PM::NPowerManifold), Σ::AbstractMatr
530530
return subsigmas
531531
end
532532

533-
function solveGraphParametric(
533+
function solveGraphParametricOptim(
534534
fg::AbstractDFG;
535535
verbose::Bool = false,
536536
computeCovariance::Bool = true,
@@ -828,8 +828,7 @@ end
828828
$SIGNATURES
829829
Add parametric solver to fg, batch solve using [`solveGraphParametric`](@ref) and update fg.
830830
"""
831-
function DFG.solveGraphParametric!(
832-
::Val{:Optim},
831+
function solveGraphParametricOptim!(
833832
fg::AbstractDFG;
834833
init::Bool = true,
835834
solveKey::Symbol = :parametric, # FIXME, moot since only :parametric used for parametric solves
@@ -845,7 +844,7 @@ function DFG.solveGraphParametric!(
845844
initParametricFrom!(fg, initSolveKey; parkey=solveKey)
846845
end
847846

848-
vardict, result, varIds, Σ = solveGraphParametric(fg; verbose, kwargs...)
847+
vardict, result, varIds, Σ = solveGraphParametricOptim(fg; verbose, kwargs...)
849848

850849
updateParametricSolution!(fg, vardict)
851850

@@ -970,7 +969,7 @@ function getInitOrderParametric(fg; startIdx::Symbol = lsfPriors(fg)[1])
970969
return order
971970
end
972971

973-
function autoinitParametric!(
972+
function autoinitParametricOptim!(
974973
fg,
975974
varorderIds = getInitOrderParametric(fg);
976975
reinit = false,
@@ -979,16 +978,16 @@ function autoinitParametric!(
979978
kwargs...
980979
)
981980
@showprogress for vIdx in varorderIds
982-
autoinitParametric!(fg, vIdx; reinit, algorithm, algorithmkwargs, kwargs...)
981+
autoinitParametricOptim!(fg, vIdx; reinit, algorithm, algorithmkwargs, kwargs...)
983982
end
984983
return nothing
985984
end
986985

987-
function autoinitParametric!(dfg::AbstractDFG, initme::Symbol; kwargs...)
988-
return autoinitParametric!(dfg, getVariable(dfg, initme); kwargs...)
986+
function autoinitParametricOptim!(dfg::AbstractDFG, initme::Symbol; kwargs...)
987+
return autoinitParametricOptim!(dfg, getVariable(dfg, initme); kwargs...)
989988
end
990989

991-
function autoinitParametric!(
990+
function autoinitParametricOptim!(
992991
dfg::AbstractDFG,
993992
xi::DFGVariable;
994993
solveKey = :parametric,

src/services/CalcFactor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ end
192192

193193
# the same as legacy, getManifold(ccwl.usrfnc!)
194194
getManifold(ccwl::CommonConvWrapper) = ccwl.manifold
195-
getManifold(cf::CalcFactor) = cf.manifold
195+
getManifold(cf::CalcFactor) = getManifold(cf.factor)
196196

197197
function _resizePointsVector!(
198198
vecP::AbstractVector{P},

src/services/NumericalCalculations.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,14 @@ function _solveLambdaNumeric_original(
133133
return exp(M, ϵ, hat(M, ϵ, r.minimizer))
134134
end
135135

136+
function cost_optim(M, objResX, Xc)
137+
ϵ = getPointIdentity(M)
138+
X = get_vector(M, ϵ, Xc, DefaultOrthogonalBasis())
139+
p = exp(M, ϵ, X)
140+
residual = objResX(p)
141+
return sum(x->x^2, residual) #TODO maybe move this to CalcFactorNormSq
142+
end
143+
136144
# 1.355700 seconds (11.78 M allocations: 557.677 MiB, 6.96% gc time)
137145
function _solveCCWNumeric_test_SA(
138146
fcttype::Union{F, <:Mixture{N_, F, S, T}},
@@ -153,18 +161,9 @@ function _solveCCWNumeric_test_SA(
153161
X0c = zero(MVector{getDimension(M),Float64})
154162
X0c .= vee(M, u0, log(M, ϵ, u0))
155163

156-
#TODO check performance
157-
function cost(Xc)
158-
X = hat(M, ϵ, Xc)
159-
p = exp(M, ϵ, X)
160-
residual = objResX(p)
161-
# return sum(residual .^ 2)
162-
return sum(abs2, residual) #TODO maybe move this to CalcFactorNormSq
163-
end
164-
165164
alg = islen1 ? Optim.BFGS() : Optim.NelderMead()
166165

167-
r = Optim.optimize(cost, X0c, alg)
166+
r = Optim.optimize(x->cost_optim(M, objResX, x), X0c, alg)
168167
if !Optim.converged(r)
169168
# TODO find good way for a solve to store diagnostics about number of failed converges etc.
170169
@warn "Optim did not converge (maxlog=10):" r maxlog=10

src/services/VariableStatistics.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@ end
2020

2121
#TODO check performance and FIXME on makemutalbe might not be needed any more
2222
function calcStdBasicSpread(vartype::InferenceVariable, ptsArr::AbstractVector) # {P}) where {P}
23-
_makemutable(s) = s
24-
_makemutable(s::StaticArray{Tuple{S},T,N}) where {S,T,N} = MArray{Tuple{S},T,N,S}(s)
25-
_makemutable(s::SMatrix{N,N,T,D}) where {N,T,D} = MMatrix{N,N,T,D}(s)
23+
# _makemutable(s) = s
24+
# _makemutable(s::StaticArray{Tuple{S},T,N}) where {S,T,N} = MArray{Tuple{S},T,N,S}(s)
25+
# _makemutable(s::SMatrix{N,N,T,D}) where {N,T,D} = MMatrix{N,N,T,D}(s)
2626

2727
# FIXME, silly conversion since Manifolds.std internally replicates eltype ptsArr which doesn't work on StaticArrays
28-
σ = std(vartype, _makemutable.(ptsArr))
28+
# σ = std(vartype, _makemutable.(ptsArr))
29+
30+
μ = mean(vartype, ptsArr, GeodesicInterpolation())
31+
σ = std(vartype, ptsArr, μ)
2932

3033
#if no std yet, set to 1
3134
msst = 1e-10 < σ ? σ : 1.0

0 commit comments

Comments
 (0)