Skip to content

Commit c6a9b0b

Browse files
committed
No Parametric Mixtures for now
1 parent 808da77 commit c6a9b0b

File tree

2 files changed

+171
-152
lines changed

2 files changed

+171
-152
lines changed

src/parametric/services/ParametricUtils.jl

Lines changed: 125 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,3 @@
1-
# ================================================================================================
2-
## FactorOperationalMemory for parametric, TODO move back to FactorOperationalMemory.jl
3-
## ================================================================================================
4-
5-
6-
# struct CalcFactorMahalanobis{CF<:CalcFactor, S<:Union{Nothing,AbstractMaxMixtureSolver}, N}
7-
# calcfactor!::CF
8-
# varOrder::Vector{Symbol}
9-
# meas::NTuple{N, <:AbstractArray}
10-
# iΣ::NTuple{N, Matrix{Float64}}
11-
# specialAlg::S
12-
# end
13-
141
# ================================================================================================
152
## FlatVariables - used for packing variables for optimization
163
## ================================================================================================
@@ -56,8 +43,7 @@ end
5643
$SIGNATURES
5744
5845
Returns the parametric measurement for a factor as a tuple (measurement, inverse covariance) for parametric inference (assuming Gaussian).
59-
Defaults to find the parametric measurement at field `Z`, fields `Zij` and `z` are deprecated for standardization.
60-
46+
Defaults to find the parametric measurement at field `Z`.
6147
Notes
6248
- Users should overload this method should their factor not default to `.Z<:ParametricType`.
6349
- First design choice was to restrict this function to returning coordinates
@@ -99,7 +85,6 @@ function getMeasurementParametric(s::AbstractFactor)
9985
if hasfield(typeof(s), :Z)
10086
Z = s.Z
10187
else
102-
@warn "getMeasurementParametric falls back to using field `.Z` by default. Extend it for more complex factors."
10388
error(
10489
"getMeasurementParametric(::$(typeof(s))) not defined, please add it, or use non-parametric, or open an issue for help.",
10590
)
@@ -111,6 +96,33 @@ end
11196
getMeasurementParametric(fct::DFGFactor) = getMeasurementParametric(getFactorType(fct))
11297
getMeasurementParametric(dfg::AbstractDFG, flb::Symbol) = getMeasurementParametric(getFactor(dfg, flb))
11398

99+
# maybe rename getMeasurementParametric to something like getNormalDistributionParams or getMeanCov
100+
101+
# default to point on manifold
102+
function getFactorMeasurementParametric(fac::AbstractPrior)
103+
M = getManifold(fac)
104+
ϵ = getPointIdentity(M)
105+
dims = manifold_dimension(M)
106+
Xc, iΣ = getMeasurementParametric(fac)
107+
X = get_vector(M, ϵ, Xc, DefaultOrthogonalBasis())
108+
meas = convert(typeof(ϵ), exp(M, ϵ, X))
109+
= convert(SMatrix{dims, dims}, iΣ)
110+
meas, iΣ
111+
end
112+
# default to point on tangent vector
113+
function getFactorMeasurementParametric(fac::AbstractRelative)
114+
M = getManifold(fac)
115+
ϵ = getPointIdentity(M)
116+
dims = manifold_dimension(M)
117+
Xc, iΣ = getMeasurementParametric(fac)
118+
measX = convert(typeof(ϵ), get_vector(M, ϵ, Xc, DefaultOrthogonalBasis()))
119+
= convert(SMatrix{dims, dims}, iΣ)
120+
measX, iΣ
121+
end
122+
123+
getFactorMeasurementParametric(fct::DFGFactor) = getFactorMeasurementParametric(getFactorType(fct))
124+
getFactorMeasurementParametric(dfg::AbstractDFG, flb::Symbol) = getFactorMeasurementParametric(getFactor(dfg, flb))
125+
114126
## ================================================================================================
115127
## Parametric solve with Mahalanobis distance - CalcFactor
116128
## ================================================================================================
@@ -124,41 +136,18 @@ function CalcFactorMahalanobis(fg, fct::DFGFactor)
124136
varOrder = getVariableOrder(fct)
125137

126138
# NOTE, use getMeasurementParametric on DFGFactor{<:CCW} to allow special cases like OAS factors
127-
_meas, _iΣ = getMeasurementParametric(fct) # fac_func
128-
M = getManifold(getFactorType(fct))
129-
dims = manifold_dimension(M)
130-
ϵ = getPointIdentity(M)
131-
132-
_measX = if typeof(_meas) <: Tuple
133-
# TODO perhaps better consolidate manifold prior
134-
map(m -> hat(M, ϵ, m), _meas)
135-
elseif fac_func isa ManifoldPrior
136-
(_meas,)
137-
else
138-
(convert(typeof(ϵ), get_vector(M, ϵ, _meas, DefaultOrthogonalBasis())),)
139-
end
140-
141-
meas = fac_func isa AbstractPrior ? map(X -> exp(M, ϵ, X), _measX) : _measX
142-
143-
= convert.(SMatrix{dims, dims}, typeof(_iΣ) <: Tuple ? _iΣ : (_iΣ,))
139+
_meas, _iΣ = getFactorMeasurementParametric(fct) # fac_func
140+
141+
# make sure its a tuple TODO Fix with mixture rework #1504
142+
meas = typeof(_meas) <: Tuple ? _meas : (_meas,)
143+
= typeof(_iΣ) <: Tuple ? _iΣ : (_iΣ,)
144144

145145
cache = preambleCache(fg, getVariable.(fg, varOrder), getFactorType(fct))
146146

147-
calcf = CalcFactor(
148-
getFactorMechanics(fac_func),
149-
0,
150-
nothing,
151-
true,
152-
cache,
153-
(), #DFGVariable[],
154-
0,
155-
getManifold(_getCCW(fct)) # getManifold(fac_func)
156-
)
157-
158147
multihypo = getSolverData(fct).multihypo
159148
nullhypo = getSolverData(fct).nullhypo
160149

161-
# FIXME, type instability, use dispatch instead of if-else
150+
# FIXME, type instability
162151
if length(multihypo) > 0
163152
special = MaxMultihypo(multihypo)
164153
elseif nullhypo > 0
@@ -169,16 +158,22 @@ function CalcFactorMahalanobis(fg, fct::DFGFactor)
169158
special = nothing
170159
end
171160

172-
return CalcFactorMahalanobis(fct.label, calcf, varOrder, meas, iΣ, special)
161+
return CalcFactorMahalanobis(fct.label, getFactorMechanics(fac_func), cache, varOrder, meas, iΣ, special)
173162
end
174163

175164
# This is where the actual parametric calculation happens, CalcFactor equivalent for parametric
176-
@inline function (cfp::CalcFactorMahalanobis{1, D, L, Nothing})(variables...) where {D, L}# AbstractArray{T} where T <: Real
177-
# call the user function
178-
res = cfp.calcfactor!(cfp.meas..., variables...)
179-
# 1/2*log(1/( sqrt(det(Σ)*(2pi)^k) )) ## k = dim(μ)
180-
return res' * cfp.iΣ[1] * res
181-
end
165+
# function (cfp::CalcFactorMahalanobis{FT, 1, C, MEAS, D, L, Nothing})(variables...) where {FT, C, MEAS, D, L, Nothing}# AbstractArray{T} where T <: Real
166+
# # call the user function
167+
# res = cfp.calcfactor!(cfp.meas..., variables...)
168+
# # 1/2*log(1/( sqrt(det(Σ)*(2pi)^k) )) ## k = dim(μ)
169+
# return res' * cfp.iΣ[1] * res
170+
# end
171+
172+
# function (cfm::CalcFactorMahalanobis)(variables...)
173+
# meas = cfm.meas
174+
# points = map(idx->p[idx], cfm.varOrderIdxs)
175+
# return cfm.sqrt_iΣ * cfm(meas, points...)
176+
# end
182177

183178
function calcFactorMahalanobisDict(fg)
184179
calcFactors = OrderedDict{Symbol, CalcFactorMahalanobis}()
@@ -190,20 +185,39 @@ function calcFactorMahalanobisDict(fg)
190185
return calcFactors
191186
end
192187

193-
# Base.eltype(::Type{<:CalcFactorMahalanobis}) = CalcFactorMahalanobis
188+
function getFactorTypesCount(facs::Vector{<:DFGFactor})
189+
typedict = OrderedDict{DataType, Int}()
190+
alltypes = OrderedDict{DataType, Vector{Symbol}}()
191+
for f in facs
192+
facType = typeof(getFactorType(f))
193+
cnt = get!(typedict, facType, 0)
194+
typedict[facType] = cnt + 1
194195

195-
# function calcFactorMahalanobisArray(fg)
196-
# cfps = map(getFactors(fg)) do fct
197-
# CalcFactorMahalanobis(fg, fct)
198-
# end
199-
# types = collect(Set(typeof.(cfps)))
200-
# cfparr = ArrayPartition(map(x->Vector{x}(), types)...)
201-
# for cfp in cfps
202-
# idx = findfirst(==(typeof(cfp)), types)
203-
# push!(cfparr.x[idx], cfp)
204-
# end
205-
# return cfparr
206-
# end
196+
dt = get!(alltypes, facType, Symbol[])
197+
push!(dt, f.label)
198+
end
199+
#TODO tuple or vector?
200+
# vartypes = tuple(keys(typedict)...)
201+
factypes::Vector{DataType} = collect(keys(typedict))
202+
return factypes, typedict, alltypes
203+
end
204+
205+
function calcFactorMahalanobisVec(fg)
206+
factypes, typedict, alltypes = getFactorTypesCount(getFactors(fg))
207+
208+
# skip non-numeric prior (MetaPrior)
209+
#TODO test... remove MetaPrior{T} something like this
210+
metaPriorKeys = filter(k->contains(string(k), "MetaPrior"), collect(keys(alltypes)))
211+
delete!.(Ref(alltypes), metaPriorKeys)
212+
213+
parts = map(values(alltypes)) do labels
214+
map(getFactor.(fg, labels)) do fct
215+
CalcFactorMahalanobis(fg, fct)
216+
end
217+
end
218+
parts_tuple = (parts...,)
219+
return ArrayPartition{CalcFactorMahalanobis, typeof(parts_tuple)}(parts_tuple)
220+
end
207221

208222
## ================================================================================================
209223
## ================================================================================================
@@ -265,8 +279,10 @@ function getVariableTypesCount(vars::Vector{<:DFGVariable})
265279
return vartypes, typedict, alltypes
266280
end
267281

268-
function buildGraphSolveManifold(fg::AbstractDFG)
269-
vartypes, vartypecount, vartypeslist = getVariableTypesCount(fg)
282+
buildGraphSolveManifold(fg::AbstractDFG) = buildGraphSolveManifold(getVariables(fg))
283+
284+
function buildGraphSolveManifold(vars::Vector{<:DFGVariable})
285+
vartypes, vartypecount, vartypeslist = getVariableTypesCount(vars)
270286

271287
PMs = map(vartypes) do vartype
272288
N = vartypecount[vartype]
@@ -294,34 +310,32 @@ function GraphSolveBuffers(@nospecialize(M), ::Type{T}) where {T}
294310
return GraphSolveBuffers(ϵ, p, X, Xc)
295311
end
296312

297-
struct GraphSolveContainer
313+
struct GraphSolveContainer{CFT}
298314
M::AbstractManifold # ProductManifold or ProductGroup
299315
buffers::OrderedDict{DataType, GraphSolveBuffers}
300316
varTypes::Vector{DataType}
301317
varTypesIds::OrderedDict{DataType, Vector{Symbol}}
302-
cfdict::OrderedDict{Symbol, CalcFactorMahalanobis}
303318
varOrderDict::OrderedDict{Symbol, Tuple{Int, Vararg{Int}}}
304-
# cfarr::AbstractVector # TODO maybe <: AbstractVector(CalcFactorMahalanobis)
319+
cfv::ArrayPartition{CalcFactorMahalanobis, CFT}
305320
end
306321

307322
function GraphSolveContainer(fg)
308323
M, varTypes, varTypesIds = buildGraphSolveManifold(fg)
309324
varTypesIndexes = ArrayPartition(values(varTypesIds)...)
310325
buffs = OrderedDict{DataType, GraphSolveBuffers}()
311-
cfd = calcFactorMahalanobisDict(fg)
326+
cfvec = calcFactorMahalanobisVec(fg)
312327

313328
varOrderDict = OrderedDict{Symbol, Tuple{Int, Vararg{Int}}}()
314-
for (fid, cfp) in cfd
329+
for cfp in cfvec
330+
fid = cfp.faclbl
315331
varOrder = cfp.varOrder
316332
var_idx = map(varOrder) do v
317333
return findfirst(==(v), varTypesIndexes)
318334
end
319335
varOrderDict[fid] = tuple(var_idx...)
320336
end
321337

322-
# cfarr = calcFactorMahalanobisArray(fg)
323-
# return GraphSolveContainer(M, buffs, varTypes, varTypesIds, cfd, varOrderDict, cfarr)
324-
return GraphSolveContainer(M, buffs, varTypes, varTypesIds, cfd, varOrderDict)
338+
return GraphSolveContainer(M, buffs, varTypes, varTypesIds, varOrderDict, cfvec)
325339
end
326340

327341
function getGraphSolveCache!(gsc::GraphSolveContainer, ::Type{T}) where {T <: Real}
@@ -348,11 +362,15 @@ function _toPoints2!(
348362
end
349363

350364
function cost_cfp(
351-
@nospecialize(cfp::CalcFactorMahalanobis),
352-
@nospecialize(p::AbstractArray),
365+
cfp::CalcFactorMahalanobis,
366+
p::AbstractArray{T},
353367
vi::NTuple{N, Int},
354-
) where N
355-
cfp(map(v->p[v],vi)...)
368+
) where {T,N}
369+
# cfp(map(v->p[v],vi)...)
370+
res = cfp(cfp.meas..., map(v->p[v],vi)...)
371+
# 1/2*log(1/( sqrt(det(Σ)*(2pi)^k) )) ## k = dim(μ)
372+
return res' * cfp.iΣ[1] * res
373+
356374
end
357375
# function cost_cfp(
358376
# @nospecialize(cfp::CalcFactorMahalanobis),
@@ -403,17 +421,19 @@ function (gsc::GraphSolveContainer)(Xc::Vector{T}) where {T <: Real}
403421
#
404422
buffs = getGraphSolveCache!(gsc, T)
405423

406-
cfdict = gsc.cfdict
407424
varOrderDict = gsc.varOrderDict
408425

409426
M = gsc.M
410427

411428
p = _toPoints2!(M, buffs, Xc)
412-
413-
obj = mapreduce(+, cfdict) do (fid, cfp)
414-
varOrder_idx = varOrderDict[fid]
415-
# call the user function
416-
return cost_cfp(cfp, p, varOrder_idx)
429+
430+
obj = mapreduce(+, eachindex(gsc.cfv)) do i
431+
cfp = gsc.cfv[i]
432+
varOrder_idx = varOrderDict[cfp.faclbl]
433+
# # call the user function
434+
cost::T = cost_cfp(cfp, p, varOrder_idx)
435+
436+
return cost
417437
end
418438

419439
return obj / 2
@@ -507,6 +527,7 @@ function solveGraphParametric(
507527
autodiff = :forward,
508528
algorithm = Optim.BFGS,
509529
algorithmkwargs = (), # add manifold to overwrite computed one
530+
# algorithmkwargs = (linesearch=Optim.BackTracking(),), # add manifold to overwrite computed one
510531
options = Optim.Options(;
511532
allow_f_increases = true,
512533
time_limit = 100,
@@ -539,22 +560,7 @@ function solveGraphParametric(
539560

540561
#optim setup and solve
541562
alg = algorithm(; algorithmkwargs...)
542-
# alg = NewtonTrustRegion(;
543-
# initial_delta = 1.0,
544-
# delta_hat = 100.0,
545-
# eta = 0.1,
546-
# rho_lower = 0.25,
547-
# rho_upper = 0.75
548-
# )
549-
# alg = LBFGS(;
550-
# m = 10,
551-
# alphaguess = LineSearches.InitialStatic(),
552-
# linesearch = LineSearches.HagerZhang(),
553-
# P = nothing,
554-
# precondprep = (P, x) -> nothing,
555-
# manifold = Flat(),
556-
# scaleinvH0::Bool = true && (typeof(P) <: Nothing)
557-
# )
563+
558564
tdtotalCost = Optim.TwiceDifferentiable(gsc, initValues; autodiff = autodiff)
559565

560566
result = Optim.optimize(tdtotalCost, initValues, alg, options)
@@ -609,10 +615,10 @@ function _totalCost(fg, cfdict::OrderedDict{Symbol, <:CalcFactorMahalanobis}, fl
609615
]
610616

611617
# call the user function
612-
retval = cfp(Xparams...)
613-
618+
# retval = cfp(Xparams...)
619+
res = cfp(cfp.meas..., Xparams...)
614620
# 1/2*log(1/( sqrt(det(Σ)*(2pi)^k) )) ## k = dim(μ)
615-
obj += 1 / 2 * retval
621+
obj += 1 / 2 * res' * cfp.iΣ[1] * res
616622
end
617623

618624
return obj
@@ -890,7 +896,7 @@ end
890896
$SIGNATURES
891897
Update the fg from solution in vardict and add MeanMaxPPE (all just mean). Usefull for plotting
892898
"""
893-
function updateParametricSolution!(sfg, vardict; solveKey::Symbol = :parametric)
899+
function updateParametricSolution!(sfg, vardict::AbstractDict; solveKey::Symbol = :parametric)
894900
for (v, val) in vardict
895901
vnd = getSolverData(getVariable(sfg, v), solveKey)
896902
# Update the variable node data value and covariance
@@ -902,6 +908,18 @@ function updateParametricSolution!(sfg, vardict; solveKey::Symbol = :parametric)
902908
end
903909
end
904910

911+
function updateParametricSolution!(sfg, labels::AbstractArray{Symbol}, vals; solveKey::Symbol = :parametric)
912+
for (v, val) in zip(labels, vals)
913+
vnd = getSolverData(getVariable(sfg, v), solveKey)
914+
# Update the variable node data value and covariance
915+
updateSolverDataParametric!(vnd, val, vnd.bw)#FIXME add cov
916+
#fill in ppe as mean
917+
Xc = collect(getCoordinates(getVariableType(sfg, v), val))
918+
ppe = MeanMaxPPE(solveKey, Xc, Xc, Xc)
919+
getPPEDict(getVariable(sfg, v))[solveKey] = ppe
920+
end
921+
end
922+
905923
function createMvNormal(val, cov)
906924
#TODO do something better for properly formed covariance, but for now just a hack...FIXME
907925
if all(diag(cov) .> 0.001) && isapprox(cov, transpose(cov); rtol = 1e-4)
@@ -939,9 +957,10 @@ function autoinitParametric!(
939957
reinit = false,
940958
algorithm = Optim.NelderMead,
941959
algorithmkwargs = (initial_simplex = Optim.AffineSimplexer(0.025, 0.1),),
960+
kwargs...
942961
)
943962
@showprogress for vIdx in varorderIds
944-
autoinitParametric!(fg, vIdx; reinit, algorithm, algorithmkwargs)
963+
autoinitParametric!(fg, vIdx; reinit, algorithm, algorithmkwargs, kwargs...)
945964
end
946965
return nothing
947966
end

0 commit comments

Comments
 (0)