Skip to content

Commit 9af49c2

Browse files
committed
fixes, debugging, wip on derel test
1 parent bb689ff commit 9af49c2

File tree

8 files changed

+225
-37
lines changed

8 files changed

+225
-37
lines changed

ext/IncrInfrDiffEqFactorExt.jl

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ import DifferentialEquations: solve
88
using Dates
99

1010
using IncrementalInference
11-
import IncrementalInference: getSample, getManifold, DERelative
12-
import IncrementalInference: sampleFactor
11+
import IncrementalInference: DERelative, _solveFactorODE!
12+
import IncrementalInference: getSample, sampleFactor, getManifold
1313

1414
using DocStringExtensions
1515

@@ -174,12 +174,12 @@ function (cf::CalcFactor{<:DERelative})(measurement, X...)
174174
# need to recalculate new ODE (forward) for change in parameters (solving for 3rd or higher variable)
175175
solveforIdx = 2
176176
# use forward solve for all solvefor not in [1;2]
177-
u0pts = getBelief(cf.fullvariables[1]) |> getPoints
177+
# u0pts = getBelief(cf.fullvariables[1]) |> getPoints
178178
# update parameters for additional variables
179179
_solveFactorODE!(
180180
meas1,
181181
oderel.forwardProblem,
182-
u0pts[cf._sampleIdx],
182+
X[1], # u0pts[cf._sampleIdx],
183183
_maketuplebeyond2args(X...)...,
184184
)
185185
end
@@ -192,13 +192,52 @@ function (cf::CalcFactor{<:DERelative})(measurement, X...)
192192
#FIXME
193193
res = zeros(size(X[2], 1))
194194
for i = 1:size(X[2], 1)
195-
# diffop( test, reference ) <===> ΔX = test \ reference
195+
# diffop( reference?, test? ) <===> ΔX = test \ reference
196196
res[i] = diffOp[i](X[solveforIdx][i], meas1[i])
197197
end
198198
return res
199199
end
200200

201201

202+
# # FIXME see #1025, `multihypo=` will not work properly yet
203+
# function getSample(cf::CalcFactor{<:DERelative})
204+
205+
# oder = cf.factor
206+
207+
# # how many trajectories to propagate?
208+
# # @show getLabel(cf.fullvariables[2]), getDimension(cf.fullvariables[2])
209+
# meas = zeros(getDimension(cf.fullvariables[2]))
210+
211+
# # pick forward or backward direction
212+
# # set boundary condition
213+
# u0pts = if cf.solvefor == 1
214+
# # backward direction
215+
# prob = oder.backwardProblem
216+
# addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
217+
# convert(Tuple, getManifold(getVariableType(cf.fullvariables[1]))),
218+
# )
219+
# cf._legacyParams[2]
220+
# else
221+
# # forward backward
222+
# prob = oder.forwardProblem
223+
# # buffer manifold operations for use during factor evaluation
224+
# addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
225+
# convert(Tuple, getManifold(getVariableType(cf.fullvariables[2]))),
226+
# )
227+
# cf._legacyParams[1]
228+
# end
229+
230+
# i = cf._sampleIdx
231+
# # solve likely elements
232+
# # TODO, does this respect hyporecipe ???
233+
# idxArr = (k -> cf._legacyParams[k][i]).(1:length(cf._legacyParams))
234+
# _solveFactorODE!(meas, prob, u0pts[i], _maketuplebeyond2args(idxArr...)...)
235+
# # _solveFactorODE!(meas, prob, u0pts, i, _maketuplebeyond2args(cf._legacyParams...)...)
236+
237+
# return meas, diffOp
238+
# end
239+
240+
202241

203242

204243
## =========================================================================
@@ -221,15 +260,17 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
221260
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
222261
convert(Tuple, getManifold(getVariableType(cf.fullvariables[1]))),
223262
)
224-
getBelief(cf.fullvariables[2]) |> getPoints
263+
# getBelief(cf.fullvariables[2]) |> getPoints
264+
cf._legacyParams[2]
225265
else
226266
# forward backward
227267
prob = oder.forwardProblem
228268
# buffer manifold operations for use during factor evaluation
229269
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
230270
convert(Tuple, getManifold(getVariableType(cf.fullvariables[2]))),
231271
)
232-
getBelief(cf.fullvariables[1]) |> getPoints
272+
# getBelief(cf.fullvariables[1]) |> getPoints
273+
cf._legacyParams[1]
233274
end
234275

235276
# solve likely elements

ext/WeakDepsPrototypes.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
function _ccolamd! end
44
function _ccolamd end
55

6+
# DiffEq
7+
function _solveFactorODE! end
8+
69
# Flux.jl
710
function MixtureFluxModels end
811

src/services/CalcFactor.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ end
2525
function CalcFactor(
2626
ccwl::CommonConvWrapper;
2727
factor = ccwl.usrfnc!,
28-
_sampleIdx = 0,
28+
_sampleIdx = ccwl.particleidx[],
2929
_legacyParams = ccwl.varValsAll[],
3030
_allowThreads = true,
3131
cache = ccwl.dummyCache,
@@ -399,7 +399,7 @@ function _createCCW(
399399
# MeasType = Vector{Float64} # FIXME use `usrfnc` to get this information instead
400400
_cf = CalcFactor(
401401
usrfnc,
402-
0,
402+
1,
403403
_varValsAll,
404404
false,
405405
userCache,

src/services/FactorGraph.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ function setValKDE!(
206206
) where {P}
207207
# recover variableType information
208208
setValKDE!(getSolverData(v, solveKey), val, setinit, ipc)
209+
# TODO setPPE!
209210
return nothing
210211
end
211212
function setValKDE!(
@@ -246,7 +247,7 @@ end
246247

247248
function setValKDE!(
248249
vnd::VariableNodeData,
249-
mkd::ManifoldKernelDensity{M, B, Nothing},
250+
mkd::ManifoldKernelDensity{M, B, Nothing}, # TBD dispatch without partial?
250251
setinit::Bool = true,
251252
ipc::AbstractVector{<:Real} = [0.0;],
252253
) where {M, B}

src/services/GraphProductOperations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function propagateBelief(
2929

3030
# get proposal beliefs
3131
destlbl = getLabel(destvar)
32-
ipc = proposalbeliefs!(dfg, destlbl, factors, dens; solveKey = solveKey, N = N, dbg = dbg)
32+
ipc = proposalbeliefs!(dfg, destlbl, factors, dens; solveKey, N, dbg)
3333

3434
# @show dens[1].manifold
3535

@@ -52,7 +52,7 @@ function propagateBelief(
5252
M;
5353
Niter = 1,
5454
oldPoints = oldpts,
55-
N = N,
55+
N,
5656
u0 = getPointDefault(varType),
5757
)
5858

src/services/SolveTree.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ function doFMCIteration(
7070
logger,
7171
)
7272

73-
if 0 < length(getPoints(dens))
73+
if 0 < Npts(dens)
7474
setBelief!(vert, dens, true, ipc)
7575
# setValKDE!(vert, densPts, true, ipc)
7676
# TODO perhaps more debugging inside `propagateBelief`?

test/testBasicGraphs.jl

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,15 +306,57 @@ pts_ = getPoints(getBelief(fg, :x4))
306306
TensorCast.@cast pts[i,j] := pts_[j][i]
307307
@test 0.2 < Statistics.cov( pts[1,:] ) < 3.2
308308

309+
310+
309311
@testset "Test localProduct on solveKey" begin
310312

311313
localProduct(fg,:x2)
312-
313314
localProduct(fg,:x2, solveKey=:graphinit)
314315

315316
end
316317

318+
end
319+
320+
321+
##
322+
@testset "consistency check on more factors (origin is a DERelative fail case)" begin
323+
##
324+
325+
fg = initfg()
326+
327+
addVariable!(fg, :x0, Position{1})
328+
addFactor!(fg, [:x0], Prior(Normal(1.0, 0.01)))
317329

330+
# force a basic setup
331+
initAll!(fg)
332+
@test isapprox( 1, getPPE(fg, :x0).suggested[1]; atol=0.1)
333+
334+
##
335+
336+
addVariable!(fg, :x1, Position{1})
337+
addFactor!(fg, [:x0;:x1], LinearRelative(Normal(1.0, 0.01)))
338+
339+
addVariable!(fg, :x2, Position{1})
340+
addFactor!(fg, [:x1;:x2], LinearRelative(Normal(1.0, 0.01)))
341+
342+
addVariable!(fg, :x3, Position{1})
343+
addFactor!(fg, [:x2;:x3], LinearRelative(Normal(1.0, 0.01)))
344+
345+
##
346+
347+
tree = solveGraph!(fg)
348+
349+
##
350+
351+
@test isapprox( 1, getPPE(fg, :x0).suggested[1]; atol=0.1)
352+
@test isapprox( 4, getPPE(fg, :x3).suggested[1]; atol=0.3)
353+
354+
## check contents of tree messages
355+
356+
tree[1]
357+
msg1 = IIF.getMessageBuffer(tree[1])
358+
359+
##
318360
end
319361

320362

0 commit comments

Comments
 (0)