Skip to content

Commit f46feb3

Browse files
authored
Merge pull request #1331 from JuliaRobotics/21Q3/enh/relpartl
towards partial IPC on relative via gradients
2 parents db7ff84 + 43d61e7 commit f46feb3

14 files changed

+232
-121
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ The list below highlights major breaking changes, and please note that significa
1919
- Deprecating use of `ensureAllInitialized!`, use `initAll!` instead.
2020
- New helper function `randToPoints(::SamplableBelief, N=1)::Vector{P}` to help with `getSample` for cases with new `ManifoldKernelDensity` beliefs for manifolds containing points of type `P`.
2121
- Upstream `calcHelix_T` canonical generator utility from RoME.jl.
22+
- Deserialization of factors with DFG needs new API and change of solverData and CCW type in factor.
2223

2324
# Major changes in v0.24
2425

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ ApproxManifoldProducts = "0.4.11"
4747
BSON = "0.2, 0.3"
4848
Combinatorics = "1.0"
4949
DataStructures = "0.16, 0.17, 0.18"
50-
DistributedFactorGraphs = "0.15"
50+
DistributedFactorGraphs = "0.15.1"
5151
Distributions = "0.24, 0.25"
5252
DocStringExtensions = "0.8"
5353
FileIO = "1"

src/ApproxConv.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,9 +399,7 @@ function evalPotentialSpecific( Xi::AbstractVector{<:DFGVariable},
399399
mani = getManifold(getVariableType(Xi[sfidx]))
400400

401401
# perform the numeric solutions on the indicated elements
402-
# error("ccwl.xDim=$(ccwl.xDim)")
403402
# FIXME consider repeat solve as workaround for inflation off-zero
404-
# @info "EVALSPEC END" ccwl.varidx string(ccwl.params) string(allelements)
405403
computeAcrossHypothesis!( ccwl, allelements, activehypo, certainidx,
406404
sfidx, maxlen, mani, spreadNH=spreadNH,
407405
inflateCycles=inflateCycles, skipSolve=skipSolve,

src/CalcFactor.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,12 @@ function calcFactorResidualTemporary( fct::AbstractRelative,
114114
measurement::Tuple,
115115
varTypes::Tuple,
116116
pts::Tuple;
117-
tfg::AbstractDFG = initfg() )
117+
tfg::AbstractDFG = initfg(),
118+
_blockRecursion::Bool=false )
118119
#
119120

120121
# build a new temporary graph
121-
_, _dfgfct = _buildGraphByFactorAndTypes!(fct, varTypes, pts, dfg=tfg)
122+
_, _dfgfct = _buildGraphByFactorAndTypes!(fct, varTypes, pts, dfg=tfg, _blockRecursion=_blockRecursion)
122123

123124
# get a fresh measurement if needed
124125
_measurement = if length(measurement) != 0

src/Deprecated.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ end
2525
## Deprecate code below before v0.27
2626
##==============================================================================
2727

28+
@deprecate calcPerturbationFromVariable( fgc::FactorGradientsCached!, fromVar::Int, infoPerCoord::AbstractVector;tol::Real=0.02*fgc._h ) calcPerturbationFromVariable(fgc, [fromVar => infoPerCoord;], tol=tol )
29+
2830
@deprecate findRelatedFromPotential(w...;kw...) (calcProposalBelief(w...;kw...),)
2931

3032
# function generateNullhypoEntropy( val::AbstractMatrix{<:Real},

src/DispatchPackedConversions.jl

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,17 @@ end
6363
After deserializing a factor using decodePackedType, use this to
6464
completely rebuild the factor's CCW and user data.
6565
"""
66-
function rebuildFactorMetadata!(dfg::AbstractDFG{SolverParams}, factor::DFGFactor)
66+
function rebuildFactorMetadata!(dfg::AbstractDFG{SolverParams},
67+
factor::DFGFactor,
68+
neighbors = map(vId->getVariable(dfg, vId), getNeighbors(dfg, factor)) )
69+
#
6770
# Set up the neighbor data
68-
neighbors = map(vId->getVariable(dfg, vId), getNeighbors(dfg, factor))
69-
neighborUserData = map(v->getVariableType(v), neighbors)
7071

7172
# Rebuilding the CCW
7273
fsd = getSolverData(factor)
73-
ccw_new = getDefaultFactorData( dfg, neighbors, getFactorType(factor),
74+
fnd_new = getDefaultFactorData( dfg,
75+
neighbors,
76+
getFactorType(factor),
7477
multihypo=fsd.multihypo,
7578
nullhypo=fsd.nullhypo,
7679
# special inflation override
@@ -79,7 +82,29 @@ function rebuildFactorMetadata!(dfg::AbstractDFG{SolverParams}, factor::DFGFacto
7982
potentialused=fsd.potentialused,
8083
edgeIDs=fsd.edgeIDs,
8184
solveInProgress=fsd.solveInProgress)
82-
setSolverData!(factor, ccw_new)
85+
#
86+
87+
factor_ = if typeof(fnd_new) != typeof(getSolverData(factor))
88+
# must change the type of factor solver data FND{CCW{...}}
89+
# create a new factor
90+
factor__ = DFGFactor(getLabel(factor),
91+
getTimestamp(factor),
92+
factor.nstime,
93+
getTags(factor),
94+
fnd_new,
95+
getSolvable(factor),
96+
Tuple(getVariableOrder(factor)))
97+
#
98+
99+
# replace old factor in dfg with a new one
100+
deleteFactor!(dfg, factor)
101+
addFactor!(dfg, factor__)
102+
103+
factor__
104+
else
105+
setSolverData!(factor, fnd_new)
106+
factor
107+
end
83108

84109
#... Copying neighbor data into the factor?
85110
# JT TODO it looks like this is already updated in getDefaultFactorData -> prepgenericconvolution
@@ -88,7 +113,7 @@ function rebuildFactorMetadata!(dfg::AbstractDFG{SolverParams}, factor::DFGFacto
88113
# ccw_new.fnc.cpt[i].factormetadata.variableuserdata = deepcopy(neighborUserData)
89114
# end
90115

91-
return factor
116+
return factor_
92117
end
93118

94119

src/FactorGraph.jl

Lines changed: 66 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ function calcZDim(cf::CalcFactor{T}) where {T <: AbstractFactor}
605605
M = getManifold(cf.factor)
606606
return manifold_dimension(M)
607607
catch
608-
@warn "no method getManifold(::$T), calcZDim will attempt legacy length(sample) method instead"
608+
@warn "no method getManifold(::$(string(T))), calcZDim will attempt legacy length(sample) method instead"
609609
end
610610
end
611611

@@ -620,34 +620,51 @@ calcZDim(cf::CalcFactor{<:GenericMarginal}) = 0
620620

621621
calcZDim(cf::CalcFactor{<:ManifoldPrior}) = manifold_dimension(cf.factor.M)
622622

623+
# return a BitVector masking the fractional portion, assuming converted 0's on 100% confident variables
624+
_getFractionalVars(varList::Union{<:Tuple, <:AbstractVector}, mh::Nothing) = zeros(length(varList)) .== 1
625+
_getFractionalVars(varList::Union{<:Tuple, <:AbstractVector}, mh::Categorical) = 0 .< mh.p
626+
627+
function _selectHypoVariables(allVars::Union{<:Tuple, <:AbstractVector},
628+
mh::Categorical,
629+
sel::Integer = rand(mh) )
630+
#
631+
mask = mh.p .≈ 0.0
632+
mask[sel] = true
633+
(1:length(allVars))[mask]
634+
end
635+
636+
_selectHypoVariables(allVars::Union{<:Tuple, <:AbstractVector},mh::Nothing,sel::Integer=0 ) = collect(1:length(allVars))
637+
623638

624639
function prepgenericconvolution(Xi::Vector{<:DFGVariable},
625640
usrfnc::T;
626641
multihypo::Union{Nothing, Distributions.Categorical}=nothing,
627642
nullhypo::Real=0.0,
628643
threadmodel=MultiThreaded,
629-
inflation::Real=0.0 ) where {T <: FunctorInferenceType}
644+
inflation::Real=0.0,
645+
_blockRecursion::Bool=false ) where {T <: AbstractFactor}
630646
#
631647
pttypes = getVariableType.(Xi) .|> getPointType
632648
PointType = 0 < length(pttypes) ? pttypes[1] : Vector{Float64}
633649
# FIXME stop using Any, see #1321
634-
varParams = Vector{Vector{Any}}()
635-
maxlen, sfidx, mani = prepareparamsarray!(varParams, Xi, nothing, 0) # Nothing for init.
650+
varParamsAll = Vector{Vector{Any}}()
651+
maxlen, sfidx, mani = prepareparamsarray!(varParamsAll, Xi, nothing, 0) # Nothing for init.
636652

637653
# standard factor metadata
638654
sflbl = 0==length(Xi) ? :null : getLabel(Xi[end])
639-
fmd = FactorMetadata(Xi, getLabel.(Xi), varParams, sflbl, nothing)
655+
fmd = FactorMetadata(Xi, getLabel.(Xi), varParamsAll, sflbl, nothing)
640656

641657
# create a temporary CalcFactor object for extracting the first sample
642658
# TODO, deprecate this: guess measurement points type
643659
# MeasType = Vector{Float64} # FIXME use `usrfnc` to get this information instead
644-
_cf = CalcFactor( usrfnc, fmd, 0, 1, nothing, varParams) # (Vector{MeasType}(),)
660+
_cf = CalcFactor( usrfnc, fmd, 0, 1, nothing, varParamsAll) # (Vector{MeasType}(),)
645661

646662
# get a measurement sample
647663
meas_single = sampleFactor(_cf, 1)
648664

665+
# get the measurement dimension
649666
zdim = calcZDim(_cf)
650-
# zdim = T != GenericMarginal ? size(getSample(usrfnc, 2)[1],1) : 0
667+
# some hypo resolution
651668
certainhypo = multihypo !== nothing ? collect(1:length(multihypo.p))[multihypo.p .== 0.0] : collect(1:length(Xi))
652669

653670
# sort out partialDims here
@@ -662,19 +679,32 @@ function prepgenericconvolution(Xi::Vector{<:DFGVariable},
662679
varTypes::Vector{DataType} = typeof.(getVariableType.(Xi))
663680
gradients = nothing
664681
# prepare new cached gradient lambdas (attempt)
665-
# try
666-
# measurement = tuple(((x->x[1]).(meas_single))...)
667-
# pts = tuple(((x->x[1]).(varParams))...)
668-
# gradients = FactorGradientsCached!(usrfnc, varTypes, measurement, pts);
669-
# catch e
670-
# @warn "Unable to create measurements and gradients for $usrfnc during prep of CCW, falling back on no-partial information assumption."
671-
# end
682+
try
683+
# this try block definitely fails on deserialization, due to empty DFGVariable[] vector here:
684+
# https://github.com/JuliaRobotics/IncrementalInference.jl/blob/db7ff84225cc848c325e57b5fb9d0d85cb6c79b8/src/DispatchPackedConversions.jl#L46
685+
# also https://github.com/JuliaRobotics/DistributedFactorGraphs.jl/issues/590#issuecomment-891450762
686+
if (!_blockRecursion) && usrfnc isa AbstractRelative
687+
# take first value from each measurement-tuple-element
688+
measurement_ = map(x->x[1], meas_single)
689+
# compensate if no info available during deserialization
690+
# take the first value from each variable param
691+
pts_ = map(x->x[1], varParamsAll)
692+
# FIXME, only using first meas and params values at this time...
693+
# NOTE, must block recurions here, since FGC uses this function to calculate numerical gradients on a temp fg.
694+
# assume for now fractional-var in multihypo have same varType
695+
hypoidxs = _selectHypoVariables(pts_, multihypo)
696+
gradients = FactorGradientsCached!(usrfnc, tuple(varTypes[hypoidxs]...), measurement_, tuple(pts_[hypoidxs]...), _blockRecursion=true);
697+
end
698+
catch e
699+
@warn "Unable to create measurements and gradients for $usrfnc during prep of CCW, falling back on no-partial information assumption. Enable @debug printing to see the error."
700+
@debug(e)
701+
end
672702

673703
ccw = CommonConvWrapper(
674704
usrfnc,
675705
PointType[],
676706
zdim,
677-
varParams,
707+
varParamsAll,
678708
fmd,
679709
specialzDim = hasfield(T, :zDim),
680710
partial = ispartl,
@@ -707,15 +737,16 @@ function getDefaultFactorData(dfg::AbstractDFG,
707737
potentialused::Bool = false,
708738
edgeIDs = Int[],
709739
solveInProgress = 0,
710-
inflation::Real=getSolverParams(dfg).inflation ) where T <: FunctorInferenceType
740+
inflation::Real=getSolverParams(dfg).inflation,
741+
_blockRecursion::Bool=false ) where T <: FunctorInferenceType
711742
#
712743

713744
# prepare multihypo particulars
714745
# storeMH::Vector{Float64} = multihypo == nothing ? Float64[] : [multihypo...]
715746
mhcat, nh = parseusermultihypo(multihypo, nullhypo)
716747

717748
# allocate temporary state for convolutional operations (not stored)
718-
ccw = prepgenericconvolution(Xi, usrfnc, multihypo=mhcat, nullhypo=nh, threadmodel=threadmodel, inflation=inflation)
749+
ccw = prepgenericconvolution(Xi, usrfnc, multihypo=mhcat, nullhypo=nh, threadmodel=threadmodel, inflation=inflation, _blockRecursion=_blockRecursion)
719750

720751
# and the factor data itself
721752
return FunctionNodeData{typeof(ccw)}(eliminated, potentialused, edgeIDs, ccw, multihypo, ccw.certainhypo, nullhypo, solveInProgress, inflation)
@@ -1177,7 +1208,7 @@ Experimental
11771208
- `inflation`, to better disperse kernels before convolution solve, see IIF #1051.
11781209
"""
11791210
function DFG.addFactor!(dfg::AbstractDFG,
1180-
Xi::Vector{<:DFGVariable},
1211+
Xi::AbstractVector{<:DFGVariable},
11811212
usrfnc::AbstractFactor;
11821213
multihypo::Vector{Float64}=Float64[],
11831214
nullhypo::Float64=0.0,
@@ -1188,7 +1219,8 @@ function DFG.addFactor!(dfg::AbstractDFG,
11881219
threadmodel=SingleThreaded,
11891220
suppressChecks::Bool=false,
11901221
inflation::Real=getSolverParams(dfg).inflation,
1191-
namestring::Symbol = assembleFactorName(dfg, Xi) )
1222+
namestring::Symbol = assembleFactorName(dfg, Xi),
1223+
_blockRecursion::Bool=false )
11921224
#
11931225
# depcrecation
11941226

@@ -1199,7 +1231,8 @@ function DFG.addFactor!(dfg::AbstractDFG,
11991231
multihypo=multihypo,
12001232
nullhypo=nullhypo,
12011233
threadmodel=threadmodel,
1202-
inflation=inflation)
1234+
inflation=inflation,
1235+
_blockRecursion=_blockRecursion)
12031236
newFactor = DFGFactor(Symbol(namestring),
12041237
varOrderLabels,
12051238
solverData;
@@ -1208,16 +1241,23 @@ function DFG.addFactor!(dfg::AbstractDFG,
12081241
timestamp=timestamp)
12091242
#
12101243

1211-
success = DFG.addFactor!(dfg, newFactor)
1244+
success = addFactor!(dfg, newFactor)
12121245

12131246
# TODO: change this operation to update a conditioning variable
12141247
graphinit && doautoinit!(dfg, Xi, singles=false)
12151248

12161249
return newFactor
12171250
end
12181251

1252+
function _checkFactorAdd(usrfnc, xisyms)
1253+
if length(xisyms) == 1 && !(usrfnc isa AbstractPrior) && !(usrfnc isa Mixture)
1254+
@warn("Listing only one variable $xisyms for non-unary factor type $(typeof(usrfnc))")
1255+
end
1256+
nothing
1257+
end
1258+
12191259
function DFG.addFactor!(dfg::AbstractDFG,
1220-
xisyms::Vector{Symbol},
1260+
xisyms::AbstractVector{Symbol},
12211261
usrfnc::AbstractFactor;
12221262
suppressChecks::Bool=false,
12231263
kw... )
@@ -1233,12 +1273,12 @@ function DFG.addFactor!(dfg::AbstractDFG,
12331273
# depcrecation
12341274

12351275
# basic sanity check for unary vs n-ary
1236-
if !suppressChecks && length(xisyms) == 1 && !(usrfnc isa AbstractPrior) && !(usrfnc isa Mixture)
1237-
@warn("Listing only one variable $xisyms for non-unary factor type $(typeof(usrfnc))")
1276+
if !suppressChecks
1277+
_checkFactorAdd(usrfnc, xisyms)
12381278
end
12391279

1240-
variables = getVariable.(dfg, xisyms)
1241-
# verts = map(vid -> DFG.getVariable(dfg, vid), xisyms)
1280+
# variables = getVariable.(dfg, xisyms)
1281+
variables = map(vid -> getVariable(dfg, vid), xisyms)
12421282
addFactor!(dfg, variables, usrfnc; suppressChecks=suppressChecks, kw... ) # multihypo=multihypo, nullhypo=nullhypo, solvable=solvable, tags=tags, graphinit=graphinit, threadmodel=threadmodel, timestamp=timestamp, inflation=inflation )
12431283
end
12441284

src/Factors/GenericFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ export ManifoldFactor
7171
# For now, `Z` is on the tangent space in coordinates at the point used in the factor.
7272
# For groups just the lie algebra
7373
# As transition it will be easier this way, we can reevaluate
74-
struct ManifoldFactor{M <: AbstractManifold, T <: SamplableBelief} <: AbstractManifoldMinimize#AbstractFactor
74+
struct ManifoldFactor{M <: AbstractManifold, T <: SamplableBelief} <: AbstractManifoldMinimize #AbstractFactor
7575
M::M
7676
Z::T
7777
end

src/NumericalCalculations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ function _buildCalcFactorLambdaSample(ccwl::CommonConvWrapper,
199199
# slack is used to shift the residual away from the natural "zero" tension position of a factor,
200200
# this is useful when calculating factor gradients at a variety of param locations resulting in "non-zero slack" of the residual.
201201
# see `IIF.calcFactorResidualTemporary`
202+
# NOTE this minus operation assumes _slack is either coordinate or tangent vector element (not a manifold or group element)
202203
() -> cf( (_getindextuple(measurement_, smpid))..., (getindex.(varParams, smpid))... ) .- _slack
203204
end
204205

src/SolverUtilities.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,20 +165,21 @@ function _buildGraphByFactorAndTypes!(fct::AbstractFactor,
165165
_allVars::AbstractVector{Symbol} = sortDFG(ls(dfg, destPattern)),
166166
currLabel::Symbol = 0 < length(_allVars) ? _allVars[end] : Symbol(destPrefix, 0),
167167
currNumber::Integer = reverse(match(r"\d+", reverse(string(currLabel))).match) |> x->parse(Int,x),
168-
graphinit::Bool = false )
168+
graphinit::Bool = false,
169+
_blockRecursion::Bool=false )
169170
#
170171

171172
# TODO generalize beyond binary
172173
len = length(varTypes)
173-
vars = [Symbol(destPrefix, s_) for s_ in (currNumber .+ (1:len))]
174+
vars = Symbol[Symbol(destPrefix, s_) for s_ in (currNumber .+ (1:len))]
174175
for (s_, vTyp) in enumerate(varTypes)
175176
# add the necessary variables
176177
exists(dfg, vars[s_]) ? nothing : addVariable!(dfg, vars[s_], vTyp)
177178
# set the numerical values if available
178179
((0 < length(pts)) && (pts[s_] isa Nothing)) ? nothing : initManual!(dfg, vars[s_], [pts[s_],], solveKey, bw=ones(getDimension(vTyp)))
179180
end
180181
# if newFactor then add the factor on vars, else assume only one existing factor between vars
181-
_dfgfct = newFactor ? addFactor!(dfg, vars, fct, graphinit=graphinit) : getFactor(dfg, intersect((ls.(dfg, vars))...)[1] )
182+
_dfgfct = newFactor ? addFactor!(dfg, vars, fct, graphinit=graphinit, _blockRecursion=_blockRecursion) : getFactor(dfg, intersect((ls.(dfg, vars))...)[1] )
182183

183184
return dfg, _dfgfct
184185
end

0 commit comments

Comments
 (0)