Skip to content

Commit 180d89d

Browse files
committed
Basic Riemannian Levenberg-Marquardt solve
1 parent c95beec commit 180d89d

File tree

7 files changed

+328
-10
lines changed

7 files changed

+328
-10
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ DistributedFactorGraphs = "b5cc3c7e-6572-11e9-2517-99fb8daf2f04"
1717
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1818
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1919
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
20+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
2021
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2122
FunctionalStateMachine = "3e9e306e-7e3c-11e9-12d2-8f8f67a2f951"
2223
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
@@ -26,6 +27,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
2627
ManifoldDiff = "af67fdf4-a580-4b9f-bbec-742ef357defd"
2728
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
2829
ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb"
30+
Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"
2931
MetaGraphs = "626554b9-1ddb-594c-aa3c-2596fe9399a5"
3032
NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c"
3133
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"

src/Factors/GenericFunctions.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,14 @@ function getSample(cf::CalcFactor{<:ManifoldFactor{M, Z}}) where {M, Z}
8585
else
8686
ret = rand(cf.factor.Z)
8787
end
88+
# ret = sampleTangent(M, cf.factor.Z)
8889
#return coordinates as we do not know the point here #TODO separate Lie group
8990
return ret
9091
end
9192

9293
# function (cf::CalcFactor{<:ManifoldFactor{<:AbstractDecoratorManifold}})(Xc, p, q)
93-
function (cf::CalcFactor{<:ManifoldFactor})(Xc, p, q)
94-
# function (cf::ManifoldFactor)(X, p, q)
95-
M = cf.manifold # .factor.M
96-
# M = cf.M
97-
X = hat(M, p, Xc)
98-
return distanceTangent2Point(M, X, p, q)
94+
function (cf::CalcFactor{<:ManifoldFactor})(X, p, q)
95+
return distanceTangent2Point(cf.manifold, X, p, q)
9996
end
10097

10198
## ======================================================================================

src/IncrementalInference.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ include("CliqueStateMachine/services/CliqStateMachineUtils.jl")
213213
#EXPERIMENTAL parametric
214214
include("ParametricCSMFunctions.jl")
215215
include("ParametricUtils.jl")
216+
include("ParametricManoptDev.jl")
216217
include("services/MaxMixture.jl")
217218

218219
#X-stroke

src/ManifoldsExtentions.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,18 @@ function Manifolds.exp!(M::NPowerManifold, q, p, X)
8888
return q
8989
end
9090

91+
function Manifolds.allocate_result(M::NPowerManifold, f, x...)
92+
if length(x) == 0
93+
return [Manifolds.allocate_result(M.manifold, f) for _ in Manifolds.get_iterator(M)]
94+
else
95+
return copy(x[1])
96+
end
97+
end
98+
99+
function Manifolds.allocate_result(::NPowerManifold, ::typeof(get_vector), p, X)
100+
return copy(p)
101+
end
102+
91103
## ================================================================================================
92104
## ArrayPartition getPointIdentity (identity_element)
93105
## ================================================================================================

src/ParametricManoptDev.jl

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
using Manopt
2+
using FiniteDiff
3+
# using ForwardDiff
4+
# using Zygote
5+
6+
##
7+
8+
struct CalcFactorManopt{
9+
D,
10+
L,
11+
FT <: AbstractFactor,
12+
M <: AbstractManifold,
13+
MEAS <: AbstractArray,
14+
}
15+
faclbl::Symbol
16+
calcfactor!::CalcFactor{FT, Nothing, Nothing, Tuple{}, M}
17+
varOrder::Vector{Symbol}
18+
varOrderIdxs::Vector{Int}
19+
meas::MEAS
20+
::SMatrix{D, D, Float64, L}
21+
end
22+
23+
function CalcFactorManopt(fg, fct::DFGFactor, varIntLabel)
24+
fac_func = getFactorType(fct)
25+
varOrder = getVariableOrder(fct)
26+
27+
varOrderIdxs = getindex.(Ref(varIntLabel), varOrder)
28+
29+
M = getManifold(getFactorType(fct))
30+
31+
dims = manifold_dimension(M)
32+
ϵ = getPointIdentity(M)
33+
34+
_meas, _iΣ = getMeasurementParametric(fct)
35+
if fac_func isa ManifoldPrior
36+
meas = _meas # already a point on M
37+
elseif fac_func isa AbstractPrior
38+
X = get_vector(M, ϵ, _meas, DefaultOrthogonalBasis())
39+
meas = exp(M, ϵ, X) # convert to point on M
40+
else
41+
# its a relative factor so should be a tangent vector
42+
meas = convert(typeof(ϵ), get_vector(M, ϵ, _meas, DefaultOrthogonalBasis()))
43+
end
44+
45+
# make sure its an SMatrix
46+
= convert(SMatrix{dims, dims}, _iΣ)
47+
48+
# cache = preambleCache(fg, getVariable.(fg, varOrder), getFactorType(fct))
49+
50+
calcf = CalcFactor(
51+
getFactorMechanics(fac_func),
52+
0,
53+
nothing,
54+
true,
55+
nothing,#cache,
56+
(), #DFGVariable[],
57+
0,
58+
getManifold(fac_func),
59+
)
60+
return CalcFactorManopt(fct.label, calcf, varOrder, varOrderIdxs, meas, iΣ)
61+
end
62+
63+
function (cfm::CalcFactorManopt)(p)
64+
meas = cfm.meas
65+
idx = cfm.varOrderIdxs
66+
return cfm.calcfactor!(meas, p[idx]...)
67+
end
68+
69+
# cost function f: M->ℝᵈ for Riemannian Levenberg-Marquardt
70+
struct CostF_RLM!{T}
71+
points::Vector{T}
72+
costfuns::Vector{<:CalcFactorManopt}
73+
end
74+
75+
function CostF_RLM!(costfuns::Vector{<:CalcFactorManopt}, frontals_p::Vector{T}, separators_p::Vector{T}) where T
76+
points::Vector{T} = vcat(frontals_p, separators_p)
77+
return CostF_RLM!(points, costfuns)
78+
end
79+
80+
function (cfm::CostF_RLM!)(M::AbstractManifold, x, p::Vector{T}) where T
81+
cfm.points[1:length(p)] .= p
82+
return x .= mapreduce(f -> f(cfm.points), vcat, cfm.costfuns)
83+
end
84+
85+
# jacobian of function for Riemannian Levenberg-Marquardt
86+
struct JacF_RLM!{CF, T}
87+
costF!::CF
88+
X0::Vector{Float64}
89+
X::T
90+
q::T
91+
res::Vector{Float64}
92+
end
93+
94+
function JacF_RLM!(M, costF!; basis_domain::AbstractBasis = DefaultOrthogonalBasis())
95+
96+
p = costF!.points
97+
98+
res = mapreduce(f -> f(p), vcat, costF!.costfuns)
99+
100+
X0 = zeros(manifold_dimension(M))
101+
102+
X = get_vector(M, p, X0, basis_domain)
103+
104+
q = exp(M, p, X)
105+
106+
# J = FiniteDiff.finite_difference_jacobian(
107+
# Xc -> costF!(M, res, exp!(M, q, p, get_vector!(M, X, p, Xc, basis_domain))),
108+
# X0,
109+
# )
110+
111+
return JacF_RLM!(costF!, X0, X, q, res)
112+
113+
end
114+
115+
function (jacF!::JacF_RLM!)(
116+
M::AbstractManifold,
117+
J,
118+
p;
119+
basis_domain::AbstractBasis = DefaultOrthogonalBasis(),
120+
)
121+
122+
X0 = jacF!.X0
123+
X = jacF!.X
124+
q = jacF!.q
125+
126+
fill!(X0, 0)
127+
128+
# J .= FiniteDiff.finite_difference_jacobian(
129+
# Xc -> jacF!.costF!(M, jacF!.res, exp!(M, q, p, get_vector!(M, X, p, Xc, basis_domain))),
130+
# X0,
131+
# )
132+
FiniteDiff.finite_difference_jacobian!(
133+
J,
134+
(res,Xc) -> jacF!.costF!(M, res, exp!(M, q, p, get_vector!(M, X, p, Xc, basis_domain))),
135+
X0,
136+
)
137+
return J
138+
end
139+
140+
141+
function solve_RLM(
142+
fg,
143+
frontals::Vector{Symbol} = ls(fg),
144+
separators::Vector{Symbol} = setdiff(ls(fg), frontals);
145+
)
146+
@error "#FIXME, use covariances" maxlog=1
147+
148+
# get the subgraph formed by all frontals, separators and fully connected factors
149+
varlabels = union(frontals, separators)
150+
faclabels = sortDFG(setdiff(getNeighborhood(fg, varlabels, 1), varlabels))
151+
152+
filter!(faclabels) do fl
153+
return issubset(getVariableOrder(fg, fl), varlabels)
154+
end
155+
156+
facs = getFactor.(fg, faclabels)
157+
158+
# so the subgraph consists of varlabels(frontals + separators) and faclabels
159+
160+
varIntLabel = OrderedDict(zip(varlabels, collect(1:length(varlabels))))
161+
162+
# varIntLabel_frontals = filter(p->first(p) in frontals, varIntLabel)
163+
# varIntLabel_separators = filter(p->first(p) in separators, varIntLabel)
164+
165+
calcfacs = CalcFactorManopt.(fg, facs, Ref(varIntLabel))
166+
167+
168+
# get the manifold and variable types
169+
frontal_vars = getVariable.(fg, frontals)
170+
vartypes, vartypecount, vartypeslist = getVariableTypesCount(frontal_vars)
171+
172+
PMs = map(vartypes) do vartype
173+
N = vartypecount[vartype]
174+
G = getManifold(vartype)
175+
return IIF.NPowerManifold(G, N)
176+
end
177+
M = ProductManifold(PMs...)
178+
179+
#
180+
#FIXME
181+
@assert length(M.manifolds) == 1 "#FIXME, this only works with 1 manifold type component"
182+
MM = M.manifolds[1]
183+
184+
# inital values and separators from fg
185+
fro_p = first.(getVal.(frontal_vars, solveKey = :parametric))
186+
sep_p::Vector{eltype(fro_p)} = first.(getVal.(fg, separators, solveKey = :parametric))
187+
188+
189+
if false
190+
# non-in-place version updated bolow to in-place
191+
fullp::Vector{eltype(fro_p)} = [fro_p; sep_p]
192+
# cost function f: M->ℝᵈ for Riemannian Levenberg-Marquardt
193+
function costF_RLM(M::AbstractManifold, p::Vector{T}) where T
194+
fullp[1:length(p)] .= p
195+
return Vector(mapreduce(f -> f(fullp), vcat, calcfacs))
196+
end
197+
198+
# jacobian of function for Riemannian Levenberg-Marquardt
199+
function jacF_RLM(
200+
M::AbstractManifold,
201+
p;
202+
basis_domain::AbstractBasis = DefaultOrthogonalBasis(),
203+
)
204+
X0 = zeros(manifold_dimension(M))
205+
J = FiniteDiff.finite_difference_jacobian(
206+
x -> costF_RLM(M, exp(M, p, get_vector(M, p, x, basis_domain))),
207+
X0,
208+
)
209+
# J = ForwardDiff.jacobian(
210+
# x -> costF_RLM(M, exp(M, p, get_vector(M, p, x, basis_domain))),
211+
# X0,
212+
# )
213+
return J
214+
end
215+
216+
# 0.296639 seconds (2.46 M allocations: 164.722 MiB, 12.83% gc time)
217+
p0 = deepcopy(fro_p)
218+
lm_r = LevenbergMarquardt(MM, costF_RLM, jacF_RLM, p0)
219+
220+
# 81.185117 seconds (647.20 M allocations: 41.680 GiB, 8.61% gc time)
221+
else
222+
# 74.420872 seconds (567.70 M allocations: 34.698 GiB, 8.30% gc time, 0.66% compilation time)
223+
#cost and jacobian functions
224+
# cost function f: M->ℝᵈ for Riemannian Levenberg-Marquardt
225+
costF! = CostF_RLM!(calcfacs, fro_p, sep_p)
226+
# jacobian of function for Riemannian Levenberg-Marquardt
227+
jacF! = JacF_RLM!(MM, costF!)
228+
229+
num_components = length(jacF!.res)
230+
231+
p0 = deepcopy(fro_p)
232+
lm_r = LevenbergMarquardt(MM, costF!, jacF!, p0, num_components; evaluation=InplaceEvaluation())
233+
end
234+
235+
return vartypeslist, lm_r
236+
end
237+
238+
function autoinitParametricManopt!(
239+
fg,
240+
varorderIds = getInitOrderParametric(fg);
241+
reinit = false,
242+
)
243+
@showprogress for vIdx in varorderIds
244+
autoinitParametricManopt!(fg, vIdx; reinit)
245+
end
246+
return nothing
247+
end
248+
249+
function autoinitParametricManopt!(dfg::AbstractDFG, initme::Symbol; kwargs...)
250+
return autoinitParametricManopt!(dfg, getVariable(dfg, initme); kwargs...)
251+
end
252+
253+
function autoinitParametricManopt!(
254+
dfg::AbstractDFG,
255+
xi::DFGVariable;
256+
solveKey = :parametric,
257+
reinit::Bool = false,
258+
kwargs...,
259+
)
260+
#
261+
262+
initme = getLabel(xi)
263+
vnd = getSolverData(xi, solveKey)
264+
# don't initialize a variable more than once
265+
if reinit || !isInitialized(xi, solveKey)
266+
267+
# frontals - initme
268+
# separators - inifrom
269+
270+
initfrom = ls2(dfg, initme)
271+
filter!(initfrom) do vl
272+
return isInitialized(dfg, vl, solveKey)
273+
end
274+
275+
vartypeslist, lm_r = solve_RLM(dfg, [initme], initfrom)
276+
277+
val = lm_r[1]
278+
vnd::VariableNodeData = getSolverData(xi, solveKey)
279+
vnd.val[1] = val
280+
281+
282+
# val = lm_r[1]
283+
# cov = ...
284+
# updateSolverDataParametric!(vnd, val, cov)
285+
286+
vnd.initialized = true
287+
#fill in ppe as mean
288+
Xc::Vector{Float64} = collect(getCoordinates(getVariableType(xi), val))
289+
ppe = MeanMaxPPE(:parametric, Xc, Xc, Xc)
290+
getPPEDict(xi)[:parametric] = ppe
291+
292+
result = vartypeslist, lm_r
293+
294+
else
295+
result = nothing
296+
end
297+
298+
return result#isInitialized(xi, solveKey)
299+
end

0 commit comments

Comments
 (0)