Skip to content

Commit ad00c3b

Browse files
authored
Merge pull request #1774 from JuliaRobotics/23Q3/enh/derel_getsample
fixes, debugging, wip on derel test
2 parents bb689ff + 60ad0a2 commit ad00c3b

14 files changed

+377
-142
lines changed

ext/IncrInfrDiffEqFactorExt.jl

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ import DifferentialEquations: solve
88
using Dates
99

1010
using IncrementalInference
11-
import IncrementalInference: getSample, getManifold, DERelative
12-
import IncrementalInference: sampleFactor
11+
import IncrementalInference: DERelative, _solveFactorODE!
12+
import IncrementalInference: getSample, sampleFactor, getManifold
1313

1414
using DocStringExtensions
1515

@@ -174,12 +174,12 @@ function (cf::CalcFactor{<:DERelative})(measurement, X...)
174174
# need to recalculate new ODE (forward) for change in parameters (solving for 3rd or higher variable)
175175
solveforIdx = 2
176176
# use forward solve for all solvefor not in [1;2]
177-
u0pts = getBelief(cf.fullvariables[1]) |> getPoints
177+
# u0pts = getBelief(cf.fullvariables[1]) |> getPoints
178178
# update parameters for additional variables
179179
_solveFactorODE!(
180180
meas1,
181181
oderel.forwardProblem,
182-
u0pts[cf._sampleIdx],
182+
X[1], # u0pts[cf._sampleIdx],
183183
_maketuplebeyond2args(X...)...,
184184
)
185185
end
@@ -192,13 +192,52 @@ function (cf::CalcFactor{<:DERelative})(measurement, X...)
192192
#FIXME
193193
res = zeros(size(X[2], 1))
194194
for i = 1:size(X[2], 1)
195-
# diffop( test, reference ) <===> ΔX = test \ reference
195+
# diffop( reference?, test? ) <===> ΔX = test \ reference
196196
res[i] = diffOp[i](X[solveforIdx][i], meas1[i])
197197
end
198198
return res
199199
end
200200

201201

202+
# # FIXME see #1025, `multihypo=` will not work properly yet
203+
# function getSample(cf::CalcFactor{<:DERelative})
204+
205+
# oder = cf.factor
206+
207+
# # how many trajectories to propagate?
208+
# # @show getLabel(cf.fullvariables[2]), getDimension(cf.fullvariables[2])
209+
# meas = zeros(getDimension(cf.fullvariables[2]))
210+
211+
# # pick forward or backward direction
212+
# # set boundary condition
213+
# u0pts = if cf.solvefor == 1
214+
# # backward direction
215+
# prob = oder.backwardProblem
216+
# addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
217+
# convert(Tuple, getManifold(getVariableType(cf.fullvariables[1]))),
218+
# )
219+
# cf._legacyParams[2]
220+
# else
221+
# # forward backward
222+
# prob = oder.forwardProblem
223+
# # buffer manifold operations for use during factor evaluation
224+
# addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
225+
# convert(Tuple, getManifold(getVariableType(cf.fullvariables[2]))),
226+
# )
227+
# cf._legacyParams[1]
228+
# end
229+
230+
# i = cf._sampleIdx
231+
# # solve likely elements
232+
# # TODO, does this respect hyporecipe ???
233+
# idxArr = (k -> cf._legacyParams[k][i]).(1:length(cf._legacyParams))
234+
# _solveFactorODE!(meas, prob, u0pts[i], _maketuplebeyond2args(idxArr...)...)
235+
# # _solveFactorODE!(meas, prob, u0pts, i, _maketuplebeyond2args(cf._legacyParams...)...)
236+
237+
# return meas, diffOp
238+
# end
239+
240+
202241

203242

204243
## =========================================================================
@@ -221,15 +260,17 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
221260
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
222261
convert(Tuple, getManifold(getVariableType(cf.fullvariables[1]))),
223262
)
224-
getBelief(cf.fullvariables[2]) |> getPoints
263+
# getBelief(cf.fullvariables[2]) |> getPoints
264+
cf._legacyParams[2]
225265
else
226266
# forward backward
227267
prob = oder.forwardProblem
228268
# buffer manifold operations for use during factor evaluation
229269
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
230270
convert(Tuple, getManifold(getVariableType(cf.fullvariables[2]))),
231271
)
232-
getBelief(cf.fullvariables[1]) |> getPoints
272+
# getBelief(cf.fullvariables[1]) |> getPoints
273+
cf._legacyParams[1]
233274
end
234275

235276
# solve likely elements

ext/WeakDepsPrototypes.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
function _ccolamd! end
44
function _ccolamd end
55

6+
# DiffEq
7+
function _solveFactorODE! end
8+
69
# Flux.jl
710
function MixtureFluxModels end
811

src/services/ApproxConv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function approxConvBelief(
1313
)
1414
#
1515
v_trg = getVariable(dfg, target)
16-
N = N == 0 ? getNumPts(v_trg; solveKey = solveKey) : N
16+
N = N == 0 ? getNumPts(v_trg; solveKey) : N
1717
# approxConv should push its result into duplicate memory destination, NOT the variable.VND.val itself. ccw.varValsAll always points directly to variable.VND.val
1818
# points and infoPerCoord
1919

src/services/CalcFactor.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ end
2525
function CalcFactor(
2626
ccwl::CommonConvWrapper;
2727
factor = ccwl.usrfnc!,
28-
_sampleIdx = 0,
28+
_sampleIdx = ccwl.particleidx[],
2929
_legacyParams = ccwl.varValsAll[],
3030
_allowThreads = true,
3131
cache = ccwl.dummyCache,
@@ -399,7 +399,7 @@ function _createCCW(
399399
# MeasType = Vector{Float64} # FIXME use `usrfnc` to get this information instead
400400
_cf = CalcFactor(
401401
usrfnc,
402-
0,
402+
1,
403403
_varValsAll,
404404
false,
405405
userCache,

src/services/EvalFactor.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,10 +572,10 @@ function evalFactor(
572572
dfg::AbstractDFG,
573573
fct::DFGFactor,
574574
solvefor::Symbol,
575-
measurement::AbstractVector = Tuple[];
575+
measurement::AbstractVector = Tuple[]; # FIXME ensure type stable in all cases
576576
needFreshMeasurements::Bool = true,
577577
solveKey::Symbol = :default,
578-
variables = getVariable.(dfg, getVariableOrder(fct)), # because we trying to use StaticArrays, go figure
578+
variables = getVariable.(dfg, getVariableOrder(fct)), # FIXME use tuple instead for type stability
579579
N::Int = length(measurement),
580580
inflateCycles::Int = getSolverParams(dfg).inflateCycles,
581581
nullSurplus::Real = 0,

src/services/FactorGraph.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,12 @@ function setValKDE!(
203203
setinit::Bool = true,
204204
ipc::AbstractVector{<:Real} = [0.0;];
205205
solveKey::Symbol = :default,
206-
) where {P}
206+
ppeType::Type{T} = MeanMaxPPE,
207+
) where {P, T}
208+
vnd = getSolverData(v, solveKey)
207209
# recover variableType information
208-
setValKDE!(getSolverData(v, solveKey), val, setinit, ipc)
210+
setValKDE!(vnd, val, setinit, ipc)
211+
setPPE!(v; solveKey, ppeType)
209212
return nothing
210213
end
211214
function setValKDE!(
@@ -246,7 +249,7 @@ end
246249

247250
function setValKDE!(
248251
vnd::VariableNodeData,
249-
mkd::ManifoldKernelDensity{M, B, Nothing},
252+
mkd::ManifoldKernelDensity{M, B, Nothing}, # TBD dispatch without partial?
250253
setinit::Bool = true,
251254
ipc::AbstractVector{<:Real} = [0.0;],
252255
) where {M, B}
@@ -282,8 +285,15 @@ function setValKDE!(
282285
return nothing
283286
end
284287

285-
function setBelief!(vari::DFGVariable, bel::ManifoldKernelDensity, setinit::Bool=true,ipc::AbstractVector{<:Real}=[0.0;])
286-
setValKDE!(vari,getPoints(bel, false),setinit, ipc)
288+
function setBelief!(
289+
vari::DFGVariable,
290+
bel::ManifoldKernelDensity,
291+
setinit::Bool=true,
292+
ipc::AbstractVector{<:Real}=[0.0;];
293+
solveKey::Symbol = :default
294+
)
295+
setValKDE!(vari, bel, setinit, ipc; solveKey)
296+
# setValKDE!(vari,getPoints(bel, false), setinit, ipc)
287297
end
288298

289299
"""

src/services/GraphProductOperations.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,18 @@ function propagateBelief(
2929

3030
# get proposal beliefs
3131
destlbl = getLabel(destvar)
32-
ipc = proposalbeliefs!(dfg, destlbl, factors, dens; solveKey = solveKey, N = N, dbg = dbg)
32+
ipc = proposalbeliefs!(dfg, destlbl, factors, dens; solveKey, N, dbg)
3333

3434
# @show dens[1].manifold
3535

36-
# make sure oldpts has right number of points
36+
# make sure oldPoints vector has right length
3737
oldBel = getBelief(dfg, destlbl, solveKey)
38-
oldpts = if Npts(oldBel) == N
39-
getPoints(oldBel)
38+
_pts = getPoints(oldBel, false)
39+
oldPoints = if Npts(oldBel) <= N
40+
_pts[1:N]
4041
else
41-
sample(oldBel, N)[1]
42+
nn = N - length(_pts) # should be larger than 0
43+
vcat(_pts, sample(oldBel, nn))
4244
end
4345

4446
# few more data requirements
@@ -51,8 +53,8 @@ function propagateBelief(
5153
dens,
5254
M;
5355
Niter = 1,
54-
oldPoints = oldpts,
55-
N = N,
56+
oldPoints,
57+
N,
5658
u0 = getPointDefault(varType),
5759
)
5860

src/services/SolveTree.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ function doFMCIteration(
7070
logger,
7171
)
7272

73-
if 0 < length(getPoints(dens))
73+
if 0 < Npts(dens)
7474
setBelief!(vert, dens, true, ipc)
7575
# setValKDE!(vert, densPts, true, ipc)
7676
# TODO perhaps more debugging inside `propagateBelief`?

src/services/SolverUtilities.jl

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,18 @@ function mmd(
4040
nodeType::Union{InstanceType{<:InferenceVariable}, InstanceType{<:AbstractFactor}},
4141
threads::Bool = true;
4242
bw::AbstractVector{<:Real} = SA[0.001;],
43+
asPartial::Bool = true
4344
)
4445
#
45-
return mmd(getPoints(p1), getPoints(p2), nodeType, threads; bw)
46+
return mmd(getPoints(p1, asPartial), getPoints(p2, asPartial), nodeType, threads; bw)
4647
end
4748

4849
# part of consolidation, see #927
49-
function sampleFactor!(ccwl::CommonConvWrapper, N::Int; _allowThreads::Bool=true)
50+
function sampleFactor!(
51+
ccwl::CommonConvWrapper,
52+
N::Int;
53+
_allowThreads::Bool=true
54+
)
5055
#
5156

5257
# FIXME get allocations here down to 0
@@ -60,23 +65,42 @@ function sampleFactor!(ccwl::CommonConvWrapper, N::Int; _allowThreads::Bool=true
6065
return ccwl.measurement
6166
end
6267

63-
function sampleFactor(ccwl::CommonConvWrapper, N::Int; _allowThreads::Bool=true)
68+
function sampleFactor(
69+
ccwl::CommonConvWrapper,
70+
N::Int;
71+
_allowThreads::Bool=true
72+
)
6473
#
6574
cf = CalcFactor(ccwl; _allowThreads)
6675
return sampleFactor(cf, N)
6776
end
6877

69-
sampleFactor(fct::DFGFactor, N::Int = 1; _allowThreads::Bool=true) = sampleFactor(_getCCW(fct), N; _allowThreads)
78+
sampleFactor(
79+
fct::DFGFactor,
80+
N::Int = 1;
81+
_allowThreads::Bool=true
82+
) = sampleFactor(
83+
_getCCW(fct),
84+
N;
85+
_allowThreads
86+
)
7087

71-
function sampleFactor(dfg::AbstractDFG, sym::Symbol, N::Int = 1; _allowThreads::Bool=true)
88+
function sampleFactor(
89+
dfg::AbstractDFG,
90+
sym::Symbol,
91+
N::Int = 1;
92+
_allowThreads::Bool=true
93+
)
7294
#
7395
return sampleFactor(getFactor(dfg, sym), N; _allowThreads)
7496
end
7597

7698
"""
7799
$(SIGNATURES)
78100
79-
Update cliq `cliqID` in Bayes (Juction) tree `bt` according to contents of `urt` -- intended use is to update main clique after a upward belief propagation computation has been completed per clique.
101+
Update cliq `cliqID` in Bayes (Juction) tree `bt` according to contents of `urt`.
102+
Intended use is to update main clique after a upward belief propagation computation
103+
has been completed per clique.
80104
"""
81105
function updateFGBT!(
82106
fg::AbstractDFG,

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ TEST_GROUP = get(ENV, "IIF_TEST_GROUP", "all")
55
# temporarily moved to start (for debugging)
66
#...
77
if TEST_GROUP in ["all", "tmp_debug_group"]
8-
include("testSpecialOrthogonalMani.jl")
98
include("testDERelative.jl")
9+
include("testSpecialOrthogonalMani.jl")
1010
include("testMultiHypo3Door.jl")
1111
include("priorusetest.jl")
1212
end

0 commit comments

Comments
 (0)