Skip to content

Commit 44f5d1a

Browse files
Seed particles from top-level RNG (#41)
* Set seed from parent RNG * Decouple seed and sampler * No-allocation * Update Project.toml Co-authored-by: Hong Ge <[email protected]>
1 parent 44f9fd8 commit 44f5d1a

File tree

6 files changed

+96
-17
lines changed

6 files changed

+96
-17
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1414
[compat]
1515
AbstractMCMC = "2, 3"
1616
Distributions = "0.23, 0.24, 0.25"
17-
Libtask = "0.6"
17+
Libtask = "0.6.7"
1818
Random123 = "1.3"
1919
StatsFuns = "0.9"
2020
julia = "1.3"

src/container.jl

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,15 @@ function ParticleContainer(particles::Vector{<:Particle})
9494
return ParticleContainer(particles, zeros(length(particles)), TracedRNG())
9595
end
9696

97-
function ParticleContainer(particles::Vector{<:Particle}, r::TracedRNG)
98-
return ParticleContainer(particles, zeros(length(particles)), r)
97+
function ParticleContainer(particles::Vector{<:Particle}, rng::TracedRNG)
98+
return ParticleContainer(particles, zeros(length(particles)), rng)
99+
end
100+
101+
function ParticleContainer(
102+
particles::Vector{<:Particle}, trng::TracedRNG, rng::Random.AbstractRNG
103+
)
104+
pc = ParticleContainer(particles, trng)
105+
return seed_from_rng!(pc, rng)
99106
end
100107

101108
Base.collect(pc::ParticleContainer) = pc.vals
@@ -190,12 +197,35 @@ function update_keys!(pc::ParticleContainer, ref::Union{Particle,Nothing}=nothin
190197
n = ref === nothing ? nparticles : nparticles - 1
191198
for i in 1:n
192199
pi = pc.vals[i]
193-
k = split(pi.rng.rng.key)
200+
k = split(state(pi.rng.rng))
194201
Random.seed!(pi.rng, k[1])
195202
end
196203
return nothing
197204
end
198205

206+
"""
207+
seed_from_rng!(pc::ParticleContainer, rng::Random.AbstractRNG, ref::Union{Particle,Nothing}=nothing)
208+
209+
Set seeds of particle rng from user-provided `rng`
210+
"""
211+
function seed_from_rng!(
212+
pc::ParticleContainer{T,<:TracedRNG{R,N,<:Random123.AbstractR123{I}}},
213+
rng::Random.AbstractRNG,
214+
ref::Union{Particle,Nothing}=nothing,
215+
) where {T,R,N,I}
216+
n = length(pc.vals)
217+
nseeds = isnothing(ref) ? n : n - 1
218+
219+
sampler = Random.Sampler(rng, I)
220+
for i in 1:nseeds
221+
subrng = pc.vals[i].rng
222+
Random.seed!(subrng, gen_seed(rng, subrng, sampler))
223+
end
224+
Random.seed!(pc.rng, gen_seed(rng, pc.rng, sampler))
225+
226+
return pc
227+
end
228+
199229
"""
200230
resample_propagate!(rng, pc::ParticleContainer[, randcat = resample_systematic,
201231
ref = nothing; weights = getweights(pc)])

src/rng.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# Default RNG type for when nothing is specified
2-
const _BASE_RNG = Random123.Philox2x
2+
const _BASE_RNG = Random123.Philox2x # Rng with state bigger than 1 are broken because of Split
33

44
"""
55
TracedRNG{R,N,T}
66
77
Wrapped random number generator from Random123 to keep track of random streams during model evaluation
88
"""
9-
mutable struct TracedRNG{R,N,T<:Random123.AbstractR123{R}} <: Random.AbstractRNG
9+
mutable struct TracedRNG{R,N,T<:Random123.AbstractR123} <: Random.AbstractRNG
1010
"Model step counter"
1111
count::Int
1212
"Inner RNG"
@@ -21,7 +21,7 @@ Create a `TracedRNG` with `r` as the inner RNG.
2121
"""
2222
function TracedRNG(r::Random123.AbstractR123=_BASE_RNG())
2323
Random123.set_counter!(r, 0)
24-
return TracedRNG(1, r, typeof(r.key)[])
24+
return TracedRNG(1, r, Random123.seed_type(r)[])
2525
end
2626

2727
# Connect to the Random API
@@ -50,7 +50,7 @@ function load_state!(rng::TracedRNG)
5050
end
5151

5252
"""
53-
update_rng!(rng::TracedRNG)
53+
Random.seed!(rng::TracedRNG, key)
5454
5555
Set key and counter of inner rng in `rng` to `key` and the running model step to 0
5656
"""
@@ -59,15 +59,33 @@ function Random.seed!(rng::TracedRNG, key)
5959
return Random123.set_counter!(rng.rng, 0)
6060
end
6161

62+
"""
63+
gen_seed(rng::Random.AbstractRNG, subrng::TracedRNG, sampler::Random.Sampler)
64+
65+
Generate a `seed` for the subrng based on top-level `rng` and `sampler`
66+
"""
67+
function gen_seed(rng::Random.AbstractRNG, ::TracedRNG{<:Integer}, sampler::Random.Sampler)
68+
return rand(rng, sampler)
69+
end
70+
71+
function gen_seed(
72+
rng::Random.AbstractRNG, ::TracedRNG{<:NTuple{N}}, sampler::Random.Sampler
73+
) where {N}
74+
return Tuple(rand(rng, sampler, N))
75+
end
76+
6277
"""
6378
save_state!(r::TracedRNG)
6479
6580
Add current key of the inner rng in `r` to `keys`.
6681
"""
6782
function save_state!(r::TracedRNG)
68-
return push!(r.keys, r.rng.key)
83+
return push!(r.keys, state(r.rng))
6984
end
7085

86+
state(rng::Random123.Philox2x) = rng.key
87+
state(rng::Random123.Philox4x) = (rng.key1, rng.key2)
88+
7189
Base.copy(r::TracedRNG) = TracedRNG(r.count, copy(r.rng), deepcopy(r.keys))
7290

7391
"""

src/smc.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ function AbstractMCMC.sample(
3939

4040
# Create a set of particles.
4141
particles = ParticleContainer(
42-
[Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)], TracedRNG()
42+
[Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)], TracedRNG(), rng
4343
)
4444

4545
# Perform particle sweep.
@@ -86,7 +86,7 @@ function AbstractMCMC.step(
8686
)
8787
# Create a new set of particles.
8888
particles = ParticleContainer(
89-
[Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)], TracedRNG()
89+
[Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)], TracedRNG(), rng
9090
)
9191

9292
# Perform a particle sweep.
@@ -107,6 +107,7 @@ function AbstractMCMC.step(
107107
)
108108
# Create a new set of particles.
109109
nparticles = sampler.nparticles
110+
110111
x = map(1:nparticles) do i
111112
if i == nparticles
112113
# Create reference trajectory.
@@ -115,10 +116,12 @@ function AbstractMCMC.step(
115116
Trace(model, TracedRNG())
116117
end
117118
end
118-
particles = ParticleContainer(x, TracedRNG())
119+
120+
reference = x[end]
121+
particles = ParticleContainer(x, TracedRNG(), rng)
119122

120123
# Perform a particle sweep.
121-
logevidence = sweep!(rng, particles, sampler.resampler, particles.vals[nparticles])
124+
logevidence = sweep!(rng, particles, sampler.resampler, reference)
122125

123126
# Pick a particle to be retained.
124127
newtrajectory = rand(rng, particles)

test/container.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,4 +121,32 @@
121121
@test consume(tr.ctask) == 2
122122
@test consume(a.ctask) == 4
123123
end
124+
125+
@testset "seed container" begin
126+
function dummy(rng) end
127+
128+
seed = 1
129+
n = 3
130+
rng = Random.MersenneTwister(seed)
131+
132+
particles = [AdvancedPS.Trace(dummy, AdvancedPS.TracedRNG()) for _ in 1:n]
133+
pc = AdvancedPS.ParticleContainer(particles, AdvancedPS.TracedRNG())
134+
135+
AdvancedPS.seed_from_rng!(pc, rng)
136+
old_seeds = vcat([part.rng.rng.key for part in pc.vals], [pc.rng.rng.key])
137+
138+
Random.seed!(rng, seed)
139+
AdvancedPS.seed_from_rng!(pc, rng)
140+
new_seeds = vcat([part.rng.rng.key for part in pc.vals], [pc.rng.rng.key])
141+
142+
# Check if we reset the seeds properly
143+
@test old_seeds new_seeds
144+
145+
Random.seed!(rng, 2)
146+
AdvancedPS.seed_from_rng!(pc, rng, pc.vals[n])
147+
ref_seeds = vcat([part.rng.rng.key for part in pc.vals], [pc.rng.rng.key])
148+
149+
# Dont reset reference particle
150+
@test ref_seeds[n] new_seeds[n]
151+
end
124152
end

test/smc.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
sampler = AdvancedPS.SMC(15, 0.6)
88
@test sampler.nparticles == 15
99
@test sampler.resampler ===
10-
AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_systematic, 0.6)
10+
AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_systematic, 0.6)
1111

1212
sampler = AdvancedPS.SMC(20, AdvancedPS.resample_multinomial, 0.6)
1313
@test sampler.nparticles == 20
1414
@test sampler.resampler ===
15-
AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_multinomial, 0.6)
15+
AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_multinomial, 0.6)
1616

1717
sampler = AdvancedPS.SMC(25, AdvancedPS.resample_systematic)
1818
@test sampler.nparticles == 25
@@ -105,12 +105,12 @@
105105
sampler = AdvancedPS.PG(60, 0.6)
106106
@test sampler.nparticles == 60
107107
@test sampler.resampler ===
108-
AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_systematic, 0.6)
108+
AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_systematic, 0.6)
109109

110110
sampler = AdvancedPS.PG(80, AdvancedPS.resample_multinomial, 0.6)
111111
@test sampler.nparticles == 80
112112
@test sampler.resampler ===
113-
AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_multinomial, 0.6)
113+
AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_multinomial, 0.6)
114114

115115
sampler = AdvancedPS.PG(100, AdvancedPS.resample_systematic)
116116
@test sampler.nparticles == 100

0 commit comments

Comments
 (0)