Skip to content

Commit 43d61e7

Browse files
committed
deserialize and multihypo fixes for CCW's FGC
1 parent d1f8ca2 commit 43d61e7

File tree

4 files changed

+124
-61
lines changed

4 files changed

+124
-61
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ The list below highlights major breaking changes, and please note that significa
1919
- Deprecating use of `ensureAllInitialized!`, use `initAll!` instead.
2020
- New helper function `randToPoints(::SamplableBelief, N=1)::Vector{P}` to help with `getSample` for cases with new `ManifoldKernelDensity` beliefs for manifolds containing points of type `P`.
2121
- Upstream `calcHelix_T` canonical generator utility from RoME.jl.
22+
- Deserialization of factors with DFG needs new API and change of solverData and CCW type in factor.
2223

2324
# Major changes in v0.24
2425

src/DispatchPackedConversions.jl

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,17 @@ end
6363
After deserializing a factor using decodePackedType, use this to
6464
completely rebuild the factor's CCW and user data.
6565
"""
66-
function rebuildFactorMetadata!(dfg::AbstractDFG{SolverParams}, factor::DFGFactor)
66+
function rebuildFactorMetadata!(dfg::AbstractDFG{SolverParams},
67+
factor::DFGFactor,
68+
neighbors = map(vId->getVariable(dfg, vId), getNeighbors(dfg, factor)) )
69+
#
6770
# Set up the neighbor data
68-
neighbors = map(vId->getVariable(dfg, vId), getNeighbors(dfg, factor))
69-
neighborUserData = map(v->getVariableType(v), neighbors)
7071

7172
# Rebuilding the CCW
7273
fsd = getSolverData(factor)
73-
ccw_new = getDefaultFactorData( dfg, neighbors, getFactorType(factor),
74+
fnd_new = getDefaultFactorData( dfg,
75+
neighbors,
76+
getFactorType(factor),
7477
multihypo=fsd.multihypo,
7578
nullhypo=fsd.nullhypo,
7679
# special inflation override
@@ -79,7 +82,29 @@ function rebuildFactorMetadata!(dfg::AbstractDFG{SolverParams}, factor::DFGFacto
7982
potentialused=fsd.potentialused,
8083
edgeIDs=fsd.edgeIDs,
8184
solveInProgress=fsd.solveInProgress)
82-
setSolverData!(factor, ccw_new)
85+
#
86+
87+
factor_ = if typeof(fnd_new) != typeof(getSolverData(factor))
88+
# must change the type of factor solver data FND{CCW{...}}
89+
# create a new factor
90+
factor__ = DFGFactor(getLabel(factor),
91+
getTimestamp(factor),
92+
factor.nstime,
93+
getTags(factor),
94+
fnd_new,
95+
getSolvable(factor),
96+
Tuple(getVariableOrder(factor)))
97+
#
98+
99+
# replace old factor in dfg with a new one
100+
deleteFactor!(dfg, factor)
101+
addFactor!(dfg, factor__)
102+
103+
factor__
104+
else
105+
setSolverData!(factor, fnd_new)
106+
factor
107+
end
83108

84109
#... Copying neighbor data into the factor?
85110
# JT TODO it looks like this is already updated in getDefaultFactorData -> prepgenericconvolution
@@ -88,7 +113,7 @@ function rebuildFactorMetadata!(dfg::AbstractDFG{SolverParams}, factor::DFGFacto
88113
# ccw_new.fnc.cpt[i].factormetadata.variableuserdata = deepcopy(neighborUserData)
89114
# end
90115

91-
return factor
116+
return factor_
92117
end
93118

94119

src/FactorGraph.jl

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,21 @@ calcZDim(cf::CalcFactor{<:GenericMarginal}) = 0
620620

621621
calcZDim(cf::CalcFactor{<:ManifoldPrior}) = manifold_dimension(cf.factor.M)
622622

623+
# return a BitVector masking the fractional portion, assuming converted 0's on 100% confident variables
624+
_getFractionalVars(varList::Union{<:Tuple, <:AbstractVector}, mh::Nothing) = zeros(length(varList)) .== 1
625+
_getFractionalVars(varList::Union{<:Tuple, <:AbstractVector}, mh::Categorical) = 0 .< mh.p
626+
627+
function _selectHypoVariables(allVars::Union{<:Tuple, <:AbstractVector},
628+
mh::Categorical,
629+
sel::Integer = rand(mh) )
630+
#
631+
mask = mh.p .≈ 0.0
632+
mask[sel] = true
633+
(1:length(allVars))[mask]
634+
end
635+
636+
_selectHypoVariables(allVars::Union{<:Tuple, <:AbstractVector},mh::Nothing,sel::Integer=0 ) = collect(1:length(allVars))
637+
623638

624639
function prepgenericconvolution(Xi::Vector{<:DFGVariable},
625640
usrfnc::T;
@@ -647,8 +662,9 @@ function prepgenericconvolution(Xi::Vector{<:DFGVariable},
647662
# get a measurement sample
648663
meas_single = sampleFactor(_cf, 1)
649664

665+
# get the measurement dimension
650666
zdim = calcZDim(_cf)
651-
# zdim = T != GenericMarginal ? size(getSample(usrfnc, 2)[1],1) : 0
667+
# some hypo resolution
652668
certainhypo = multihypo !== nothing ? collect(1:length(multihypo.p))[multihypo.p .== 0.0] : collect(1:length(Xi))
653669

654670
# sort out partialDims here
@@ -664,17 +680,24 @@ function prepgenericconvolution(Xi::Vector{<:DFGVariable},
664680
gradients = nothing
665681
# prepare new cached gradient lambdas (attempt)
666682
try
683+
# this try block definitely fails on deserialization, due to empty DFGVariable[] vector here:
684+
# https://github.com/JuliaRobotics/IncrementalInference.jl/blob/db7ff84225cc848c325e57b5fb9d0d85cb6c79b8/src/DispatchPackedConversions.jl#L46
685+
# also https://github.com/JuliaRobotics/DistributedFactorGraphs.jl/issues/590#issuecomment-891450762
667686
if (!_blockRecursion) && usrfnc isa AbstractRelative
668687
# take first value from each measurement-tuple-element
669688
measurement_ = map(x->x[1], meas_single)
689+
# compensate if no info available during deserialization
670690
# take the first value from each variable param
671691
pts_ = map(x->x[1], varParamsAll)
672692
# FIXME, only using first meas and params values at this time...
673693
# NOTE, must block recurions here, since FGC uses this function to calculate numerical gradients on a temp fg.
674-
gradients = FactorGradientsCached!(usrfnc, tuple(varTypes...), measurement_, tuple(pts_...), _blockRecursion=true);
694+
# assume for now fractional-var in multihypo have same varType
695+
hypoidxs = _selectHypoVariables(pts_, multihypo)
696+
gradients = FactorGradientsCached!(usrfnc, tuple(varTypes[hypoidxs]...), measurement_, tuple(pts_[hypoidxs]...), _blockRecursion=true);
675697
end
676698
catch e
677-
@warn "Unable to create measurements and gradients for $usrfnc during prep of CCW, falling back on no-partial information assumption."
699+
@warn "Unable to create measurements and gradients for $usrfnc during prep of CCW, falling back on no-partial information assumption. Enable @debug printing to see the error."
700+
@debug(e)
678701
end
679702

680703
ccw = CommonConvWrapper(
@@ -1185,7 +1208,7 @@ Experimental
11851208
- `inflation`, to better disperse kernels before convolution solve, see IIF #1051.
11861209
"""
11871210
function DFG.addFactor!(dfg::AbstractDFG,
1188-
Xi::Vector{<:DFGVariable},
1211+
Xi::AbstractVector{<:DFGVariable},
11891212
usrfnc::AbstractFactor;
11901213
multihypo::Vector{Float64}=Float64[],
11911214
nullhypo::Float64=0.0,
@@ -1218,16 +1241,23 @@ function DFG.addFactor!(dfg::AbstractDFG,
12181241
timestamp=timestamp)
12191242
#
12201243

1221-
success = DFG.addFactor!(dfg, newFactor)
1244+
success = addFactor!(dfg, newFactor)
12221245

12231246
# TODO: change this operation to update a conditioning variable
12241247
graphinit && doautoinit!(dfg, Xi, singles=false)
12251248

12261249
return newFactor
12271250
end
12281251

1252+
function _checkFactorAdd(usrfnc, xisyms)
1253+
if length(xisyms) == 1 && !(usrfnc isa AbstractPrior) && !(usrfnc isa Mixture)
1254+
@warn("Listing only one variable $xisyms for non-unary factor type $(typeof(usrfnc))")
1255+
end
1256+
nothing
1257+
end
1258+
12291259
function DFG.addFactor!(dfg::AbstractDFG,
1230-
xisyms::Vector{Symbol},
1260+
xisyms::AbstractVector{Symbol},
12311261
usrfnc::AbstractFactor;
12321262
suppressChecks::Bool=false,
12331263
kw... )
@@ -1243,12 +1273,12 @@ function DFG.addFactor!(dfg::AbstractDFG,
12431273
# depcrecation
12441274

12451275
# basic sanity check for unary vs n-ary
1246-
if !suppressChecks && length(xisyms) == 1 && !(usrfnc isa AbstractPrior) && !(usrfnc isa Mixture)
1247-
@warn("Listing only one variable $xisyms for non-unary factor type $(typeof(usrfnc))")
1276+
if !suppressChecks
1277+
_checkFactorAdd(usrfnc, xisyms)
12481278
end
12491279

1250-
variables = getVariable.(dfg, xisyms)
1251-
# verts = map(vid -> DFG.getVariable(dfg, vid), xisyms)
1280+
# variables = getVariable.(dfg, xisyms)
1281+
variables = map(vid -> getVariable(dfg, vid), xisyms)
12521282
addFactor!(dfg, variables, usrfnc; suppressChecks=suppressChecks, kw... ) # multihypo=multihypo, nullhypo=nullhypo, solvable=solvable, tags=tags, graphinit=graphinit, threadmodel=threadmodel, timestamp=timestamp, inflation=inflation )
12531283
end
12541284

test/testSaveLoadDFG.jl

Lines changed: 52 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,73 +2,80 @@
22
using IncrementalInference
33
using Test
44

5+
##
6+
57
@testset "saving to and loading from FileDFG" begin
8+
##
9+
10+
fg = generateCanonicalFG_Kaess()
11+
addVariable!(fg, :x4, ContinuousScalar)
12+
addFactor!(fg, [:x2;:x3;:x4], LinearRelative(Normal()), multihypo=[1.0;0.6;0.4])
613

7-
fg = generateCanonicalFG_Kaess()
8-
addVariable!(fg, :x4, ContinuousScalar)
9-
addFactor!(fg, [:x2;:x3;:x4], LinearRelative(Normal()), multihypo=[1.0;0.6;0.4])
14+
saveFolder = "/tmp/dfg_test"
15+
saveDFG(fg, saveFolder)
16+
# VERSION above 1.0.x hack required since Julia 1.0 does not seem to havfunction `splitpath`
1017

11-
saveFolder = "/tmp/dfg_test"
12-
saveDFG(fg, saveFolder)
13-
# VERSION above 1.0.x hack required since Julia 1.0 does not seem to havfunction `splitpath`
14-
if v"1.1" <= VERSION
15-
retDFG = initfg()
16-
retDFG = loadDFG!(retDFG, saveFolder)
17-
Base.rm(saveFolder*".tar.gz")
18+
retDFG = initfg()
19+
retDFG = loadDFG!(retDFG, saveFolder)
20+
Base.rm(saveFolder*".tar.gz")
1821

19-
@test symdiff(ls(fg), ls(retDFG)) == []
20-
@test symdiff(lsf(fg), lsf(retDFG)) == []
22+
@test symdiff(ls(fg), ls(retDFG)) == []
23+
@test symdiff(lsf(fg), lsf(retDFG)) == []
2124

22-
@show getFactor(fg, :x2x3x4f1).solverData.multihypo
23-
@show getFactor(retDFG, :x2x3x4f1).solverData.multihypo
25+
@show getFactor(fg, :x2x3x4f1).solverData.multihypo
26+
@show getFactor(retDFG, :x2x3x4f1).solverData.multihypo
2427

25-
# check for match
26-
@test getFactor(fg, :x2x3x4f1).solverData.multihypo - getFactor(retDFG, :x2x3x4f1).solverData.multihypo |> norm < 1e-10
27-
@test getFactor(fg, :x2x3x4f1).solverData.certainhypo - getFactor(retDFG, :x2x3x4f1).solverData.certainhypo |> norm < 1e-10
28-
end
28+
# check for match
29+
@test getFactor(fg, :x2x3x4f1).solverData.multihypo - getFactor(retDFG, :x2x3x4f1).solverData.multihypo |> norm < 1e-10
30+
@test getFactor(fg, :x2x3x4f1).solverData.certainhypo - getFactor(retDFG, :x2x3x4f1).solverData.certainhypo |> norm < 1e-10
2931

32+
##
3033
end
3134

3235

3336
@testset "saving to and loading from FileDFG with nullhypo, eliminated, solveInProgress" begin
34-
fg = generateCanonicalFG_Kaess()
35-
addVariable!(fg, :x4, ContinuousScalar)
36-
addFactor!(fg, [:x2;:x3;:x4], LinearRelative(Normal()), multihypo=[1.0;0.6;0.4])
37-
addFactor!(fg, [:x1;], Prior(Normal(10,1)), nullhypo=0.5)
37+
##
38+
39+
fg = generateCanonicalFG_Kaess()
40+
addVariable!(fg, :x4, ContinuousScalar)
41+
addFactor!(fg, [:x2;:x3;:x4], LinearRelative(Normal()), multihypo=[1.0;0.6;0.4])
42+
addFactor!(fg, [:x1;], Prior(Normal(10,1)), nullhypo=0.5)
43+
44+
solveTree!(fg)
3845

39-
solveTree!(fg)
46+
#manually change a few fields to test if they are preserved
47+
fa = getFactor(fg, :x2x3x4f1)
48+
fa.solverData.eliminated = true
49+
fa.solverData.solveInProgress = 1
50+
fa.solverData.nullhypo = 0.5
4051

41-
#manually change a few fields to test if they are preserved
42-
fa = getFactor(fg, :x2x3x4f1)
43-
fa.solverData.eliminated = true
44-
fa.solverData.solveInProgress = 1
45-
fa.solverData.nullhypo = 0.5
4652

53+
saveFolder = "/tmp/dfg_test"
54+
saveDFG(fg, saveFolder)
4755

48-
saveFolder = "/tmp/dfg_test"
49-
saveDFG(fg, saveFolder)
56+
retDFG = initfg()
57+
loadDFG!(retDFG, saveFolder)
58+
Base.rm(saveFolder*".tar.gz")
5059

51-
retDFG = initfg()
52-
loadDFG!(retDFG, saveFolder)
53-
Base.rm(saveFolder*".tar.gz")
60+
@test issetequal(ls(fg), ls(retDFG))
61+
@test issetequal(lsf(fg), lsf(retDFG))
5462

55-
@test issetequal(ls(fg), ls(retDFG))
56-
@test issetequal(lsf(fg), lsf(retDFG))
63+
@show getFactor(fg, :x2x3x4f1).solverData.multihypo
64+
@show getFactor(retDFG, :x2x3x4f1).solverData.multihypo
5765

58-
@show getFactor(fg, :x2x3x4f1).solverData.multihypo
59-
@show getFactor(retDFG, :x2x3x4f1).solverData.multihypo
66+
# check for match
67+
@test isapprox(getFactor(fg, :x2x3x4f1).solverData.multihypo, getFactor(retDFG, :x2x3x4f1).solverData.multihypo)
68+
@test isapprox(getFactor(fg, :x2x3x4f1).solverData.certainhypo, getFactor(retDFG, :x2x3x4f1).solverData.certainhypo)
6069

61-
# check for match
62-
@test isapprox(getFactor(fg, :x2x3x4f1).solverData.multihypo, getFactor(retDFG, :x2x3x4f1).solverData.multihypo)
63-
@test isapprox(getFactor(fg, :x2x3x4f1).solverData.certainhypo, getFactor(retDFG, :x2x3x4f1).solverData.certainhypo)
6470

71+
fb = getFactor(retDFG, :x2x3x4f1)
72+
@test fa == fb
6573

66-
fb = getFactor(retDFG, :x2x3x4f1)
67-
@test fa == fb
74+
fa = getFactor(fg, :x1f2)
75+
fb = getFactor(retDFG, :x1f2)
6876

69-
fa = getFactor(fg, :x1f2)
70-
fb = getFactor(retDFG, :x1f2)
77+
@test fa == fb
7178

72-
@test fa == fb
79+
##
7380
end
7481

0 commit comments

Comments
 (0)