Skip to content

Commit 006e0aa

Browse files
authored
Update and fix VariableState and solveKey (#1876)
1 parent 8a9107b commit 006e0aa

File tree

13 files changed

+55
-62
lines changed

13 files changed

+55
-62
lines changed

IncrementalInference/src/Deprecated.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ function solveGraphParametric2(
120120
flatvar = FlatVariables(fg, varIds)
121121

122122
for vId in varIds
123-
p = getVariableSolverData(fg, vId, solvekey).val[1]
123+
p = getVariableState(fg, vId, solvekey).val[1]
124124
flatvar[vId] = getCoordinates(getVariableType(fg, vId), p)
125125
end
126126

IncrementalInference/src/Serialization/services/DispatchPackedConversions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ function DFG.rebuildFactorCache!(
143143

144144
# factor__
145145
# else
146-
# setSolverData!(factor, new_solverData)
146+
# mergeVariableState!(factor, new_solverData)
147147
# DFG.setCache!(factor, solvercache)
148148
# # We're not updating here because we don't want
149149
# # to solve cloud in loop, we want to make sure this flow works:

IncrementalInference/src/parametric/services/ParametricUtils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ function initPoints!(p, gsc, fg::AbstractDFG, solveKey = :parametric)
489489
for (i, vartype) in enumerate(gsc.varTypes)
490490
varIds = gsc.varTypesIds[vartype]
491491
for (j, vId) in enumerate(varIds)
492-
p[gsc.M, i][j] = getVariableSolverData(fg, vId, solveKey).val[1]
492+
p[gsc.M, i][j] = getVariableState(fg, vId, solveKey).val[1]
493493
end
494494
end
495495
end
@@ -674,7 +674,7 @@ function solveConditionalsParametric(
674674
flatvar = FlatVariables(fg, varIds)
675675

676676
for vId in varIds
677-
p = getVariableSolverData(fg, vId, solvekey).val[1]
677+
p = getVariableState(fg, vId, solvekey).val[1]
678678
flatvar[vId] = getCoordinates(getVariableType(fg, vId), p)
679679
end
680680
initValues = flatvar.X

IncrementalInference/src/services/BayesNet.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,20 @@ function addBayesNetVerts!(dfg::AbstractDFG, elimOrder::Array{Symbol, 1})
7676
#
7777
for pId in elimOrder
7878
vert = DFG.getVariable(dfg, pId)
79-
if getVariableState(vert).BayesNetVertID === nothing ||
80-
getVariableState(vert).BayesNetVertID == :_null # Special serialization case of nothing
79+
if getVariableState(vert, :default).BayesNetVertID === nothing ||
80+
getVariableState(vert, :default).BayesNetVertID == :_null # Special serialization case of nothing
8181
@debug "[AddBayesNetVerts] Assigning $pId.data.BayesNetVertID = $pId"
82-
getVariableState(vert).BayesNetVertID = pId
82+
getVariableState(vert, :default).BayesNetVertID = pId
8383
else
84-
@warn "addBayesNetVerts -- Something is wrong, variable '$pId' should not have an existing Bayes net reference to '$(getVariableState(vert).BayesNetVertID)'"
84+
@warn "addBayesNetVerts -- Something is wrong, variable '$pId' should not have an existing Bayes net reference to '$(getVariableState(vert, :default).BayesNetVertID)'"
8585
end
8686
end
8787
end
8888

8989
function addConditional!(dfg::AbstractDFG, vertId::Symbol, Si::Vector{Symbol})
9090
#
9191
bnv = DFG.getVariable(dfg, vertId)
92-
bnvd = getVariableState(bnv)
92+
bnvd = getVariableState(bnv, :default)
9393
bnvd.separator = Si
9494
for s in Si
9595
push!(bnvd.BayesNetOutVertIDs, s)
@@ -188,7 +188,7 @@ function buildBayesNet!(dfg::AbstractDFG, elimorder::Vector{Symbol}; solvable::I
188188
end
189189

190190
# mark variable
191-
getVariableState(vert).eliminated = true
191+
getVariableState(vert, :default).eliminated = true
192192

193193
# TODO -- remove links from current vertex to any marginals
194194
rmVarFromMarg(dfg, vert, gm)

IncrementalInference/src/services/FGOSUtils.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,19 +168,19 @@ joinLogPath(dfg::AbstractDFG, str...) = joinLogPath(getSolverParams(dfg), str...
168168
169169
Set variable(s) `sym` of factor graph to be marginalized -- i.e. not be updated by inference computation.
170170
"""
171-
function setfreeze!(dfg::AbstractDFG, sym::Symbol)
171+
function setfreeze!(dfg::AbstractDFG, sym::Symbol, solveKey::Symbol = :default)
172172
if !isInitialized(dfg, sym)
173173
@warn "Vertex $(sym) is not initialized, and won't be frozen at this time."
174174
return nothing
175175
end
176176
vert = DFG.getVariable(dfg, sym)
177-
data = getVariableState(vert)
177+
data = getVariableState(vert, solveKey)
178178
data.ismargin = true
179179
return nothing
180180
end
181-
function setfreeze!(dfg::AbstractDFG, syms::Vector{Symbol})
181+
function setfreeze!(dfg::AbstractDFG, syms::Vector{Symbol}, solveKey::Symbol = :default)
182182
for sym in syms
183-
setfreeze!(dfg, sym)
183+
setfreeze!(dfg, sym, solveKey)
184184
end
185185
end
186186

@@ -357,10 +357,10 @@ Reset initialization flag on all variables in `::AbstractDFG`.
357357
Notes
358358
- Numerical values remain, but inference will overwrite since init flags are now `false`.
359359
"""
360-
function resetVariableAllInitializations!(fgl::AbstractDFG)
360+
function resetVariableAllInitializations!(fgl::AbstractDFG, solveKey::Symbol = :default)
361361
vsyms = ls(fgl)
362362
for sym in vsyms
363-
setVariableInitialized!(getVariable(fgl, sym), :false)
363+
setVariableInitialized!(getVariable(fgl, sym), solveKey, false)
364364
end
365365
return nothing
366366
end

IncrementalInference/src/services/FactorGraph.jl

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -302,12 +302,11 @@ end
302302
Set variable initialized status.
303303
"""
304304
function setVariableInitialized!(varid::VariableNodeData, status::Bool)
305-
#
306305
return varid.initialized = status
307306
end
308-
#TODO why no solveKey
309-
function setVariableInitialized!(vari::DFGVariable, status::Bool)
310-
return setVariableInitialized!(getVariableState(vari), status)
307+
308+
function setVariableInitialized!(vari::DFGVariable, solveKey::Symbol, status::Bool)
309+
return setVariableInitialized!(getVariableState(vari, solveKey), status)
311310
end
312311

313312
"""
@@ -348,7 +347,7 @@ end
348347
349348
Reset the solve state of a variable to uninitialized/unsolved state.
350349
"""
351-
function resetVariable!(varid::VariableNodeData; solveKey::Symbol = :default)::Nothing
350+
function resetVariable!(varid::VariableNodeData)
352351
#
353352
val = getBelief(varid)
354353
pts = AMP.getPoints(val)
@@ -363,17 +362,12 @@ function resetVariable!(varid::VariableNodeData; solveKey::Symbol = :default)::N
363362
return nothing
364363
end
365364

366-
function resetVariable!(vari::DFGVariable; solveKey::Symbol = :default)
367-
return resetVariable!(getVariableState(vari); solveKey = solveKey)
365+
function resetVariable!(vari::DFGVariable, solveKey::Symbol = :default)
366+
return resetVariable!(getVariableState(vari, solveKey))
368367
end
369368

370-
function resetVariable!(
371-
dfg::G,
372-
sym::Symbol;
373-
solveKey::Symbol = :default,
374-
)::Nothing where {G <: AbstractDFG}
375-
#
376-
return resetVariable!(getVariable(dfg, sym); solveKey = solveKey)
369+
function resetVariable!(dfg::AbstractDFG, sym::Symbol, solveKey::Symbol = :default)
370+
return resetVariable!(getVariableState(dfg, sym, solveKey))
377371
end
378372

379373
# return VariableNodeData
@@ -438,7 +432,7 @@ function setDefaultNodeDataParametric!(
438432
kwargs...,
439433
)
440434
vnd = DefaultNodeDataParametric(0, variableType |> getDimension, variableType; solveKey, kwargs...)
441-
setSolverData!(v, vnd, solveKey)
435+
mergeVariableState!(v, vnd)
442436
nothing
443437
end
444438

@@ -485,7 +479,7 @@ function setDefaultNodeData!(
485479
(val, bw)
486480
end
487481
# make and set the new solverData
488-
setSolverData!(
482+
mergeVariableState!(
489483
v,
490484
VariableNodeData(varType;
491485
id=nothing,
@@ -504,8 +498,7 @@ function setDefaultNodeData!(
504498
# 0,
505499
# 0,
506500
solveKey,
507-
),
508-
solveKey,
501+
)
509502
)
510503
return nothing
511504
end
@@ -556,7 +549,7 @@ function setVariableRefence!(
556549
)
557550
#
558551
# set the value in the DFGVariable
559-
return setSolverData!(var, vnd, refKey)
552+
return mergeVariableState!(var, vnd)
560553
end
561554

562555
# get instance from variableType

IncrementalInference/src/services/GraphInit.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@ function doautoinit!(
179179
# Update the estimates (longer DFG function used so cloud is also updated)
180180
setVariablePosteriorEstimates!(dfg, xi.label, solveKey)
181181
# Update the data in the event that it's not local
182-
# TODO perhaps usecopy=false
183-
updateVariableSolverData!(dfg, xi, solveKey, true; warn_if_absent = false)
182+
# TODO perhaps use merge, but keeping to deepcopy as update variant used was set to copy.
183+
DFG.copytoVariableState!(dfg, xi.label, solveKey, getVariableState(xi, solveKey))
184184
# deepcopy graphinit value, see IIF #612
185185
DFG.copytoVariableState!(
186186
dfg,

IncrementalInference/src/services/JunctionTreeUtils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ function identifyFirstEliminatedSeparator(
403403
dfg::AbstractDFG,
404404
elimorder::Vector{Symbol},
405405
firvert::DFGVariable,
406-
Sj = getVariableState(firvert).separator,
406+
Sj = getVariableState(firvert, :default).separator,
407407
)::DFGVariable
408408
#
409409
firstelim = (2^(Sys.WORD_SIZE - 1) - 1)
@@ -440,7 +440,7 @@ function newPotential(
440440
) where {G <: AbstractDFG}
441441
firvert = DFG.getVariable(dfg, var)
442442
# no parent
443-
if (length(getVariableState(firvert).separator) == 0)
443+
if (length(getVariableState(firvert, :default).separator) == 0)
444444
# if (length(getCliques(tree)) == 0)
445445
# create new root
446446
addClique!(tree, dfg, var)
@@ -451,7 +451,7 @@ function newPotential(
451451
# end
452452
else
453453
# find parent clique Cp that containts the first eliminated variable of Sj as frontal
454-
Sj = getVariableState(firvert).separator
454+
Sj = getVariableState(firvert, :default).separator
455455
felbl = identifyFirstEliminatedSeparator(dfg, elimorder, firvert, Sj).label
456456
# get clique id of first eliminated frontal
457457
CpID = tree.frontals[felbl]
@@ -888,7 +888,7 @@ can be constructed.
888888
"""
889889
function resetFactorGraphNewTree!(dfg::AbstractDFG)
890890
for v in DFG.getVariables(dfg)
891-
resetData!(getVariableState(v))
891+
resetData!(getVariableState(v, :default))
892892
end
893893
for f in DFG.getFactors(dfg)
894894
resetData!(DFG.getFactorState(f))

IncrementalInference/src/services/TreeMessageUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function resetCliqSolve!(
3636
cda = getCliqueData(cliq)
3737
vars = getCliqVarIdsAll(cliq)
3838
for varis in vars
39-
resetVariable!(dfg, varis; solveKey = solveKey)
39+
resetVariable!(dfg, varis, solveKey)
4040
end
4141
# TODO remove once consolidation with upMsgs is done
4242
putCliqueMsgUp!(cda, LikelihoodMessage())

IncrementalInference/test/testSpecialEuclidean2Mani.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ p = addFactor!(fg, [:x0], mp)
4848
doautoinit!(fg, :x0)
4949

5050
##
51-
vnd = getVariableSolverData(fg, :x0)
51+
vnd = getVariableState(fg, :x0, :default)
5252
@test all(isapprox.(mean(vnd.val), ArrayPartition(SA[0.0,0.0], SA[1.0 0.0; 0.0 1.0]), atol=0.1))
5353
@test all(is_point.(Ref(M), vnd.val))
5454

@@ -59,7 +59,7 @@ f = addFactor!(fg, [:x0, :x1], mf)
5959

6060
doautoinit!(fg, :x1)
6161

62-
vnd = getVariableSolverData(fg, :x1)
62+
vnd = getVariableState(fg, :x1, :default)
6363
@test all(isapprox(M, mean(M,vnd.val), ArrayPartition(SA[1.0,2.0], SA[0.7071 -0.7071; 0.7071 0.7071]), atol=0.1))
6464
@test all(is_point.(Ref(M), vnd.val))
6565

@@ -68,11 +68,11 @@ smtasks = Task[]
6868
solveTree!(fg; smtasks, verbose=true) #, recordcliqs=ls(fg))
6969
# hists = fetchCliqHistoryAll!(smtasks);
7070

71-
vnd = getVariableSolverData(fg, :x0)
71+
vnd = getVariableState(fg, :x0, :default)
7272
@test all(isapprox.(mean(vnd.val), ArrayPartition(SA[0.0,0.0], SA[1.0 0.0; 0.0 1.0]), atol=0.1))
7373
@test all(is_point.(Ref(M), vnd.val))
7474

75-
vnd = getVariableSolverData(fg, :x1)
75+
vnd = getVariableState(fg, :x1, :default)
7676
@test all(isapprox.(mean(vnd.val), ArrayPartition(SA[1.0,2.0], SA[0.7071 -0.7071; 0.7071 0.7071]), atol=0.1))
7777
@test all(is_point.(Ref(M), vnd.val))
7878

@@ -194,13 +194,13 @@ addFactor!(fg, [:x6; :l1], mf)
194194
smtasks = Task[]
195195
solveTree!(fg; smtasks);
196196

197-
vnd = getVariableSolverData(fg, :x0)
197+
vnd = getVariableState(fg, :x0, :default)
198198
@test isapprox(M, mean(M, vnd.val), ArrayPartition([10.0,10.0], [-1.0 0.0; 0.0 -1.0]), atol=0.2)
199199

200-
vnd = getVariableSolverData(fg, :x1)
200+
vnd = getVariableState(fg, :x1, :default)
201201
@test isapprox(M, mean(M, vnd.val), ArrayPartition([0.0,10.0], [-0.5 0.866; -0.866 -0.5]), atol=0.4)
202202

203-
vnd = getVariableSolverData(fg, :x6)
203+
vnd = getVariableState(fg, :x6, :default)
204204
@test isapprox(M, mean(M, vnd.val), ArrayPartition([10.0,10.0], [-1.0 0.0; 0.0 -1.0]), atol=0.5)
205205

206206
## Special test for manifold based messages
@@ -310,18 +310,18 @@ f = addFactor!(fg, [:x0, :x1], mf)
310310

311311
doautoinit!(fg, :x1)
312312

313-
vnd = getVariableSolverData(fg, :x1)
313+
vnd = getVariableState(fg, :x1, :default)
314314
@test all(isapprox.(mean(vnd.val), [1.0,2.0], atol=0.1))
315315

316316
##
317317
smtasks = Task[]
318318
solveTree!(fg; smtasks, verbose=true, recordcliqs=ls(fg))
319319
# # hists = fetchCliqHistoryAll!(smtasks);
320320

321-
vnd = getVariableSolverData(fg, :x0)
321+
vnd = getVariableState(fg, :x0, :default)
322322
@test isapprox(mean(getManifold(fg,:x0),vnd.val), ArrayPartition([0.0,0.0], [1.0 0.0; 0.0 1.0]), atol=0.1)
323323

324-
vnd = getVariableSolverData(fg, :x1)
324+
vnd = getVariableState(fg, :x1, :default)
325325
@test all(isapprox.(mean(vnd.val), [1.0,2.0], atol=0.1))
326326

327327
##
@@ -551,7 +551,7 @@ f = addFactor!(fg, [:x0, :x1a, :x1b], mf; multihypo=[1,0.5,0.5])
551551

552552
solveTree!(fg)
553553

554-
vnd = getVariableSolverData(fg, :x0)
554+
vnd = getVariableState(fg, :x0, :default)
555555
@test isapprox(SpecialEuclidean(2; vectors=HybridTangentRepresentation()), mean(SpecialEuclidean(2; vectors=HybridTangentRepresentation()), vnd.val), ArrayPartition([0.0,0.0], [1.0 0; 0 1]), atol=0.1)
556556

557557
#FIXME I would expect close to 50% of particles to land on the correct place
@@ -622,7 +622,7 @@ f = addFactor!(fg, [:x0, :x1a, :x1b], mf; multihypo=[1,0.5,0.5])
622622

623623
solveTree!(fg)
624624

625-
vnd = getVariableSolverData(fg, :x0)
625+
vnd = getVariableState(fg, :x0, :default)
626626
@test isapprox(SpecialEuclidean(2; vectors=HybridTangentRepresentation()), mean(SpecialEuclidean(2; vectors=HybridTangentRepresentation()), vnd.val), ArrayPartition([0.0,0.0], [1.0 0; 0 1]), atol=0.1)
627627

628628
#FIXME I would expect close to 50% of particles to land on the correct place

0 commit comments

Comments
 (0)