@@ -23,16 +23,21 @@ function Trace(f, rng::TracedRNG)
23
23
end
24
24
25
25
function Trace (f, ctask:: Libtask.CTask )
26
- rng = TracedRNG ()
27
- return Trace (f, ctask, rng)
26
+ return Trace (f, ctask, TracedRNG ())
28
27
end
29
28
30
- # Copy task and RNG
31
- Base. copy (trace:: Trace ) = Trace (trace. f, copy (trace. ctask))
29
+ # Copy task
30
+ Base. copy (trace:: Trace ) = Trace (trace. f, copy (trace. ctask), deepcopy (trace . rng) )
32
31
33
32
# step to the next observe statement and
34
33
# return the log probability of the transition (or nothing if done)
35
- advance! (t:: Trace ) = Libtask. consume (t. ctask)
34
+ function advance! (t:: Trace , isref:: Bool )
35
+ isref ? reset_rng! (t. rng) : save_state! (t. rng)
36
+ inc_count! (t. rng)
37
+
38
+ # Move to next step
39
+ return Libtask. consume (t. ctask)
40
+ end
36
41
37
42
# reset log probability
38
43
reset_logprob! (t:: Trace ) = nothing
55
60
# Create new task and copy randomness
56
61
function forkr (trace:: Trace )
57
62
newf = reset_model (trace. f)
63
+ set_count! (trace. rng, 1 )
58
64
59
65
ctask = let f = trace. ctask. task. code
60
66
Libtask. CTask () do
@@ -89,23 +95,15 @@ Data structure for particle filters
89
95
- normalise!(pc::ParticleContainer)
90
96
- consume(pc::ParticleContainer): return incremental likelihood
91
97
"""
92
- mutable struct ParticleContainer{T<: Particle ,R <: Random.AbstractRNG }
98
+ mutable struct ParticleContainer{T<: Particle }
93
99
" Particles."
94
100
vals:: Vector{T}
95
101
" Unnormalized logarithmic weights."
96
102
logWs:: Vector{Float64}
97
- " TracedRNG to track the resampling step"
98
- rng:: TracedRNG{R}
99
103
end
100
104
101
105
function ParticleContainer (particles:: Vector{<:Particle} )
102
- return ParticleContainer (particles, zeros (length (particles)), TracedRNG ())
103
- end
104
-
105
- function ParticleContainer (
106
- particles:: Vector{<:Particle} , rng:: T
107
- ) where {T<: Random.AbstractRNG }
108
- return ParticleContainer (particles, zeros (length (particles)), TracedRNG (rng))
106
+ return ParticleContainer (particles, zeros (length (particles)))
109
107
end
110
108
111
109
Base. collect (pc:: ParticleContainer ) = pc. vals
@@ -132,7 +130,7 @@ function Base.copy(pc::ParticleContainer)
132
130
# copy weights
133
131
logWs = copy (pc. logWs)
134
132
135
- return ParticleContainer (vals, logWs, pc . rng )
133
+ return ParticleContainer (vals, logWs)
136
134
end
137
135
138
136
"""
@@ -231,9 +229,12 @@ function resample_propagate!(
231
229
p = isref ? fork (pi , isref) : pi
232
230
children[j += 1 ] = p
233
231
232
+ seeds = split (pi . rng, ni)
234
233
# fork additional children
235
- for _ in 2 : ni
236
- children[j += 1 ] = fork (p, isref)
234
+ for k in 2 : ni
235
+ part = fork (p, isref)
236
+ seed! (part. rng, seeds[k])
237
+ children[j += 1 ] = part
237
238
end
238
239
end
239
240
end
274
275
Check if the final time step is reached, and otherwise reweight the particles by
275
276
considering the next observation.
276
277
"""
277
- function reweight! (pc:: ParticleContainer )
278
+ function reweight! (pc:: ParticleContainer , ref :: Union{Particle,Nothing} = nothing )
278
279
n = length (pc)
279
280
280
281
particles = collect (pc)
@@ -286,7 +287,8 @@ function reweight!(pc::ParticleContainer)
286
287
# the execution of the model is finished.
287
288
# Here ``yᵢ`` are observations, ``xᵢ`` variables of the particle filter, and
288
289
# ``θᵢ`` are variables of other samplers.
289
- score = advance! (p)
290
+ isref = p === ref
291
+ score = advance! (p, isref)
290
292
291
293
if score === nothing
292
294
numdone += 1
@@ -337,7 +339,6 @@ function sweep!(
337
339
ref:: Union{Particle,Nothing} = nothing ,
338
340
)
339
341
# Initial step:
340
-
341
342
# Resample and propagate particles.
342
343
resample_propagate! (rng, pc, resampler, ref)
343
344
@@ -349,7 +350,7 @@ function sweep!(
349
350
logZ0 = logZ (pc)
350
351
351
352
# Reweight the particles by including the first observation ``y₁``.
352
- isdone = reweight! (pc)
353
+ isdone = reweight! (pc, ref )
353
354
354
355
# Compute the normalizing constant ``Z₁`` after reweighting.
355
356
logZ1 = logZ (pc)
@@ -367,7 +368,7 @@ function sweep!(
367
368
logZ0 = logZ (pc)
368
369
369
370
# Reweight the particles by including the next observation ``yₜ``.
370
- isdone = reweight! (pc)
371
+ isdone = reweight! (pc, ref )
371
372
372
373
# Compute the normalizing constant ``Z₁`` after reweighting.
373
374
logZ1 = logZ (pc)
0 commit comments