Skip to content

Commit 8134751

Browse files
Merge pull request #2 from FredericWantiez/feature/split
Reset rng
2 parents 22fa456 + 5b01908 commit 8134751

File tree

6 files changed

+115
-78
lines changed

6 files changed

+115
-78
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
88
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
99
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
1010
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
11+
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
1112
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1213

1314
[compat]
1415
AbstractMCMC = "2, 3"
1516
Distributions = "0.23, 0.24, 0.25"
1617
Libtask = "0.5.3"
18+
Random123 = "1.3"
1719
StatsFuns = "0.9"
1820
julia = "1.3"

src/container.jl

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,21 @@ function Trace(f, rng::TracedRNG)
2323
end
2424

2525
function Trace(f, ctask::Libtask.CTask)
26-
rng = TracedRNG()
27-
return Trace(f, ctask, rng)
26+
return Trace(f, ctask, TracedRNG())
2827
end
2928

30-
# Copy task and RNG
31-
Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask))
29+
# Copy task
30+
Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask), deepcopy(trace.rng))
3231

3332
# step to the next observe statement and
3433
# return the log probability of the transition (or nothing if done)
35-
advance!(t::Trace) = Libtask.consume(t.ctask)
34+
function advance!(t::Trace, isref::Bool)
35+
isref ? reset_rng!(t.rng) : save_state!(t.rng)
36+
inc_count!(t.rng)
37+
38+
# Move to next step
39+
return Libtask.consume(t.ctask)
40+
end
3641

3742
# reset log probability
3843
reset_logprob!(t::Trace) = nothing
@@ -55,6 +60,7 @@ end
5560
# Create new task and copy randomness
5661
function forkr(trace::Trace)
5762
newf = reset_model(trace.f)
63+
set_count!(trace.rng, 1)
5864

5965
ctask = let f = trace.ctask.task.code
6066
Libtask.CTask() do
@@ -89,23 +95,15 @@ Data structure for particle filters
8995
- normalise!(pc::ParticleContainer)
9096
- consume(pc::ParticleContainer): return incremental likelihood
9197
"""
92-
mutable struct ParticleContainer{T<:Particle,R<:Random.AbstractRNG}
98+
mutable struct ParticleContainer{T<:Particle}
9399
"Particles."
94100
vals::Vector{T}
95101
"Unnormalized logarithmic weights."
96102
logWs::Vector{Float64}
97-
"TracedRNG to track the resampling step"
98-
rng::TracedRNG{R}
99103
end
100104

101105
function ParticleContainer(particles::Vector{<:Particle})
102-
return ParticleContainer(particles, zeros(length(particles)), TracedRNG())
103-
end
104-
105-
function ParticleContainer(
106-
particles::Vector{<:Particle}, rng::T
107-
) where {T<:Random.AbstractRNG}
108-
return ParticleContainer(particles, zeros(length(particles)), TracedRNG(rng))
106+
return ParticleContainer(particles, zeros(length(particles)))
109107
end
110108

111109
Base.collect(pc::ParticleContainer) = pc.vals
@@ -132,7 +130,7 @@ function Base.copy(pc::ParticleContainer)
132130
# copy weights
133131
logWs = copy(pc.logWs)
134132

135-
return ParticleContainer(vals, logWs, pc.rng)
133+
return ParticleContainer(vals, logWs)
136134
end
137135

138136
"""
@@ -231,9 +229,12 @@ function resample_propagate!(
231229
p = isref ? fork(pi, isref) : pi
232230
children[j += 1] = p
233231

232+
seeds = split(pi.rng, ni)
234233
# fork additional children
235-
for _ in 2:ni
236-
children[j += 1] = fork(p, isref)
234+
for k in 2:ni
235+
part = fork(p, isref)
236+
seed!(part.rng, seeds[k])
237+
children[j += 1] = part
237238
end
238239
end
239240
end
@@ -274,7 +275,7 @@ end
274275
Check if the final time step is reached, and otherwise reweight the particles by
275276
considering the next observation.
276277
"""
277-
function reweight!(pc::ParticleContainer)
278+
function reweight!(pc::ParticleContainer, ref::Union{Particle,Nothing}=nothing)
278279
n = length(pc)
279280

280281
particles = collect(pc)
@@ -286,7 +287,8 @@ function reweight!(pc::ParticleContainer)
286287
# the execution of the model is finished.
287288
# Here ``yᵢ`` are observations, ``xᵢ`` variables of the particle filter, and
288289
# ``θᵢ`` are variables of other samplers.
289-
score = advance!(p)
290+
isref = p === ref
291+
score = advance!(p, isref)
290292

291293
if score === nothing
292294
numdone += 1
@@ -337,7 +339,6 @@ function sweep!(
337339
ref::Union{Particle,Nothing}=nothing,
338340
)
339341
# Initial step:
340-
341342
# Resample and propagate particles.
342343
resample_propagate!(rng, pc, resampler, ref)
343344

@@ -349,7 +350,7 @@ function sweep!(
349350
logZ0 = logZ(pc)
350351

351352
# Reweight the particles by including the first observation ``y₁``.
352-
isdone = reweight!(pc)
353+
isdone = reweight!(pc, ref)
353354

354355
# Compute the normalizing constant ``Z₁`` after reweighting.
355356
logZ1 = logZ(pc)
@@ -367,7 +368,7 @@ function sweep!(
367368
logZ0 = logZ(pc)
368369

369370
# Reweight the particles by including the next observation ``yₜ``.
370-
isdone = reweight!(pc)
371+
isdone = reweight!(pc, ref)
371372

372373
# Compute the normalizing constant ``Z₁`` after reweighting.
373374
logZ1 = logZ(pc)

src/rng.jl

Lines changed: 74 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,88 @@
1+
using Random123
2+
using Random
3+
using Distributions
4+
5+
import Base.rand
6+
import Random.seed!
7+
8+
# Use Philox2x for now
9+
BASE_RNG = Philox2x
10+
111
"""
2-
Data structure to keep track of the history of the random stream
3-
produced by RNG.
12+
TracedRNG{R,T}
13+
14+
Wrapped random number generator from Random123 to keep track of random streams during model evaluation
415
"""
5-
mutable struct TracedRNG{T} <: Random.AbstractRNG where {T<:Random.AbstractRNG}
6-
count::Base.RefValue{Int}
16+
mutable struct TracedRNG{T} <:
17+
Random.AbstractRNG where {T<:(Random123.AbstractR123{R} where {R})}
18+
count::Int
719
rng::T
8-
seed::Array
9-
states::Array{T}
20+
keys
21+
counters
1022
end
1123

12-
# Set seed manually, for init ?
13-
function Random.seed!(rng::TracedRNG, seed)
14-
rng.rng.seed = seed
15-
return Random.seed!(rng.rng, seed)
24+
function TracedRNG(r::Random123.AbstractR123)
25+
return TracedRNG(1, r, typeof(r.key)[], typeof(r.ctr1)[])
1626
end
1727

18-
# Reset the rng to the initial seed
19-
Random.seed!(rng::TracedRNG) = Random.seed!(rng.rng, rng.seed)
28+
"""
29+
TracedRNG()
2030
21-
TracedRNG() = TracedRNG(Random.MersenneTwister()) # Pick up an explicit RNG from Random
22-
TracedRNG(rng::Random.AbstractRNG) = TracedRNG(Ref(0), rng, rng.seed, [rng])
23-
TracedRNG(rng::Random._GLOBAL_RNG) = TracedRNG(Random.default_rng())
31+
Create a default TracedRNG
32+
"""
33+
function TracedRNG()
34+
r = BASE_RNG()
35+
return TracedRNG(r)
36+
end
2437

25-
# Intercept rand
26-
# https://github.com/JuliaLang/julia/issues/30732
27-
Random.rng_native_52(r::TracedRNG) = UInt64
38+
# Plug into Random
39+
Random.rng_native_52(rng::TracedRNG{U}) where {U} = Random.rng_native_52(rng.rng)
40+
Base.rand(rng::TracedRNG{U}, ::Type{T}) where {U,T} = Base.rand(rng.rng, T)
2841

29-
function Base.rand(rng::TracedRNG, ::Type{T}) where {T}
30-
res = Base.rand(rng.rng, T)
31-
inc_count!(rng, length(res))
32-
push!(rng.states, copy(rng.rng))
33-
return res
42+
"""
43+
split(r::TracedRNG, n::Integer)
44+
45+
Split keys of the internal Philox2x into n distinct seeds
46+
"""
47+
function split(r::TracedRNG{T}, n::Integer) where {T}
48+
n == 1 && return [r.rng.key]
49+
return map(i -> hash(r.rng.key, convert(UInt, r.rng.ctr1 + i)), 1:n)
3450
end
3551

36-
inc_count!(rng::TracedRNG) = inc_count!(rng, 1)
52+
"""
53+
update_rng!(r::TracedRNG, seed::Number)
3754
38-
inc_count!(rng::TracedRNG, n::Int) = rng.count[] += n
55+
Set the key of the wrapped Philox2x rng
56+
"""
57+
function seed!(r::TracedRNG{T}, seed) where {T}
58+
return seed!(r.rng, seed)
59+
end
60+
61+
"""
62+
reset_rng(r::TracedRNG, seed)
63+
64+
Reset the rng to the running model step
65+
"""
66+
function reset_rng!(rng::TracedRNG{T}) where {T}
67+
key = rng.keys[rng.count]
68+
ctr = rng.counters[rng.count]
69+
Random.seed!(rng.rng, key)
70+
return set_counter!(rng.rng, ctr)
71+
end
72+
73+
function save_state!(r::TracedRNG{T}) where {T}
74+
push!(r.keys, r.rng.key)
75+
return push!(r.counters, r.rng.ctr1)
76+
end
77+
78+
Base.copy(r::TracedRNG{T}) where {T} = TracedRNG(r.count, copy(r.rng), copy(r.keys))
79+
80+
"""
81+
set_count!(r::TracedRNG, n::Integer)
82+
83+
Set the counter of the TracedRNG, used to keep track of the current model step
84+
"""
85+
set_count!(r::TracedRNG, n::Integer) = r.count = n
3986

40-
curr_count(t::TracedRNG) = t.count[]
87+
inc_count!(r::TracedRNG, n::Integer) = r.count += n
88+
inc_count!(r::TracedRNG) = inc_count!(r, 1)

src/smc.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ function AbstractMCMC.sample(
3838
end
3939

4040
# Create a set of particles.
41-
particles = ParticleContainer(
42-
[Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)], rng
43-
)
41+
particles = ParticleContainer([
42+
Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)
43+
])
4444

4545
# Perform particle sweep.
46-
logevidence = sweep!(particles.rng, particles, sampler.resampler)
46+
logevidence = sweep!(rng, particles, sampler.resampler)
4747

4848
return SMCSample(collect(particles), getweights(particles), logevidence)
4949
end
@@ -85,12 +85,12 @@ function AbstractMCMC.step(
8585
rng::Random.AbstractRNG, model::AbstractMCMC.AbstractModel, sampler::PG; kwargs...
8686
)
8787
# Create a new set of particles.
88-
particles = ParticleContainer(
89-
[Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)], rng
90-
)
88+
particles = ParticleContainer([
89+
Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)
90+
])
9191

9292
# Perform a particle sweep.
93-
logevidence = sweep!(particles.rng, particles, sampler.resampler)
93+
logevidence = sweep!(rng, particles, sampler.resampler)
9494

9595
# Pick a particle to be retained.
9696
trajectory = rand(rng, particles)
@@ -115,12 +115,10 @@ function AbstractMCMC.step(
115115
Trace(model, TracedRNG())
116116
end
117117
end
118-
particles = ParticleContainer(x, rng)
118+
particles = ParticleContainer(x)
119119

120120
# Perform a particle sweep.
121-
logevidence = sweep!(
122-
particles.rng, particles, sampler.resampler, particles.vals[nparticles]
123-
)
121+
logevidence = sweep!(rng, particles, sampler.resampler, particles.vals[nparticles])
124122

125123
# Pick a particle to be retained.
126124
newtrajectory = rand(rng, particles)

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
44
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
55
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6+
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
67
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
78

89
[compat]
910
AbstractMCMC = "2, 3"
1011
Distributions = "0.24, 0.25"
1112
Libtask = "0.5"
1213
julia = "1.3"
14+
Random123 = "1.3"

test/rng.jl

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,12 @@
22
@testset "sample distribution" begin
33
rng = AdvancedPS.TracedRNG()
44
vns = rand(rng, Distributions.Normal())
5-
6-
@test AdvancedPS.curr_count(rng) === 1
5+
AdvancedPS.save_state!(rng)
76

87
rand(rng, Distributions.Normal())
9-
Random.seed!(rng)
8+
9+
AdvancedPS.reset_rng!(rng)
1010
new_vns = rand(rng, Distributions.Normal())
1111
@test new_vns vns
1212
end
13-
14-
@testset "inc count" begin
15-
rng = AdvancedPS.TracedRNG()
16-
AdvancedPS.inc_count!(rng)
17-
@test AdvancedPS.curr_count(rng) == 1
18-
19-
AdvancedPS.inc_count!(rng, 2)
20-
@test AdvancedPS.curr_count(rng) == 3
21-
end
22-
23-
@testset "curr count" begin
24-
rng = AdvancedPS.TracedRNG()
25-
@test AdvancedPS.curr_count(rng) == 0
26-
end
2713
end

0 commit comments

Comments
 (0)