Skip to content

Commit eaeba94

Browse files
committed
solve_RLM_sparse wip
1 parent ebc4c85 commit eaeba94

File tree

3 files changed

+144
-11
lines changed

3 files changed

+144
-11
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ version = "0.34.0"
88
ApproxManifoldProducts = "9bbbb610-88a1-53cd-9763-118ce10c1f89"
99
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
1010
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
11+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
1112
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1213
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1314
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
@@ -40,6 +41,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
4041
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
4142
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
4243
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
44+
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
4345
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
4446
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
4547
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

src/ManifoldsExtentions.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,18 @@ function Manifolds.exp!(M::NPowerManifold, q, p, X)
8888
return q
8989
end
9090

91+
function Manifolds.compose!(M::NPowerManifold, x, p, q)
92+
rep_size = representation_size(M.manifold)
93+
for i in Manifolds.get_iterator(M)
94+
x[i] = compose(
95+
M.manifold,
96+
Manifolds._read(M, rep_size, p, i),
97+
Manifolds._read(M, rep_size, q, i),
98+
)
99+
end
100+
return x
101+
end
102+
91103
function Manifolds.allocate_result(M::NPowerManifold, f, x...)
92104
if length(x) == 0
93105
return [Manifolds.allocate_result(M.manifold, f) for _ in Manifolds.get_iterator(M)]

src/ParametricManoptDev.jl

Lines changed: 130 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
using Manopt
22
using FiniteDiff
3+
using SparseDiffTools
4+
using BlockArrays
5+
using SparseArrays
6+
37
# using ForwardDiff
48
# using Zygote
59

@@ -18,9 +22,10 @@ struct CalcFactorManopt{
1822
varOrderIdxs::Vector{Int}
1923
meas::MEAS
2024
::SMatrix{D, D, Float64, L}
25+
sqrt_iΣ::SMatrix{D, D, Float64, L}
2126
end
2227

23-
function CalcFactorManopt(fg, fct::DFGFactor, varIntLabel)
28+
function CalcFactorManopt(fct::DFGFactor, varIntLabel)
2429
fac_func = getFactorType(fct)
2530
varOrder = getVariableOrder(fct)
2631

@@ -45,6 +50,7 @@ function CalcFactorManopt(fg, fct::DFGFactor, varIntLabel)
4550
# make sure its an SMatrix
4651
= convert(SMatrix{dims, dims}, _iΣ)
4752

53+
sqrt_iΣ = convert(SMatrix{dims, dims}, sqrt(iΣ))
4854
# cache = preambleCache(fg, getVariable.(fg, varOrder), getFactorType(fct))
4955

5056
calcf = CalcFactor(
@@ -57,13 +63,14 @@ function CalcFactorManopt(fg, fct::DFGFactor, varIntLabel)
5763
0,
5864
getManifold(fac_func),
5965
)
60-
return CalcFactorManopt(fct.label, calcf, varOrder, varOrderIdxs, meas, iΣ)
66+
return CalcFactorManopt(fct.label, calcf, varOrder, varOrderIdxs, meas, iΣ, sqrt_iΣ)
6167
end
6268

6369
function (cfm::CalcFactorManopt)(p)
6470
meas = cfm.meas
6571
idx = cfm.varOrderIdxs
6672
return cfm.calcfactor!(meas, p[idx]...)
73+
# return cfm.sqrt_iΣ * cfm.calcfactor!(meas, p[idx]...)
6774
end
6875

6976
# cost function f: M->ℝᵈ for Riemannian Levenberg-Marquardt
@@ -95,7 +102,7 @@ function JacF_RLM!(M, costF!; basis_domain::AbstractBasis = DefaultOrthogonalBas
95102

96103
p = costF!.points
97104

98-
res = mapreduce(f -> f(p), vcat, costF!.costfuns)
105+
res = Vector(mapreduce(f -> f(p), vcat, costF!.costfuns))
99106

100107
X0 = zeros(manifold_dimension(M))
101108

@@ -125,25 +132,52 @@ function (jacF!::JacF_RLM!)(
125132

126133
fill!(X0, 0)
127134

128-
# J .= FiniteDiff.finite_difference_jacobian(
129-
# Xc -> jacF!.costF!(M, jacF!.res, exp!(M, q, p, get_vector!(M, X, p, Xc, basis_domain))),
130-
# X0,
131-
# )
135+
# TODO maybe move to struct
136+
colorvec = matrix_colors(J)
137+
138+
# ϵ = getPointIdentity(M)
139+
# function jaccost(res, Xc)
140+
# exp!(M, q, ϵ, get_vector!(M, X, p, Xc, basis_domain))
141+
# compose!(M, q, p, q)
142+
# jacF!.costF!(M, res, q)
143+
# end
144+
132145
FiniteDiff.finite_difference_jacobian!(
133146
J,
134147
(res,Xc) -> jacF!.costF!(M, res, exp!(M, q, p, get_vector!(M, X, p, Xc, basis_domain))),
135-
X0,
148+
X0;
149+
colorvec
136150
)
151+
# @warn "1" Matrix(J)[1:3,1:3]
152+
@warn "2" Matrix(J)[4:6,1:6]
137153
return J
138154
end
139155

156+
function getSparsityPattern(fg)
157+
biadj = getBiadjacencyMatrix(fg)
158+
159+
vdims = getDimension.(getVariable.(fg, biadj.varLabels))
160+
fdims = getDimension.(getFactor.(fg, biadj.facLabels))
161+
162+
sm = map(eachindex(biadj.B)) do i
163+
vdim = vdims[i[2]]
164+
fdim = fdims[i[1]]
165+
if biadj.B[i] > 0
166+
trues(fdim,vdim)
167+
else
168+
falses(fdim,vdim)
169+
end
170+
end
171+
172+
return SparseMatrixCSC(mortar(sm))
173+
174+
end
140175

141176
function solve_RLM(
142177
fg,
143178
frontals::Vector{Symbol} = ls(fg),
144179
separators::Vector{Symbol} = setdiff(ls(fg), frontals);
145180
)
146-
@error "#FIXME, use covariances" maxlog=1
147181

148182
# get the subgraph formed by all frontals, separators and fully connected factors
149183
varlabels = union(frontals, separators)
@@ -162,7 +196,7 @@ function solve_RLM(
162196
# varIntLabel_frontals = filter(p->first(p) in frontals, varIntLabel)
163197
# varIntLabel_separators = filter(p->first(p) in separators, varIntLabel)
164198

165-
calcfacs = CalcFactorManopt.(fg, facs, Ref(varIntLabel))
199+
calcfacs = CalcFactorManopt.(facs, Ref(varIntLabel))
166200

167201

168202
# get the manifold and variable types
@@ -194,7 +228,92 @@ function solve_RLM(
194228
num_components = length(jacF!.res)
195229

196230
p0 = deepcopy(fro_p)
197-
lm_r = LevenbergMarquardt(MM, costF!, jacF!, p0, num_components; evaluation=InplaceEvaluation())
231+
232+
initial_residual_values = zeros(num_components)
233+
initial_jacF = zeros(num_components, manifold_dimension(MM))
234+
#
235+
#HEX solve
236+
# sparse J 0.025235 seconds (133.65 k allocations: 9.964 MiB
237+
# dense J 0.022079 seconds (283.54 k allocations: 18.146 MiB)
238+
239+
lm_r = LevenbergMarquardt(
240+
MM,
241+
costF!,
242+
jacF!,
243+
p0,
244+
num_components;
245+
evaluation=InplaceEvaluation(),
246+
initial_residual_values,
247+
initial_jacF,
248+
)
249+
250+
return vartypeslist, lm_r
251+
end
252+
253+
function solve_RLM_sparse(fg)
254+
255+
# get the subgraph formed by all frontals, separators and fully connected factors
256+
varlabels = ls(fg)
257+
faclabels = lsf(fg)
258+
259+
facs = getFactor.(fg, faclabels)
260+
261+
# so the subgraph consists of varlabels(frontals + separators) and faclabels
262+
263+
varIntLabel = OrderedDict(zip(varlabels, collect(1:length(varlabels))))
264+
265+
calcfacs = CalcFactorManopt.(facs, Ref(varIntLabel))
266+
267+
# get the manifold and variable types
268+
vars = getVariable.(fg, varlabels)
269+
vartypes, vartypecount, vartypeslist = getVariableTypesCount(vars)
270+
271+
PMs = map(vartypes) do vartype
272+
N = vartypecount[vartype]
273+
G = getManifold(vartype)
274+
return IIF.NPowerManifold(G, N)
275+
end
276+
M = ProductManifold(PMs...)
277+
278+
#
279+
#FIXME
280+
@assert length(M.manifolds) == 1 "#FIXME, this only works with 1 manifold type component"
281+
MM = M.manifolds[1]
282+
283+
# inital values and separators from fg
284+
fro_p = first.(getVal.(vars, solveKey = :parametric))
285+
sep_p = eltype(fro_p)[]
286+
287+
#cost and jacobian functions
288+
# cost function f: M->ℝᵈ for Riemannian Levenberg-Marquardt
289+
costF! = CostF_RLM!(calcfacs, fro_p, sep_p)
290+
# jacobian of function for Riemannian Levenberg-Marquardt
291+
jacF! = JacF_RLM!(MM, costF!)
292+
293+
num_components = length(jacF!.res)
294+
295+
p0 = deepcopy(fro_p)
296+
297+
initial_residual_values = zeros(num_components)
298+
initial_jacF = Float64.(getSparsityPattern(fg))
299+
300+
#HEX solve
301+
# sparse J 0.025235 seconds (133.65 k allocations: 9.964 MiB
302+
# dense J 0.022079 seconds (283.54 k allocations: 18.146 MiB)
303+
304+
# 9.125818 seconds (86.35 M allocations: 6.412 GiB, 14.34% gc time)
305+
# 0.841720 seconds (7.96 M allocations: 751.825 MiB)
306+
307+
lm_r = LevenbergMarquardt(
308+
MM,
309+
costF!,
310+
jacF!,
311+
p0,
312+
num_components;
313+
evaluation=InplaceEvaluation(),
314+
initial_residual_values,
315+
initial_jacF,
316+
)
198317

199318
return vartypeslist, lm_r
200319
end

0 commit comments

Comments
 (0)