@@ -36,10 +36,10 @@ function unset_all_del!(vi::AbstractVarInfo)
36
36
return nothing
37
37
end
38
38
39
- # TODO (penelopeysm / DPPL 0.38): Figure this out
40
39
struct ParticleMCMCContext{R<: AbstractRNG } <: DynamicPPL.AbstractContext
41
40
rng:: R
42
41
end
42
+ DynamicPPL. NodeTrait (:: ParticleMCMCContext ) = DynamicPPL. IsLeaf ()
43
43
44
44
struct TracedModel{V<: AbstractVarInfo ,M<: Model ,E<: Tuple } <: AdvancedPS.AbstractGenericModel
45
45
model:: M
@@ -75,8 +75,7 @@ function AdvancedPS.delete_retained!(trace::TracedModel)
75
75
# In such a case, we need to ensure that when we continue sampling (i.e.
76
76
# the next time we hit tilde_assume!!), we don't use the values in the
77
77
# reference particle but rather sample new values.
78
- trace = Accessors. @set trace. resample = true
79
- return trace
78
+ return TracedModel (trace. model, trace. varinfo, trace. evaluator, true )
80
79
end
81
80
82
81
function AdvancedPS. reset_model (trace:: TracedModel )
@@ -309,8 +308,6 @@ function DynamicPPL.initialstep(
309
308
kwargs... ,
310
309
)
311
310
vi = DynamicPPL. setacc!! (vi, ProduceLogLikelihoodAccumulator ())
312
- # Reset the VarInfo before new sweep
313
- set_all_del! (vi)
314
311
315
312
# Create a new set of particles
316
313
num_particles = spl. alg. nparticles
@@ -348,9 +345,6 @@ function AbstractMCMC.step(
348
345
# Create reference particle for which the samples will be retained.
349
346
reference = AdvancedPS. forkr (AdvancedPS. Trace (model, vi, state. rng, false ))
350
347
351
- # For all other particles, do not retain the variables but resample them.
352
- set_all_del! (vi)
353
-
354
348
# Create a new set of particles.
355
349
num_particles = spl. alg. nparticles
356
350
x = map (1 : num_particles) do i
@@ -410,7 +404,7 @@ function get_trace_local_resampled_maybe(fallback_resampled::Bool)
410
404
catch e
411
405
e == KeyError (:task_variable ) ? nothing : rethrow (e)
412
406
end
413
- return (trace === nothing ? fallback_resampled : trace. resample):: Bool
407
+ return (trace === nothing ? fallback_resampled : trace. model . f . resample):: Bool
414
408
end
415
409
416
410
"""
@@ -479,7 +473,13 @@ function DynamicPPL.tilde_assume!!(
479
473
return x, vi
480
474
end
481
475
482
- function DynamicPPL. tilde_observe!! (:: ParticleMCMCContext , right, left, vn, vi)
476
+ function DynamicPPL. tilde_observe!! (
477
+ :: ParticleMCMCContext ,
478
+ right:: Distribution ,
479
+ left,
480
+ vn:: Union{VarName,Nothing} ,
481
+ vi:: AbstractVarInfo ,
482
+ )
483
483
arg_vi_id = objectid (vi)
484
484
vi = get_trace_local_varinfo_maybe (vi)
485
485
using_local_vi = objectid (vi) == arg_vi_id
0 commit comments