@@ -605,7 +605,7 @@ function calcZDim(cf::CalcFactor{T}) where {T <: AbstractFactor}
605605 M = getManifold (cf. factor)
606606 return manifold_dimension (M)
607607 catch
608- @warn " no method getManifold(::$T ), calcZDim will attempt legacy length(sample) method instead"
608+ @warn " no method getManifold(::$( string (T)) ), calcZDim will attempt legacy length(sample) method instead"
609609 end
610610 end
611611
@@ -620,34 +620,51 @@ calcZDim(cf::CalcFactor{<:GenericMarginal}) = 0
620620
621621calcZDim (cf:: CalcFactor{<:ManifoldPrior} ) = manifold_dimension (cf. factor. M)
622622
623+ # return a BitVector masking the fractional portion, assuming converted 0's on 100% confident variables
624+ _getFractionalVars (varList:: Union{<:Tuple, <:AbstractVector} , mh:: Nothing ) = zeros (length (varList)) .== 1
625+ _getFractionalVars (varList:: Union{<:Tuple, <:AbstractVector} , mh:: Categorical ) = 0 .< mh. p
626+
627+ function _selectHypoVariables (allVars:: Union{<:Tuple, <:AbstractVector} ,
628+ mh:: Categorical ,
629+ sel:: Integer = rand (mh) )
630+ #
631+ mask = mh. p .≈ 0.0
632+ mask[sel] = true
633+ (1 : length (allVars))[mask]
634+ end
635+
636+ _selectHypoVariables (allVars:: Union{<:Tuple, <:AbstractVector} ,mh:: Nothing ,sel:: Integer = 0 ) = collect (1 : length (allVars))
637+
623638
624639function prepgenericconvolution (Xi:: Vector{<:DFGVariable} ,
625640 usrfnc:: T ;
626641 multihypo:: Union{Nothing, Distributions.Categorical} = nothing ,
627642 nullhypo:: Real = 0.0 ,
628643 threadmodel= MultiThreaded,
629- inflation:: Real = 0.0 ) where {T <: FunctorInferenceType }
644+ inflation:: Real = 0.0 ,
645+ _blockRecursion:: Bool = false ) where {T <: AbstractFactor }
630646 #
631647 pttypes = getVariableType .(Xi) .| > getPointType
632648 PointType = 0 < length (pttypes) ? pttypes[1 ] : Vector{Float64}
633649 # FIXME stop using Any, see #1321
634- varParams = Vector {Vector{Any}} ()
635- maxlen, sfidx, mani = prepareparamsarray! (varParams , Xi, nothing , 0 ) # Nothing for init.
650+ varParamsAll = Vector {Vector{Any}} ()
651+ maxlen, sfidx, mani = prepareparamsarray! (varParamsAll , Xi, nothing , 0 ) # Nothing for init.
636652
637653 # standard factor metadata
638654 sflbl = 0 == length (Xi) ? :null : getLabel (Xi[end ])
639- fmd = FactorMetadata (Xi, getLabel .(Xi), varParams , sflbl, nothing )
655+ fmd = FactorMetadata (Xi, getLabel .(Xi), varParamsAll , sflbl, nothing )
640656
641657 # create a temporary CalcFactor object for extracting the first sample
642658 # TODO , deprecate this: guess measurement points type
643659 # MeasType = Vector{Float64} # FIXME use `usrfnc` to get this information instead
644- _cf = CalcFactor ( usrfnc, fmd, 0 , 1 , nothing , varParams ) # (Vector{MeasType}(),)
660+ _cf = CalcFactor ( usrfnc, fmd, 0 , 1 , nothing , varParamsAll ) # (Vector{MeasType}(),)
645661
646662 # get a measurement sample
647663 meas_single = sampleFactor (_cf, 1 )
648664
665+ # get the measurement dimension
649666 zdim = calcZDim (_cf)
650- # zdim = T != GenericMarginal ? size(getSample(usrfnc, 2)[1],1) : 0
667+ # some hypo resolution
651668 certainhypo = multihypo != = nothing ? collect (1 : length (multihypo. p))[multihypo. p .== 0.0 ] : collect (1 : length (Xi))
652669
653670 # sort out partialDims here
@@ -662,19 +679,32 @@ function prepgenericconvolution(Xi::Vector{<:DFGVariable},
662679 varTypes:: Vector{DataType} = typeof .(getVariableType .(Xi))
663680 gradients = nothing
664681 # prepare new cached gradient lambdas (attempt)
665- # try
666- # measurement = tuple(((x->x[1]).(meas_single))...)
667- # pts = tuple(((x->x[1]).(varParams))...)
668- # gradients = FactorGradientsCached!(usrfnc, varTypes, measurement, pts);
669- # catch e
670- # @warn "Unable to create measurements and gradients for $usrfnc during prep of CCW, falling back on no-partial information assumption."
671- # end
682+ try
683+ # this try block definitely fails on deserialization, due to empty DFGVariable[] vector here:
684+ # https://github.com/JuliaRobotics/IncrementalInference.jl/blob/db7ff84225cc848c325e57b5fb9d0d85cb6c79b8/src/DispatchPackedConversions.jl#L46
685+ # also https://github.com/JuliaRobotics/DistributedFactorGraphs.jl/issues/590#issuecomment-891450762
686+ if (! _blockRecursion) && usrfnc isa AbstractRelative
687+ # take first value from each measurement-tuple-element
688+ measurement_ = map (x-> x[1 ], meas_single)
689+ # compensate if no info available during deserialization
690+ # take the first value from each variable param
691+ pts_ = map (x-> x[1 ], varParamsAll)
692+ # FIXME , only using first meas and params values at this time...
693+ # NOTE, must block recurions here, since FGC uses this function to calculate numerical gradients on a temp fg.
694+ # assume for now fractional-var in multihypo have same varType
695+ hypoidxs = _selectHypoVariables (pts_, multihypo)
696+ gradients = FactorGradientsCached! (usrfnc, tuple (varTypes[hypoidxs]. .. ), measurement_, tuple (pts_[hypoidxs]. .. ), _blockRecursion= true );
697+ end
698+ catch e
699+ @warn " Unable to create measurements and gradients for $usrfnc during prep of CCW, falling back on no-partial information assumption. Enable @debug printing to see the error."
700+ @debug (e)
701+ end
672702
673703 ccw = CommonConvWrapper (
674704 usrfnc,
675705 PointType[],
676706 zdim,
677- varParams ,
707+ varParamsAll ,
678708 fmd,
679709 specialzDim = hasfield (T, :zDim ),
680710 partial = ispartl,
@@ -707,15 +737,16 @@ function getDefaultFactorData(dfg::AbstractDFG,
707737 potentialused:: Bool = false ,
708738 edgeIDs = Int[],
709739 solveInProgress = 0 ,
710- inflation:: Real = getSolverParams (dfg). inflation ) where T <: FunctorInferenceType
740+ inflation:: Real = getSolverParams (dfg). inflation,
741+ _blockRecursion:: Bool = false ) where T <: FunctorInferenceType
711742 #
712743
713744 # prepare multihypo particulars
714745 # storeMH::Vector{Float64} = multihypo == nothing ? Float64[] : [multihypo...]
715746 mhcat, nh = parseusermultihypo (multihypo, nullhypo)
716747
717748 # allocate temporary state for convolutional operations (not stored)
718- ccw = prepgenericconvolution (Xi, usrfnc, multihypo= mhcat, nullhypo= nh, threadmodel= threadmodel, inflation= inflation)
749+ ccw = prepgenericconvolution (Xi, usrfnc, multihypo= mhcat, nullhypo= nh, threadmodel= threadmodel, inflation= inflation, _blockRecursion = _blockRecursion )
719750
720751 # and the factor data itself
721752 return FunctionNodeData {typeof(ccw)} (eliminated, potentialused, edgeIDs, ccw, multihypo, ccw. certainhypo, nullhypo, solveInProgress, inflation)
@@ -1177,7 +1208,7 @@ Experimental
11771208- `inflation`, to better disperse kernels before convolution solve, see IIF #1051.
11781209"""
11791210function DFG. addFactor! (dfg:: AbstractDFG ,
1180- Xi:: Vector {<:DFGVariable} ,
1211+ Xi:: AbstractVector {<:DFGVariable} ,
11811212 usrfnc:: AbstractFactor ;
11821213 multihypo:: Vector{Float64} = Float64[],
11831214 nullhypo:: Float64 = 0.0 ,
@@ -1188,7 +1219,8 @@ function DFG.addFactor!(dfg::AbstractDFG,
11881219 threadmodel= SingleThreaded,
11891220 suppressChecks:: Bool = false ,
11901221 inflation:: Real = getSolverParams (dfg). inflation,
1191- namestring:: Symbol = assembleFactorName (dfg, Xi) )
1222+ namestring:: Symbol = assembleFactorName (dfg, Xi),
1223+ _blockRecursion:: Bool = false )
11921224 #
11931225 # depcrecation
11941226
@@ -1199,7 +1231,8 @@ function DFG.addFactor!(dfg::AbstractDFG,
11991231 multihypo= multihypo,
12001232 nullhypo= nullhypo,
12011233 threadmodel= threadmodel,
1202- inflation= inflation)
1234+ inflation= inflation,
1235+ _blockRecursion= _blockRecursion)
12031236 newFactor = DFGFactor (Symbol (namestring),
12041237 varOrderLabels,
12051238 solverData;
@@ -1208,16 +1241,23 @@ function DFG.addFactor!(dfg::AbstractDFG,
12081241 timestamp= timestamp)
12091242 #
12101243
1211- success = DFG . addFactor! (dfg, newFactor)
1244+ success = addFactor! (dfg, newFactor)
12121245
12131246 # TODO : change this operation to update a conditioning variable
12141247 graphinit && doautoinit! (dfg, Xi, singles= false )
12151248
12161249 return newFactor
12171250end
12181251
1252+ function _checkFactorAdd (usrfnc, xisyms)
1253+ if length (xisyms) == 1 && ! (usrfnc isa AbstractPrior) && ! (usrfnc isa Mixture)
1254+ @warn (" Listing only one variable $xisyms for non-unary factor type $(typeof (usrfnc)) " )
1255+ end
1256+ nothing
1257+ end
1258+
12191259function DFG. addFactor! (dfg:: AbstractDFG ,
1220- xisyms:: Vector {Symbol} ,
1260+ xisyms:: AbstractVector {Symbol} ,
12211261 usrfnc:: AbstractFactor ;
12221262 suppressChecks:: Bool = false ,
12231263 kw... )
@@ -1233,12 +1273,12 @@ function DFG.addFactor!(dfg::AbstractDFG,
12331273 # depcrecation
12341274
12351275 # basic sanity check for unary vs n-ary
1236- if ! suppressChecks && length (xisyms) == 1 && ! (usrfnc isa AbstractPrior) && ! (usrfnc isa Mixture)
1237- @warn ( " Listing only one variable $ xisyms for non-unary factor type $( typeof (usrfnc)) " )
1276+ if ! suppressChecks
1277+ _checkFactorAdd (usrfnc, xisyms)
12381278 end
12391279
1240- variables = getVariable .(dfg, xisyms)
1241- # verts = map(vid -> DFG. getVariable(dfg, vid), xisyms)
1280+ # variables = getVariable.(dfg, xisyms)
1281+ variables = map (vid -> getVariable (dfg, vid), xisyms)
12421282 addFactor! (dfg, variables, usrfnc; suppressChecks= suppressChecks, kw... ) # multihypo=multihypo, nullhypo=nullhypo, solvable=solvable, tags=tags, graphinit=graphinit, threadmodel=threadmodel, timestamp=timestamp, inflation=inflation )
12431283end
12441284
0 commit comments