Skip to content

Commit 51c18d4

Browse files
authored
Merge pull request #23 from FredericWantiez/feature/traced_rng
Fix forkr - Handle rng in Trace
2 parents 350e2c1 + 705d1dc commit 51c18d4

File tree

10 files changed

+245
-45
lines changed

10 files changed

+245
-45
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
name = "AdvancedPS"
22
uuid = "576499cb-2369-40b2-a588-c64705576edc"
33
authors = ["TuringLang"]
4-
version = "0.2.4"
4+
version = "0.3.0"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
88
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
99
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
1010
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
11+
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
1112
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1213

1314
[compat]
1415
AbstractMCMC = "2, 3"
1516
Distributions = "0.23, 0.24, 0.25"
1617
Libtask = "0.5.3"
18+
Random123 = "1.3"
1719
StatsFuns = "0.9"
1820
julia = "1.3"

src/AdvancedPS.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ using Distributions: Distributions
55
using Libtask: Libtask
66
using Random: Random
77
using StatsFuns: StatsFuns
8+
using Random123: Random123
89

910
include("resampling.jl")
11+
include("rng.jl")
1012
include("container.jl")
1113
include("smc.jl")
1214
include("model.jl")

src/container.jl

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,43 @@
1-
struct Trace{F}
1+
struct Trace{F,U,N,V<:Random123.AbstractR123{U}}
22
f::F
33
ctask::Libtask.CTask
4+
rng::TracedRNG{U,N,V}
45
end
56

67
const Particle = Trace
78

8-
function Trace(f)
9+
function Trace(f, rng::TracedRNG)
910
ctask = let f = f
1011
Libtask.CTask() do
11-
res = f()
12+
res = f(rng)
1213
Libtask.produce(nothing)
1314
return res
1415
end
1516
end
1617

1718
# add backward reference
18-
newtrace = Trace(f, ctask)
19+
newtrace = Trace(f, ctask, rng)
1920
addreference!(ctask.task, newtrace)
2021

2122
return newtrace
2223
end
2324

24-
Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask))
25+
function Trace(f, ctask::Libtask.CTask)
26+
return Trace(f, ctask, TracedRNG())
27+
end
28+
29+
# Copy task
30+
Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask), deepcopy(trace.rng))
2531

2632
# step to the next observe statement and
2733
# return the log probability of the transition (or nothing if done)
28-
advance!(t::Trace) = Libtask.consume(t.ctask)
34+
function advance!(t::Trace, isref::Bool)
35+
isref ? load_state!(t.rng) : save_state!(t.rng)
36+
inc_counter!(t.rng)
37+
38+
# Move to next step
39+
return Libtask.consume(t.ctask)
40+
end
2941

3042
# reset log probability
3143
reset_logprob!(t::Trace) = nothing
@@ -48,16 +60,18 @@ end
4860
# Create new task and copy randomness
4961
function forkr(trace::Trace)
5062
newf = reset_model(trace.f)
63+
Random123.set_counter!(trace.rng, 1)
64+
5165
ctask = let f = trace.ctask.task.code
5266
Libtask.CTask() do
53-
res = f()
67+
res = f()(trace.rng)
5468
Libtask.produce(nothing)
5569
return res
5670
end
5771
end
5872

5973
# add backward reference
60-
newtrace = Trace(newf, ctask)
74+
newtrace = Trace(newf, ctask, trace.rng)
6175
addreference!(ctask.task, newtrace)
6276

6377
return newtrace
@@ -81,15 +95,21 @@ Data structure for particle filters
8195
- normalise!(pc::ParticleContainer)
8296
- consume(pc::ParticleContainer): return incremental likelihood
8397
"""
84-
mutable struct ParticleContainer{T<:Particle}
98+
mutable struct ParticleContainer{T<:Particle,U,N,V<:Random123.AbstractR123{U}}
8599
"Particles."
86100
vals::Vector{T}
87101
"Unnormalized logarithmic weights."
88102
logWs::Vector{Float64}
103+
"Traced RNG to replay the resampling step"
104+
rng::TracedRNG{U,N,V}
89105
end
90106

91107
function ParticleContainer(particles::Vector{<:Particle})
92-
return ParticleContainer(particles, zeros(length(particles)))
108+
return ParticleContainer(particles, zeros(length(particles)), TracedRNG())
109+
end
110+
111+
function ParticleContainer(particles::Vector{<:Particle}, r::TracedRNG)
112+
return ParticleContainer(particles, zeros(length(particles)), r)
93113
end
94114

95115
Base.collect(pc::ParticleContainer) = pc.vals
@@ -116,7 +136,10 @@ function Base.copy(pc::ParticleContainer)
116136
# copy weights
117137
logWs = copy(pc.logWs)
118138

119-
return ParticleContainer(vals, logWs)
139+
# Copy rng and states
140+
rng = copy(pc.rng)
141+
142+
return ParticleContainer(vals, logWs, rng)
120143
end
121144

122145
"""
@@ -170,6 +193,22 @@ function effectiveSampleSize(pc::ParticleContainer)
170193
return inv(sum(abs2, Ws))
171194
end
172195

196+
"""
197+
update_keys!(pc::ParticleContainer)
198+
199+
Create new unique keys for the particles in the ParticleContainer
200+
"""
201+
function update_keys!(pc::ParticleContainer, ref::Union{Particle,Nothing}=nothing)
202+
# Update keys to new particle ids
203+
nparticles = length(pc)
204+
n = ref === nothing ? nparticles : nparticles - 1
205+
for i in 1:n
206+
pi = pc.vals[i]
207+
k = split(pi.rng.rng.key)
208+
Random.seed!(pi.rng, k[1])
209+
end
210+
end
211+
173212
"""
174213
resample_propagate!(rng, pc::ParticleContainer[, randcat = resample_systematic,
175214
ref = nothing; weights = getweights(pc)])
@@ -213,11 +252,17 @@ function resample_propagate!(
213252
pi = particles[i]
214253
isref = pi === ref
215254
p = isref ? fork(pi, isref) : pi
216-
children[j += 1] = p
255+
nseeds = isref ? ni - 1 : ni
256+
257+
seeds = split(p.rng.rng.key, nseeds)
258+
!isref && Random.seed!(p.rng, seeds[1])
217259

260+
children[j += 1] = p
218261
# fork additional children
219-
for _ in 2:ni
220-
children[j += 1] = fork(p, isref)
262+
for k in 2:ni
263+
part = fork(p, isref)
264+
Random.seed!(part.rng, seeds[k])
265+
children[j += 1] = part
221266
end
222267
end
223268
end
@@ -247,6 +292,8 @@ function resample_propagate!(
247292

248293
if ess resampler.threshold * length(pc)
249294
resample_propagate!(rng, pc, resampler.resampler, ref; weights=weights)
295+
else
296+
update_keys!(pc, ref)
250297
end
251298

252299
return pc
@@ -258,7 +305,7 @@ end
258305
Check if the final time step is reached, and otherwise reweight the particles by
259306
considering the next observation.
260307
"""
261-
function reweight!(pc::ParticleContainer)
308+
function reweight!(pc::ParticleContainer, ref::Union{Particle,Nothing}=nothing)
262309
n = length(pc)
263310

264311
particles = collect(pc)
@@ -270,7 +317,8 @@ function reweight!(pc::ParticleContainer)
270317
# the execution of the model is finished.
271318
# Here ``yᵢ`` are observations, ``xᵢ`` variables of the particle filter, and
272319
# ``θᵢ`` are variables of other samplers.
273-
score = advance!(p)
320+
isref = p === ref
321+
score = advance!(p, isref)
274322

275323
if score === nothing
276324
numdone += 1
@@ -321,7 +369,6 @@ function sweep!(
321369
ref::Union{Particle,Nothing}=nothing,
322370
)
323371
# Initial step:
324-
325372
# Resample and propagate particles.
326373
resample_propagate!(rng, pc, resampler, ref)
327374

@@ -333,7 +380,7 @@ function sweep!(
333380
logZ0 = logZ(pc)
334381

335382
# Reweight the particles by including the first observation ``y₁``.
336-
isdone = reweight!(pc)
383+
isdone = reweight!(pc, ref)
337384

338385
# Compute the normalizing constant ``Z₁`` after reweighting.
339386
logZ1 = logZ(pc)
@@ -351,7 +398,7 @@ function sweep!(
351398
logZ0 = logZ(pc)
352399

353400
# Reweight the particles by including the next observation ``yₜ``.
354-
isdone = reweight!(pc)
401+
isdone = reweight!(pc, ref)
355402

356403
# Compute the normalizing constant ``Z₁`` after reweighting.
357404
logZ1 = logZ(pc)

src/rng.jl

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Default RNG type for when nothing is specified
2+
const _BASE_RNG = Random123.Philox2x
3+
4+
"""
5+
TracedRNG{R,N,T}
6+
7+
Wrapped random number generator from Random123 to keep track of random streams during model evaluation
8+
"""
9+
mutable struct TracedRNG{R,N,T<:Random123.AbstractR123{R}} <: Random.AbstractRNG
10+
"Model step counter"
11+
count::Int
12+
"Inner RNG"
13+
rng::T
14+
"Array of keys"
15+
keys::Array{R,N}
16+
end
17+
18+
"""
19+
TracedRNG(r::Random123.AbstractR123=AdvancedPS._BASE_RNG())
20+
Create a `TracedRNG` with `r` as the inner RNG.
21+
"""
22+
function TracedRNG(r::Random123.AbstractR123=_BASE_RNG())
23+
Random123.set_counter!(r, 0)
24+
return TracedRNG(1, r, typeof(r.key)[])
25+
end
26+
27+
# Connect to the Random API
28+
Random.rng_native_52(rng::TracedRNG) = Random.rng_native_52(rng.rng)
29+
Base.rand(rng::TracedRNG, ::Type{T}) where {T} = Base.rand(rng.rng, T)
30+
31+
"""
32+
split(key::Integer, n::Integer=1)
33+
34+
Split `key` into `n` new keys
35+
"""
36+
function split(key::Integer, n::Integer=1)
37+
T = typeof(key) # Make sure the type of `key` is consistent on W32 and W64 systems.
38+
return T[hash(key, i) for i in UInt(1):UInt(n)]
39+
end
40+
41+
"""
42+
load_state!(r::TracedRNG)
43+
44+
Load state from current model iteration. Random streams are now replayed
45+
"""
46+
function load_state!(rng::TracedRNG)
47+
key = rng.keys[rng.count]
48+
Random.seed!(rng.rng, key)
49+
return Random123.set_counter!(rng.rng, 0)
50+
end
51+
52+
"""
53+
update_rng!(rng::TracedRNG)
54+
55+
Set key and counter of inner rng in `rng` to `key` and the running model step to 0
56+
"""
57+
function Random.seed!(rng::TracedRNG, key)
58+
Random.seed!(rng.rng, key)
59+
return Random123.set_counter!(rng.rng, 0)
60+
end
61+
62+
"""
63+
save_state!(r::TracedRNG)
64+
65+
Add current key of the inner rng in `r` to `keys`.
66+
"""
67+
function save_state!(r::TracedRNG)
68+
return push!(r.keys, r.rng.key)
69+
end
70+
71+
Base.copy(r::TracedRNG) = TracedRNG(r.count, copy(r.rng), deepcopy(r.keys))
72+
73+
"""
74+
set_counter!(r::TracedRNG, n::Integer)
75+
76+
Set the counter of the inner rng in `r`, used to keep track of the current model step
77+
"""
78+
Random123.set_counter!(r::TracedRNG, n::Integer) = r.count = n
79+
80+
"""
81+
inc_counter!(r::TracedRNG, n::Integer=1)
82+
83+
Increase the model step counter by `n`
84+
"""
85+
inc_counter!(r::TracedRNG, n::Integer=1) = r.count += n

src/smc.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ function AbstractMCMC.sample(
3838
end
3939

4040
# Create a set of particles.
41-
particles = ParticleContainer([Trace(model) for _ in 1:(sampler.nparticles)])
41+
particles = ParticleContainer(
42+
[Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)], TracedRNG()
43+
)
4244

4345
# Perform particle sweep.
4446
logevidence = sweep!(rng, particles, sampler.resampler)
@@ -83,7 +85,9 @@ function AbstractMCMC.step(
8385
rng::Random.AbstractRNG, model::AbstractMCMC.AbstractModel, sampler::PG; kwargs...
8486
)
8587
# Create a new set of particles.
86-
particles = ParticleContainer([Trace(model) for _ in 1:(sampler.nparticles)])
88+
particles = ParticleContainer(
89+
[Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)], TracedRNG()
90+
)
8791

8892
# Perform a particle sweep.
8993
logevidence = sweep!(rng, particles, sampler.resampler)
@@ -108,10 +112,10 @@ function AbstractMCMC.step(
108112
# Create reference trajectory.
109113
forkr(state.trajectory)
110114
else
111-
Trace(model)
115+
Trace(model, TracedRNG())
112116
end
113117
end
114-
particles = ParticleContainer(x)
118+
particles = ParticleContainer(x, TracedRNG())
115119

116120
# Perform a particle sweep.
117121
logevidence = sweep!(rng, particles, sampler.resampler, particles.vals[nparticles])

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
44
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
55
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6+
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
67
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
78

89
[compat]
910
AbstractMCMC = "2, 3"
1011
Distributions = "0.24, 0.25"
1112
Libtask = "0.5"
1213
julia = "1.3"
14+
Random123 = "1.3"

0 commit comments

Comments
 (0)