1
- struct Trace{F}
1
+ struct Trace{F,U,N,V <: Random123.AbstractR123{U} }
2
2
f:: F
3
3
ctask:: Libtask.CTask
4
+ rng:: TracedRNG{U,N,V}
4
5
end
5
6
6
7
const Particle = Trace
7
8
8
- function Trace (f)
9
+ function Trace (f, rng :: TracedRNG )
9
10
ctask = let f = f
10
11
Libtask. CTask () do
11
- res = f ()
12
+ res = f (rng )
12
13
Libtask. produce (nothing )
13
14
return res
14
15
end
15
16
end
16
17
17
18
# add backward reference
18
- newtrace = Trace (f, ctask)
19
+ newtrace = Trace (f, ctask, rng )
19
20
addreference! (ctask. task, newtrace)
20
21
21
22
return newtrace
22
23
end
23
24
24
- Base. copy (trace:: Trace ) = Trace (trace. f, copy (trace. ctask))
25
+ function Trace (f, ctask:: Libtask.CTask )
26
+ return Trace (f, ctask, TracedRNG ())
27
+ end
28
+
29
+ # Copy task
30
+ Base. copy (trace:: Trace ) = Trace (trace. f, copy (trace. ctask), deepcopy (trace. rng))
25
31
26
32
# step to the next observe statement and
27
33
# return the log probability of the transition (or nothing if done)
28
- advance! (t:: Trace ) = Libtask. consume (t. ctask)
34
+ function advance! (t:: Trace , isref:: Bool )
35
+ isref ? load_state! (t. rng) : save_state! (t. rng)
36
+ inc_counter! (t. rng)
37
+
38
+ # Move to next step
39
+ return Libtask. consume (t. ctask)
40
+ end
29
41
30
42
# reset log probability
31
43
reset_logprob! (t:: Trace ) = nothing
48
60
# Create new task and copy randomness
49
61
function forkr (trace:: Trace )
50
62
newf = reset_model (trace. f)
63
+ Random123. set_counter! (trace. rng, 1 )
64
+
51
65
ctask = let f = trace. ctask. task. code
52
66
Libtask. CTask () do
53
- res = f ()
67
+ res = f ()(trace . rng)
54
68
Libtask. produce (nothing )
55
69
return res
56
70
end
57
71
end
58
72
59
73
# add backward reference
60
- newtrace = Trace (newf, ctask)
74
+ newtrace = Trace (newf, ctask, trace . rng )
61
75
addreference! (ctask. task, newtrace)
62
76
63
77
return newtrace
@@ -81,15 +95,21 @@ Data structure for particle filters
81
95
- normalise!(pc::ParticleContainer)
82
96
- consume(pc::ParticleContainer): return incremental likelihood
83
97
"""
84
- mutable struct ParticleContainer{T<: Particle }
98
+ mutable struct ParticleContainer{T<: Particle ,U,N,V <: Random123.AbstractR123{U} }
85
99
" Particles."
86
100
vals:: Vector{T}
87
101
" Unnormalized logarithmic weights."
88
102
logWs:: Vector{Float64}
103
+ " Traced RNG to replay the resampling step"
104
+ rng:: TracedRNG{U,N,V}
89
105
end
90
106
91
107
function ParticleContainer (particles:: Vector{<:Particle} )
92
- return ParticleContainer (particles, zeros (length (particles)))
108
+ return ParticleContainer (particles, zeros (length (particles)), TracedRNG ())
109
+ end
110
+
111
+ function ParticleContainer (particles:: Vector{<:Particle} , r:: TracedRNG )
112
+ return ParticleContainer (particles, zeros (length (particles)), r)
93
113
end
94
114
95
115
Base. collect (pc:: ParticleContainer ) = pc. vals
@@ -116,7 +136,10 @@ function Base.copy(pc::ParticleContainer)
116
136
# copy weights
117
137
logWs = copy (pc. logWs)
118
138
119
- return ParticleContainer (vals, logWs)
139
+ # Copy rng and states
140
+ rng = copy (pc. rng)
141
+
142
+ return ParticleContainer (vals, logWs, rng)
120
143
end
121
144
122
145
"""
@@ -170,6 +193,22 @@ function effectiveSampleSize(pc::ParticleContainer)
170
193
return inv (sum (abs2, Ws))
171
194
end
172
195
196
+ """
197
+ update_keys!(pc::ParticleContainer)
198
+
199
+ Create new unique keys for the particles in the ParticleContainer
200
+ """
201
+ function update_keys! (pc:: ParticleContainer , ref:: Union{Particle,Nothing} = nothing )
202
+ # Update keys to new particle ids
203
+ nparticles = length (pc)
204
+ n = ref === nothing ? nparticles : nparticles - 1
205
+ for i in 1 : n
206
+ pi = pc. vals[i]
207
+ k = split (pi . rng. rng. key)
208
+ Random. seed! (pi . rng, k[1 ])
209
+ end
210
+ end
211
+
173
212
"""
174
213
resample_propagate!(rng, pc::ParticleContainer[, randcat = resample_systematic,
175
214
ref = nothing; weights = getweights(pc)])
@@ -213,11 +252,17 @@ function resample_propagate!(
213
252
pi = particles[i]
214
253
isref = pi === ref
215
254
p = isref ? fork (pi , isref) : pi
216
- children[j += 1 ] = p
255
+ nseeds = isref ? ni - 1 : ni
256
+
257
+ seeds = split (p. rng. rng. key, nseeds)
258
+ ! isref && Random. seed! (p. rng, seeds[1 ])
217
259
260
+ children[j += 1 ] = p
218
261
# fork additional children
219
- for _ in 2 : ni
220
- children[j += 1 ] = fork (p, isref)
262
+ for k in 2 : ni
263
+ part = fork (p, isref)
264
+ Random. seed! (part. rng, seeds[k])
265
+ children[j += 1 ] = part
221
266
end
222
267
end
223
268
end
@@ -247,6 +292,8 @@ function resample_propagate!(
247
292
248
293
if ess ≤ resampler. threshold * length (pc)
249
294
resample_propagate! (rng, pc, resampler. resampler, ref; weights= weights)
295
+ else
296
+ update_keys! (pc, ref)
250
297
end
251
298
252
299
return pc
258
305
Check if the final time step is reached, and otherwise reweight the particles by
259
306
considering the next observation.
260
307
"""
261
- function reweight! (pc:: ParticleContainer )
308
+ function reweight! (pc:: ParticleContainer , ref :: Union{Particle,Nothing} = nothing )
262
309
n = length (pc)
263
310
264
311
particles = collect (pc)
@@ -270,7 +317,8 @@ function reweight!(pc::ParticleContainer)
270
317
# the execution of the model is finished.
271
318
# Here ``yᵢ`` are observations, ``xᵢ`` variables of the particle filter, and
272
319
# ``θᵢ`` are variables of other samplers.
273
- score = advance! (p)
320
+ isref = p === ref
321
+ score = advance! (p, isref)
274
322
275
323
if score === nothing
276
324
numdone += 1
@@ -321,7 +369,6 @@ function sweep!(
321
369
ref:: Union{Particle,Nothing} = nothing ,
322
370
)
323
371
# Initial step:
324
-
325
372
# Resample and propagate particles.
326
373
resample_propagate! (rng, pc, resampler, ref)
327
374
@@ -333,7 +380,7 @@ function sweep!(
333
380
logZ0 = logZ (pc)
334
381
335
382
# Reweight the particles by including the first observation ``y₁``.
336
- isdone = reweight! (pc)
383
+ isdone = reweight! (pc, ref )
337
384
338
385
# Compute the normalizing constant ``Z₁`` after reweighting.
339
386
logZ1 = logZ (pc)
@@ -351,7 +398,7 @@ function sweep!(
351
398
logZ0 = logZ (pc)
352
399
353
400
# Reweight the particles by including the next observation ``yₜ``.
354
- isdone = reweight! (pc)
401
+ isdone = reweight! (pc, ref )
355
402
356
403
# Compute the normalizing constant ``Z₁`` after reweighting.
357
404
logZ1 = logZ (pc)
0 commit comments