Skip to content

Commit abb504c

Browse files
authored
Update val/bw for DFGv1 State refactor (#1904)
1 parent a1cfcf4 commit abb504c

15 files changed

+142
-136
lines changed

IncrementalInference/src/IncrementalInference.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ import ApproxManifoldProducts: isPartial
110110
import ApproxManifoldProducts: _update!
111111
import DistributedFactorGraphs: addVariable!, addFactor!, ls, lsf, isInitialized
112112
import DistributedFactorGraphs: compare
113-
import DistributedFactorGraphs: rebuildFactorCache!
114113
import DistributedFactorGraphs: getDimension, getManifold, getPointType, getPointIdentity
115114
import DistributedFactorGraphs: getPoint, getCoordinates
116115
import DistributedFactorGraphs: getStateKind

IncrementalInference/src/entities/BeliefTypes.jl

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,24 @@ function TreeBelief(
7878
end
7979

8080
function TreeBelief(vnd::State, solvDim::Real = 0)
81+
TreeBelief(DFG.getDensityKind(vnd), vnd, solvDim)
82+
end
83+
84+
function TreeBelief(::DFG.GaussianDensityKind, vnd::State, solvDim::Real = 0)
85+
return TreeBelief(
86+
DFG.refMeans(vnd),
87+
DFG.refCovariances(vnd)[1],
88+
vnd.observability,
89+
getStateKind(vnd),
90+
getManifold(vnd),
91+
solvDim,
92+
)
93+
end
94+
95+
function TreeBelief(::DFG.NonparametricDensityKind, vnd::State, solvDim::Real = 0)
8196
return TreeBelief(
82-
vnd.val,
83-
vnd.bw,
97+
DFG.refPoints(vnd),
98+
DFG.refBandwidth(vnd),
8499
vnd.observability,
85100
getStateKind(vnd),
86101
getManifold(vnd),
@@ -93,9 +108,9 @@ function TreeBelief(vari::VariableCompute, solveKey::Symbol = :default; solvable
93108
end
94109
#
95110

96-
getStateKind(tb::TreeBelief) = tb.variableType
111+
DFG.getStateKind(tb::TreeBelief) = tb.variableType
97112

98-
getManifold(treeb::TreeBelief) = getManifold(treeb.variableType)
113+
DFG.getManifold(treeb::TreeBelief) = getManifold(treeb.variableType)
99114

100115
function compare(t1::TreeBelief, t2::TreeBelief)
101116
TP = true

IncrementalInference/src/parametric/services/ParametricCSMFunctions.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ function solveUp_ParametricStateMachine(csmc::CliqStateMachineContainer)
3838
vnd = getState(getVariable(csmc.cliqSubFg, v), :parametric)
3939
# fill in the variable node data value
4040
logCSM(csmc, "$(csmc.cliq.id) up: updating $v : $val")
41-
vnd.val[1] = val.val
41+
DFG.refMeans(vnd)[1] = val.val
4242
#calculate and fill in covariance
4343
#TODO rather broadcast than make new memory
44-
vnd.bw = val.cov
44+
DFG.refCovariances(vnd)[1] = val.cov
4545
end
4646
# elseif length(lsfPriors(csmc.cliqSubFg)) == 0 #FIXME
4747
# @error "Par-3, clique $(csmc.cliq.id) failed to converge in upsolve, but ignoring since no priors" result
@@ -118,8 +118,8 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer)
118118
#TODO maybe combine variable and factor in new prior?
119119
vnd = getState(getVariable(csmc.cliqSubFg, msym), :parametric)
120120
logCSM(csmc, "$(csmc.cliq.id): Updating separator $msym from message $(belief.val)")
121-
vnd.val .= belief.val
122-
vnd.bw .= belief.bw
121+
DFG.refMeans(vnd)[1] = belief.val[1] #FIXME 🦨 shares data structure in belief
122+
DFG.refCovariances(vnd)[1] = belief.bw
123123
end
124124
end
125125
end
@@ -146,8 +146,8 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer)
146146
logCSM(csmc, "$(csmc.cliq.id) down: updating $v : $val"; loglevel = Logging.Info)
147147
vnd = getState(getVariable(csmc.cliqSubFg, v), :parametric)
148148
#Update subfg variables
149-
vnd.val[1] = val.val
150-
vnd.bw .= val.cov
149+
DFG.refMeans(vnd)[1] = val.val
150+
DFG.refCovariances(vnd)[1] = val.cov
151151
end
152152
else
153153
@error "Par-5, clique $(csmc.cliq.id) failed to converge in down solve" result

IncrementalInference/src/parametric/services/ParametricManopt.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ function solve_RLM(
330330

331331
#Can use varIntLabel (because its an OrderedDict), but varLabelsAP makes the ArrayPartition.
332332
p0 = map(varlabelsAP) do label
333-
getVal(fg, label; solveKey)[1]
333+
DFG.refMeans(getState(fg, label, solveKey))[1]
334334
end
335335

336336
# create an ArrayPartition{CalcFactorResidual} for faclabels
@@ -415,7 +415,7 @@ function solve_RLM_conditional(
415415

416416
# get the subgraph formed by all frontals, separators and fully connected factors
417417
varlabels = union(frontals, separators)
418-
faclabels = sortDFG(setdiff(listNeighborhood(fg, varlabels, 1), varlabels))
418+
_, faclabels = listNeighborhood(fg, varlabels, 1)
419419

420420
filter!(faclabels) do fl
421421
return issubset(getVariableOrder(fg, fl), varlabels)
@@ -440,7 +440,7 @@ function solve_RLM_conditional(
440440
all_varlabelsAP = ArrayPartition((frontal_varlabelsAP.x..., separator_varlabelsAP.x...))
441441

442442
all_points = map(all_varlabelsAP) do label
443-
getVal(fg, label; solveKey)[1]
443+
DFG.refMeans(getState(fg, label, solveKey))[1]
444444
end
445445

446446
p0 = ArrayPartition(all_points.x[1:length(frontal_varlabelsAP.x)])
@@ -547,12 +547,10 @@ function autoinitParametric!(
547547
return false
548548
end
549549

550-
vnd::State = getState(xi, solveKey)
551-
552550
if perturb_point
553551
_M = getManifold(xi)
554-
p = vnd.val[1]
555-
vnd.val[1] = exp(
552+
p = DFG.refMeans(vnd)[1]
553+
DFG.refMeans(vnd)[1] = exp(
556554
_M,
557555
p,
558556
get_vector(
@@ -566,9 +564,9 @@ function autoinitParametric!(
566564
M, vartypeslist, lm_r, Σ = solve_RLM_conditional(dfg, [initme], initfrom; solveKey, kwargs...)
567565

568566
val = lm_r[1]
569-
vnd.val[1] = val
567+
DFG.refMeans(vnd)[1] = val
570568

571-
!isnothing(Σ) && (vnd.bw .= Σ)
569+
!isnothing(Σ) && (DFG.refCovariances(vnd)[1] .= Σ)
572570

573571
# updateSolverDataParametric!(vnd, val, Σ)
574572

IncrementalInference/src/parametric/services/ParametricUtils.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ function initPoints!(p, gsc, fg::AbstractDFG, solveKey = :parametric)
500500
for (i, vartype) in enumerate(gsc.varTypes)
501501
varIds = gsc.varTypesIds[vartype]
502502
for (j, vId) in enumerate(varIds)
503-
p[gsc.M, i][j] = getState(fg, vId, solveKey).val[1]
503+
p[gsc.M, i][j] = DFG.refMeans(getState(fg, vId, solveKey))[1]
504504
end
505505
end
506506
end
@@ -687,7 +687,7 @@ function solveConditionalsParametric(
687687
flatvar = FlatVariables(fg, varIds)
688688

689689
for vId in varIds
690-
p = getState(fg, vId, solvekey).val[1]
690+
p = DFG.refMeans(getState(fg, vId, solvekey))[1]
691691
flatvar[vId] = getCoordinates(getStateKind(fg, vId), p)
692692
end
693693
initValues = flatvar.X
@@ -825,9 +825,9 @@ function updateSolverDataParametric!(
825825
cov::AbstractMatrix,
826826
)
827827
# fill in the variable node data value
828-
vnd.val[1] = val
828+
DFG.refMeans(vnd)[1] = val
829829
#calculate and fill in covariance
830-
vnd.bw .= cov
830+
DFG.refCovariances(vnd)[1] .= cov
831831
return vnd
832832
end
833833

@@ -888,15 +888,15 @@ function initParametricFrom!(
888888
for v in getVariables(fg)
889889
fromvnd = getState(v, fromkey)
890890
dims = getDimension(v)
891-
getState(v, parkey).val[1] = fromvnd.val[1]
892-
getState(v, parkey).bw[1:dims, 1:dims] = LinearAlgebra.I(dims)
891+
DFG.refMeans(getState(v, parkey))[1] = DFG.refMeans(fromvnd)[1]
892+
DFG.refCovariances(getState(v, parkey))[1] = LinearAlgebra.I(dims)
893893
end
894894
else
895895
for var in getVariables(fg)
896896
dims = getDimension(var)
897897
μ, Σ = calcMeanCovar(var, fromkey)
898-
getState(var, parkey).val[1] = μ
899-
getState(var, parkey).bw[1:dims, 1:dims] = Σ
898+
DFG.refMeans(getState(var, parkey))[1] = μ
899+
DFG.refCovariances(getState(var, parkey))[1] = Σ
900900
end
901901
end
902902
end
@@ -963,10 +963,12 @@ function createMvNormal(v::VariableCompute, key = :parametric)
963963
if key == :parametric
964964
vnd = getState(v, :parametric)
965965
dims = getDimension(vnd)
966-
return createMvNormal(vnd.val[1:dims, 1], vnd.bw[1:dims, 1:dims])
966+
val = DFG.refMeans(vnd)[1]
967+
cov = DFG.refCovariances(vnd)[1:dims, 1:dims]
968+
return createMvNormal(val, cov)
967969
else
968970
@warn "Trying MvNormal Fit"
969-
return fit(MvNormal, getState(v, key).val)
971+
return fit(MvNormal, DFG.refPoints(getState(v, key)))
970972
end
971973
end
972974

IncrementalInference/src/services/FactorGraph.jl

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,15 @@ reshapeVec2Mat(vec::Vector, rows::Int) = reshape(vec, rows, round(Int, length(ve
5353
5454
Fetch the variable marginal joint sampled points. Use [`getBelief`](@ref) to retrieve the full Belief object.
5555
"""
56-
getVal(v::VariableCompute; solveKey::Symbol = :default) = v.states[solveKey].val
56+
#FIXME replace with refPoints
57+
getVal(v::VariableCompute; solveKey::Symbol = :default) = DFG.refPoints(getState(v, solveKey))
5758
function getVal(v::VariableCompute, idx::Int; solveKey::Symbol = :default)
58-
return v.states[solveKey].val[:, idx]
59+
return DFG.refPoints(getState(v, solveKey))[idx]
5960
end
60-
getVal(vnd::State) = vnd.val
61-
getVal(vnd::State, idx::Int) = vnd.val[:, idx]
61+
getVal(vnd::State) = DFG.refPoints(vnd)
62+
getVal(vnd::State, idx::Int) = DFG.refPoints(vnd)[idx]
6263
function getVal(dfg::AbstractDFG, lbl::Symbol; solveKey::Symbol = :default)
63-
return getVariable(dfg, lbl).states[solveKey].val
64+
return DFG.refPoints(getVariable(dfg, lbl).states[solveKey])
6465
end
6566

6667
"""
@@ -73,15 +74,15 @@ function getNumPts(v::VariableCompute; solveKey::Symbol = :default)::Int
7374
end
7475

7576
function AMP.getBW(vnd::State)
76-
return vnd.bw
77+
return DFG.refBandwidth(vnd)
7778
end
7879

7980
# setVal! assumes you will update values to database separate, this used for local graph mods only
8081
function getBWVal(v::VariableCompute; solveKey::Symbol = :default)
81-
return getState(v, solveKey).bw
82+
return DFG.refBandwidth(getState(v, solveKey))
8283
end
8384
function setBW!(vd::State, bw::Array{Float64, 2}; solveKey::Symbol = :default)
84-
vd.bw = bw
85+
DFG.refBandwidth(vd) .= bw
8586
return nothing
8687
end
8788
function setBW!(v::VariableCompute, bw::Array{Float64, 2}; solveKey::Symbol = :default)
@@ -90,7 +91,9 @@ function setBW!(v::VariableCompute, bw::Array{Float64, 2}; solveKey::Symbol = :d
9091
end
9192

9293
function setVal!(vd::State, val::AbstractVector{P}) where {P}
93-
vd.val = val
94+
points = DFG.refPoints(vd)
95+
resize!(points, length(val))
96+
points .= val
9497
return nothing
9598
end
9699
function setVal!(
@@ -393,20 +396,13 @@ function DefaultNodeDataParametric(
393396
# dims, false, :_null, Symbol[], variableType, true, 0.0, false, dontmargin)
394397
else
395398
ϵ = getPointIdentity(variableType)
396-
return State(solveKey, variableType;
397-
val=[ϵ],
398-
bw=zeros(dims, dims),
399-
# Symbol[],
400-
# false,
401-
# :_null,
402-
# Symbol[],
403-
initialized=false,
404-
observability=zeros(dims),
405-
marginalized=false,
406-
# dontmargin,
407-
# 0,
408-
# 0,
399+
belief = DFG.BeliefRepresentation(
400+
DFG.GaussianDensityKind(),
401+
variableType;
402+
means = [ϵ],
403+
covariances = [zeros(dims, dims)],
409404
)
405+
return State(solveKey, variableType; belief)
410406
end
411407
end
412408

@@ -471,26 +467,20 @@ function setDefaultNodeData!(
471467
#
472468
(val, bw)
473469
end
470+
471+
belief = DFG.BeliefRepresentation(
472+
DFG.NonparametricDensityKind(),
473+
varType;
474+
points = val,
475+
bandwidth = bw,
476+
)
474477
# make and set the new solverData
475478
mergeState!(
476479
v,
477480
State(solveKey, varType;
478-
# id=nothing,
479-
val,
480-
bw,
481-
# Symbol[],
482-
# sp,
483-
# dims,
484-
# false,
485-
# :_null,
486-
# Symbol[],
481+
belief,
487482
initialized=isinit,
488-
observability=zeros(getDimension(v)),
489483
marginalized=false,
490-
# dontmargin,
491-
# 0,
492-
# 0,
493-
494484
)
495485
)
496486
return nothing
@@ -569,7 +559,7 @@ addVariable!(fg, :x0, Pose2)
569559
function DFG.addVariable!(
570560
dfg::AbstractDFG,
571561
label::Symbol,
572-
statetype::Union{T, Type{T}};
562+
statekind::Union{T, Type{T}};
573563
tags::Union{Set{Symbol}, Vector{Symbol}} = Set{Symbol}(),
574564
timestamp::Union{TimeDateZone, ZonedDateTime} = DFG.now_tdz(),
575565
solvable::Int = 1,
@@ -600,7 +590,7 @@ function DFG.addVariable!(
600590
tags = union(Set(tags), [:VARIABLE])
601591
v = VariableDFG(
602592
label,
603-
statetype;
593+
statekind;
604594
tags,
605595
bloblets,
606596
blobentries,

IncrementalInference/src/services/GraphInit.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,8 @@ function initVariable!(
349349
if solveKey == :parametric
350350
μ, iΣ = getMeasurementParametric(samplable_belief)
351351
vnd = getState(variable, solveKey)
352-
vnd.val[1] = getPoint(getStateKind(variable), μ)
353-
vnd.bw .= inv(iΣ)
352+
DFG.refMeans(vnd)[1] = getPoint(getStateKind(variable), μ)
353+
DFG.refCovariances(vnd)[1] .= inv(iΣ)
354354
vnd.initialized = true
355355
else
356356
points = [samplePoint(M, samplable_belief) for _ = 1:N]

IncrementalInference/src/services/VariableStatistics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ end
3737

3838
#TODO consolidate
3939
function calcMeanCovar(vari::VariableCompute, solvekey = :default)
40-
pts = getState(vari, solvekey).val
40+
pts = DFG.refPoints(getState(vari, solvekey))
4141
μ = mean(getManifold(vari), pts)
4242
Σ = cov(getStateKind(vari), pts)
4343
return μ, Σ

0 commit comments

Comments
 (0)