Skip to content

Commit 0893eac

Browse files
Correct replaying mechanism
1 parent 5b01908 commit 0893eac

File tree

5 files changed

+96
-30
lines changed

5 files changed

+96
-30
lines changed

src/container.jl

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask), deepcopy(trace.rng))
3232
# step to the next observe statement and
3333
# return the log probability of the transition (or nothing if done)
3434
function advance!(t::Trace, isref::Bool)
35-
isref ? reset_rng!(t.rng) : save_state!(t.rng)
35+
isref ? load_state(t.rng) : save_state!(t.rng)
3636
inc_count!(t.rng)
3737

3838
# Move to next step
@@ -100,10 +100,16 @@ mutable struct ParticleContainer{T<:Particle}
100100
vals::Vector{T}
101101
"Unnormalized logarithmic weights."
102102
logWs::Vector{Float64}
103+
"Traced RNG"
104+
rng::TracedRNG
103105
end
104106

105107
function ParticleContainer(particles::Vector{<:Particle})
106-
return ParticleContainer(particles, zeros(length(particles)))
108+
return ParticleContainer(particles, zeros(length(particles)), TracedRNG())
109+
end
110+
111+
function ParticleContainer(particles::Vector{<:Particle}, r::TracedRNG)
112+
return ParticleContainer(particles, zeros(length(particles)), r)
107113
end
108114

109115
Base.collect(pc::ParticleContainer) = pc.vals
@@ -130,7 +136,10 @@ function Base.copy(pc::ParticleContainer)
130136
# copy weights
131137
logWs = copy(pc.logWs)
132138

133-
return ParticleContainer(vals, logWs)
139+
# Copy rng and states
140+
rng = copy(pc.rng)
141+
142+
return ParticleContainer(vals, logWs, rng)
134143
end
135144

136145
"""
@@ -184,6 +193,21 @@ function effectiveSampleSize(pc::ParticleContainer)
184193
return inv(sum(abs2, Ws))
185194
end
186195

196+
"""
197+
update_keys!(pc::ParticleContainer)
198+
199+
Create new unique keys for the particles in the ParticleContainer
200+
"""
201+
function update_keys!(pc::ParticleContainer)
202+
# Update keys to new particle ids
203+
for i in 1:length(pc)
204+
pi = pc.vals[i]
205+
k = split(pi.rng, 1)
206+
seed!(pi.rng, k[1])
207+
set_counter!(pi.rng.rng, pi.rng.count)
208+
end
209+
end
210+
187211
"""
188212
resample_propagate!(rng, pc::ParticleContainer[, randcat = resample_systematic,
189213
ref = nothing; weights = getweights(pc)])
@@ -227,9 +251,11 @@ function resample_propagate!(
227251
pi = particles[i]
228252
isref = pi === ref
229253
p = isref ? fork(pi, isref) : pi
230-
children[j += 1] = p
231254

232-
seeds = split(pi.rng, ni)
255+
seeds = split(p.rng, ni)
256+
seed!(p.rng, seeds[1])
257+
258+
children[j += 1] = p
233259
# fork additional children
234260
for k in 2:ni
235261
part = fork(p, isref)
@@ -264,6 +290,8 @@ function resample_propagate!(
264290

265291
if ess resampler.threshold * length(pc)
266292
resample_propagate!(rng, pc, resampler.resampler, ref; weights=weights)
293+
else
294+
update_keys!(pc)
267295
end
268296

269297
return pc

src/rng.jl

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ using Distributions
55
import Base.rand
66
import Random.seed!
77

8-
# Use Philox2x for now
9-
BASE_RNG = Philox2x
8+
# Default RNG type for when nothing is specified
9+
_BASE_RNG = Philox2x
1010

1111
"""
1212
TracedRNG{R,T}
@@ -15,14 +15,22 @@ Wrapped random number generator from Random123 to keep track of random streams d
1515
"""
1616
mutable struct TracedRNG{T} <:
1717
Random.AbstractRNG where {T<:(Random123.AbstractR123{R} where {R})}
18+
"Model step counter"
1819
count::Int
20+
"Inner RNG"
1921
rng::T
22+
"Array of keys"
2023
keys
21-
counters
2224
end
2325

26+
"""
27+
TracedRNG(r::Random123.AbstractR123)
28+
29+
Initialize TracedRNG with r as the inner RNG
30+
"""
2431
function TracedRNG(r::Random123.AbstractR123)
25-
return TracedRNG(1, r, typeof(r.key)[], typeof(r.ctr1)[])
32+
set_counter!(r, 1)
33+
return TracedRNG(1, r, typeof(r.key)[])
2634
end
2735

2836
"""
@@ -31,48 +39,50 @@ end
3139
Create a default TracedRNG
3240
"""
3341
function TracedRNG()
34-
r = BASE_RNG()
42+
r = _BASE_RNG()
3543
return TracedRNG(r)
3644
end
3745

38-
# Plug into Random
46+
# Connect to the Random API
3947
Random.rng_native_52(rng::TracedRNG{U}) where {U} = Random.rng_native_52(rng.rng)
4048
Base.rand(rng::TracedRNG{U}, ::Type{T}) where {U,T} = Base.rand(rng.rng, T)
4149

4250
"""
4351
split(r::TracedRNG, n::Integer)
4452
45-
Split keys of the internal Philox2x into n distinct seeds
53+
Split inner RNG into n new TracedRNG
4654
"""
4755
function split(r::TracedRNG{T}, n::Integer) where {T}
48-
n == 1 && return [r.rng.key]
4956
return map(i -> hash(r.rng.key, convert(UInt, r.rng.ctr1 + i)), 1:n)
5057
end
5158

5259
"""
53-
update_rng!(r::TracedRNG, seed::Number)
60+
seed!(r::TracedRNG, seed::Number)
5461
55-
Set the key of the wrapped Philox2x rng
62+
Set the key of the inner RNG as `seed`
5663
"""
5764
function seed!(r::TracedRNG{T}, seed) where {T}
5865
return seed!(r.rng, seed)
5966
end
6067

6168
"""
62-
reset_rng(r::TracedRNG, seed)
69+
load_state(r::TracedRNG, seed)
6370
64-
Reset the rng to the running model step
71+
Load state from current model iteration. Random streams are now replayed
6572
"""
66-
function reset_rng!(rng::TracedRNG{T}) where {T}
73+
function load_state(rng::TracedRNG{T}) where {T}
6774
key = rng.keys[rng.count]
68-
ctr = rng.counters[rng.count]
6975
Random.seed!(rng.rng, key)
70-
return set_counter!(rng.rng, ctr)
76+
return set_counter!(rng.rng, rng.count)
7177
end
7278

79+
"""
80+
save_state!(r::TracedRNG)
81+
82+
Track current key of the inner RNG
83+
"""
7384
function save_state!(r::TracedRNG{T}) where {T}
74-
push!(r.keys, r.rng.key)
75-
return push!(r.counters, r.rng.ctr1)
85+
return push!(r.keys, r.rng.key)
7686
end
7787

7888
Base.copy(r::TracedRNG{T}) where {T} = TracedRNG(r.count, copy(r.rng), copy(r.keys))

src/smc.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ function AbstractMCMC.sample(
3838
end
3939

4040
# Create a set of particles.
41-
particles = ParticleContainer([
42-
Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)
43-
])
41+
particles = ParticleContainer(
42+
[Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)], TracedRNG()
43+
)
4444

4545
# Perform particle sweep.
4646
logevidence = sweep!(rng, particles, sampler.resampler)
@@ -85,9 +85,9 @@ function AbstractMCMC.step(
8585
rng::Random.AbstractRNG, model::AbstractMCMC.AbstractModel, sampler::PG; kwargs...
8686
)
8787
# Create a new set of particles.
88-
particles = ParticleContainer([
89-
Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)
90-
])
88+
particles = ParticleContainer(
89+
[Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)], TracedRNG()
90+
)
9191

9292
# Perform a particle sweep.
9393
logevidence = sweep!(rng, particles, sampler.resampler)
@@ -115,7 +115,7 @@ function AbstractMCMC.step(
115115
Trace(model, TracedRNG())
116116
end
117117
end
118-
particles = ParticleContainer(x)
118+
particles = ParticleContainer(x, TracedRNG())
119119

120120
# Perform a particle sweep.
121121
logevidence = sweep!(rng, particles, sampler.resampler, particles.vals[nparticles])

test/rng.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
rand(rng, Distributions.Normal())
88

9-
AdvancedPS.reset_rng!(rng)
9+
AdvancedPS.load_state(rng)
1010
new_vns = rand(rng, Distributions.Normal())
1111
@test new_vns vns
1212
end

test/smc.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,34 @@
150150
@test all(isone(p.trajectory.f.x) for p in chains_pg)
151151
@test mean(x.logevidence for x in chains_pg) -2 * log(2) atol = 0.01
152152
end
153+
154+
@testset "Replay reference" begin
155+
mutable struct Model <: AbstractMCMC.AbstractModel
156+
a::Float64
157+
b::Float64
158+
159+
Model() = new()
160+
end
161+
162+
function (m::Model)(rng)
163+
m.a = rand(rng, Normal())
164+
AdvancedPS.observe(Normal(), m.a)
165+
166+
m.b = rand(rng, Normal())
167+
AdvancedPS.observe(Normal(), m.b)
168+
end
169+
170+
pg = AdvancedPS.PG(1)
171+
first, second = sample(Model(), pg, 2);
172+
173+
first_model = first.trajectory.f
174+
second_model = second.trajectory.f
175+
176+
# Single Particle - must be replaying
177+
@test first_model.a second_model.a
178+
@test first_model.b second_model.b
179+
@test first.logevidence second.logevidence
180+
end
153181
end
154182

155183
# @testset "pmmh.jl" begin

0 commit comments

Comments
 (0)