|
| 1 | +using Manopt |
| 2 | +using FiniteDiff |
| 3 | +# using ForwardDiff |
| 4 | +# using Zygote |
| 5 | + |
| 6 | +## |
| 7 | + |
| 8 | +struct CalcFactorManopt{ |
| 9 | + D, |
| 10 | + L, |
| 11 | + FT <: AbstractFactor, |
| 12 | + M <: AbstractManifold, |
| 13 | + MEAS <: AbstractArray, |
| 14 | +} |
| 15 | + faclbl::Symbol |
| 16 | + calcfactor!::CalcFactor{FT, Nothing, Nothing, Tuple{}, M} |
| 17 | + varOrder::Vector{Symbol} |
| 18 | + varOrderIdxs::Vector{Int} |
| 19 | + meas::MEAS |
| 20 | + iΣ::SMatrix{D, D, Float64, L} |
| 21 | +end |
| 22 | + |
| 23 | +function CalcFactorManopt(fg, fct::DFGFactor, varIntLabel) |
| 24 | + fac_func = getFactorType(fct) |
| 25 | + varOrder = getVariableOrder(fct) |
| 26 | + |
| 27 | + varOrderIdxs = getindex.(Ref(varIntLabel), varOrder) |
| 28 | + |
| 29 | + M = getManifold(getFactorType(fct)) |
| 30 | + |
| 31 | + dims = manifold_dimension(M) |
| 32 | + ϵ = getPointIdentity(M) |
| 33 | + |
| 34 | + _meas, _iΣ = getMeasurementParametric(fct) |
| 35 | + if fac_func isa ManifoldPrior |
| 36 | + meas = _meas # already a point on M |
| 37 | + elseif fac_func isa AbstractPrior |
| 38 | + X = get_vector(M, ϵ, _meas, DefaultOrthogonalBasis()) |
| 39 | + meas = exp(M, ϵ, X) # convert to point on M |
| 40 | + else |
| 41 | + # its a relative factor so should be a tangent vector |
| 42 | + meas = convert(typeof(ϵ), get_vector(M, ϵ, _meas, DefaultOrthogonalBasis())) |
| 43 | + end |
| 44 | + |
| 45 | + # make sure its an SMatrix |
| 46 | + iΣ = convert(SMatrix{dims, dims}, _iΣ) |
| 47 | + |
| 48 | + # cache = preambleCache(fg, getVariable.(fg, varOrder), getFactorType(fct)) |
| 49 | + |
| 50 | + calcf = CalcFactor( |
| 51 | + getFactorMechanics(fac_func), |
| 52 | + 0, |
| 53 | + nothing, |
| 54 | + true, |
| 55 | + nothing,#cache, |
| 56 | + (), #DFGVariable[], |
| 57 | + 0, |
| 58 | + getManifold(fac_func), |
| 59 | + ) |
| 60 | + return CalcFactorManopt(fct.label, calcf, varOrder, varOrderIdxs, meas, iΣ) |
| 61 | +end |
| 62 | + |
| 63 | +function (cfm::CalcFactorManopt)(p) |
| 64 | + meas = cfm.meas |
| 65 | + idx = cfm.varOrderIdxs |
| 66 | + return cfm.calcfactor!(meas, p[idx]...) |
| 67 | +end |
| 68 | + |
| 69 | +# cost function f: M->ℝᵈ for Riemannian Levenberg-Marquardt |
| 70 | +struct CostF_RLM!{T} |
| 71 | + points::Vector{T} |
| 72 | + costfuns::Vector{<:CalcFactorManopt} |
| 73 | +end |
| 74 | + |
| 75 | +function CostF_RLM!(costfuns::Vector{<:CalcFactorManopt}, frontals_p::Vector{T}, separators_p::Vector{T}) where T |
| 76 | + points::Vector{T} = vcat(frontals_p, separators_p) |
| 77 | + return CostF_RLM!(points, costfuns) |
| 78 | +end |
| 79 | + |
| 80 | +function (cfm::CostF_RLM!)(M::AbstractManifold, x, p::Vector{T}) where T |
| 81 | + cfm.points[1:length(p)] .= p |
| 82 | + return x .= mapreduce(f -> f(cfm.points), vcat, cfm.costfuns) |
| 83 | +end |
| 84 | + |
| 85 | +# jacobian of function for Riemannian Levenberg-Marquardt |
| 86 | +struct JacF_RLM!{CF, T} |
| 87 | + costF!::CF |
| 88 | + X0::Vector{Float64} |
| 89 | + X::T |
| 90 | + q::T |
| 91 | + res::Vector{Float64} |
| 92 | +end |
| 93 | + |
| 94 | +function JacF_RLM!(M, costF!; basis_domain::AbstractBasis = DefaultOrthogonalBasis()) |
| 95 | + |
| 96 | + p = costF!.points |
| 97 | + |
| 98 | + res = mapreduce(f -> f(p), vcat, costF!.costfuns) |
| 99 | + |
| 100 | + X0 = zeros(manifold_dimension(M)) |
| 101 | + |
| 102 | + X = get_vector(M, p, X0, basis_domain) |
| 103 | + |
| 104 | + q = exp(M, p, X) |
| 105 | + |
| 106 | + # J = FiniteDiff.finite_difference_jacobian( |
| 107 | + # Xc -> costF!(M, res, exp!(M, q, p, get_vector!(M, X, p, Xc, basis_domain))), |
| 108 | + # X0, |
| 109 | + # ) |
| 110 | + |
| 111 | + return JacF_RLM!(costF!, X0, X, q, res) |
| 112 | + |
| 113 | +end |
| 114 | + |
| 115 | +function (jacF!::JacF_RLM!)( |
| 116 | + M::AbstractManifold, |
| 117 | + J, |
| 118 | + p; |
| 119 | + basis_domain::AbstractBasis = DefaultOrthogonalBasis(), |
| 120 | +) |
| 121 | + |
| 122 | + X0 = jacF!.X0 |
| 123 | + X = jacF!.X |
| 124 | + q = jacF!.q |
| 125 | + |
| 126 | + fill!(X0, 0) |
| 127 | + |
| 128 | + # J .= FiniteDiff.finite_difference_jacobian( |
| 129 | + # Xc -> jacF!.costF!(M, jacF!.res, exp!(M, q, p, get_vector!(M, X, p, Xc, basis_domain))), |
| 130 | + # X0, |
| 131 | + # ) |
| 132 | + FiniteDiff.finite_difference_jacobian!( |
| 133 | + J, |
| 134 | + (res,Xc) -> jacF!.costF!(M, res, exp!(M, q, p, get_vector!(M, X, p, Xc, basis_domain))), |
| 135 | + X0, |
| 136 | + ) |
| 137 | + return J |
| 138 | +end |
| 139 | + |
| 140 | + |
| 141 | +function solve_RLM( |
| 142 | + fg, |
| 143 | + frontals::Vector{Symbol} = ls(fg), |
| 144 | + separators::Vector{Symbol} = setdiff(ls(fg), frontals); |
| 145 | +) |
| 146 | + @error "#FIXME, use covariances" maxlog=1 |
| 147 | + |
| 148 | + # get the subgraph formed by all frontals, separators and fully connected factors |
| 149 | + varlabels = union(frontals, separators) |
| 150 | + faclabels = sortDFG(setdiff(getNeighborhood(fg, varlabels, 1), varlabels)) |
| 151 | + |
| 152 | + filter!(faclabels) do fl |
| 153 | + return issubset(getVariableOrder(fg, fl), varlabels) |
| 154 | + end |
| 155 | + |
| 156 | + facs = getFactor.(fg, faclabels) |
| 157 | + |
| 158 | + # so the subgraph consists of varlabels(frontals + separators) and faclabels |
| 159 | + |
| 160 | + varIntLabel = OrderedDict(zip(varlabels, collect(1:length(varlabels)))) |
| 161 | + |
| 162 | + # varIntLabel_frontals = filter(p->first(p) in frontals, varIntLabel) |
| 163 | + # varIntLabel_separators = filter(p->first(p) in separators, varIntLabel) |
| 164 | + |
| 165 | + calcfacs = CalcFactorManopt.(fg, facs, Ref(varIntLabel)) |
| 166 | + |
| 167 | + |
| 168 | + # get the manifold and variable types |
| 169 | + frontal_vars = getVariable.(fg, frontals) |
| 170 | + vartypes, vartypecount, vartypeslist = getVariableTypesCount(frontal_vars) |
| 171 | + |
| 172 | + PMs = map(vartypes) do vartype |
| 173 | + N = vartypecount[vartype] |
| 174 | + G = getManifold(vartype) |
| 175 | + return IIF.NPowerManifold(G, N) |
| 176 | + end |
| 177 | + M = ProductManifold(PMs...) |
| 178 | + |
| 179 | + # |
| 180 | + #FIXME |
| 181 | + @assert length(M.manifolds) == 1 "#FIXME, this only works with 1 manifold type component" |
| 182 | + MM = M.manifolds[1] |
| 183 | + |
| 184 | + # inital values and separators from fg |
| 185 | + fro_p = first.(getVal.(frontal_vars, solveKey = :parametric)) |
| 186 | + sep_p::Vector{eltype(fro_p)} = first.(getVal.(fg, separators, solveKey = :parametric)) |
| 187 | + |
| 188 | + |
| 189 | + if false |
| 190 | + # non-in-place version updated bolow to in-place |
| 191 | + fullp::Vector{eltype(fro_p)} = [fro_p; sep_p] |
| 192 | + # cost function f: M->ℝᵈ for Riemannian Levenberg-Marquardt |
| 193 | + function costF_RLM(M::AbstractManifold, p::Vector{T}) where T |
| 194 | + fullp[1:length(p)] .= p |
| 195 | + return Vector(mapreduce(f -> f(fullp), vcat, calcfacs)) |
| 196 | + end |
| 197 | + |
| 198 | + # jacobian of function for Riemannian Levenberg-Marquardt |
| 199 | + function jacF_RLM( |
| 200 | + M::AbstractManifold, |
| 201 | + p; |
| 202 | + basis_domain::AbstractBasis = DefaultOrthogonalBasis(), |
| 203 | + ) |
| 204 | + X0 = zeros(manifold_dimension(M)) |
| 205 | + J = FiniteDiff.finite_difference_jacobian( |
| 206 | + x -> costF_RLM(M, exp(M, p, get_vector(M, p, x, basis_domain))), |
| 207 | + X0, |
| 208 | + ) |
| 209 | + # J = ForwardDiff.jacobian( |
| 210 | + # x -> costF_RLM(M, exp(M, p, get_vector(M, p, x, basis_domain))), |
| 211 | + # X0, |
| 212 | + # ) |
| 213 | + return J |
| 214 | + end |
| 215 | + |
| 216 | + # 0.296639 seconds (2.46 M allocations: 164.722 MiB, 12.83% gc time) |
| 217 | + p0 = deepcopy(fro_p) |
| 218 | + lm_r = LevenbergMarquardt(MM, costF_RLM, jacF_RLM, p0) |
| 219 | + |
| 220 | + # 81.185117 seconds (647.20 M allocations: 41.680 GiB, 8.61% gc time) |
| 221 | + else |
| 222 | + # 74.420872 seconds (567.70 M allocations: 34.698 GiB, 8.30% gc time, 0.66% compilation time) |
| 223 | + #cost and jacobian functions |
| 224 | + # cost function f: M->ℝᵈ for Riemannian Levenberg-Marquardt |
| 225 | + costF! = CostF_RLM!(calcfacs, fro_p, sep_p) |
| 226 | + # jacobian of function for Riemannian Levenberg-Marquardt |
| 227 | + jacF! = JacF_RLM!(MM, costF!) |
| 228 | + |
| 229 | + num_components = length(jacF!.res) |
| 230 | + |
| 231 | + p0 = deepcopy(fro_p) |
| 232 | + lm_r = LevenbergMarquardt(MM, costF!, jacF!, p0, num_components; evaluation=InplaceEvaluation()) |
| 233 | + end |
| 234 | + |
| 235 | + return vartypeslist, lm_r |
| 236 | +end |
| 237 | + |
| 238 | +function autoinitParametricManopt!( |
| 239 | + fg, |
| 240 | + varorderIds = getInitOrderParametric(fg); |
| 241 | + reinit = false, |
| 242 | +) |
| 243 | + @showprogress for vIdx in varorderIds |
| 244 | + autoinitParametricManopt!(fg, vIdx; reinit) |
| 245 | + end |
| 246 | + return nothing |
| 247 | +end |
| 248 | + |
| 249 | +function autoinitParametricManopt!(dfg::AbstractDFG, initme::Symbol; kwargs...) |
| 250 | + return autoinitParametricManopt!(dfg, getVariable(dfg, initme); kwargs...) |
| 251 | +end |
| 252 | + |
| 253 | +function autoinitParametricManopt!( |
| 254 | + dfg::AbstractDFG, |
| 255 | + xi::DFGVariable; |
| 256 | + solveKey = :parametric, |
| 257 | + reinit::Bool = false, |
| 258 | + kwargs..., |
| 259 | +) |
| 260 | + # |
| 261 | + |
| 262 | + initme = getLabel(xi) |
| 263 | + vnd = getSolverData(xi, solveKey) |
| 264 | + # don't initialize a variable more than once |
| 265 | + if reinit || !isInitialized(xi, solveKey) |
| 266 | + |
| 267 | + # frontals - initme |
| 268 | + # separators - inifrom |
| 269 | + |
| 270 | + initfrom = ls2(dfg, initme) |
| 271 | + filter!(initfrom) do vl |
| 272 | + return isInitialized(dfg, vl, solveKey) |
| 273 | + end |
| 274 | + |
| 275 | + vartypeslist, lm_r = solve_RLM(dfg, [initme], initfrom) |
| 276 | + |
| 277 | + val = lm_r[1] |
| 278 | + vnd::VariableNodeData = getSolverData(xi, solveKey) |
| 279 | + vnd.val[1] = val |
| 280 | + |
| 281 | + |
| 282 | + # val = lm_r[1] |
| 283 | + # cov = ... |
| 284 | + # updateSolverDataParametric!(vnd, val, cov) |
| 285 | + |
| 286 | + vnd.initialized = true |
| 287 | + #fill in ppe as mean |
| 288 | + Xc::Vector{Float64} = collect(getCoordinates(getVariableType(xi), val)) |
| 289 | + ppe = MeanMaxPPE(:parametric, Xc, Xc, Xc) |
| 290 | + getPPEDict(xi)[:parametric] = ppe |
| 291 | + |
| 292 | + result = vartypeslist, lm_r |
| 293 | + |
| 294 | + else |
| 295 | + result = nothing |
| 296 | + end |
| 297 | + |
| 298 | + return result#isInitialized(xi, solveKey) |
| 299 | +end |
0 commit comments