Skip to content

Commit 6979bbe

Browse files
Merge pull request #4 from FredericWantiez/feature/split
Fix naming
2 parents d77876f + d709319 commit 6979bbe

File tree

3 files changed

+41
-33
lines changed

3 files changed

+41
-33
lines changed

src/container.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask), deepcopy(trace.rng))
3232
# step to the next observe statement and
3333
# return the log probability of the transition (or nothing if done)
3434
function advance!(t::Trace, isref::Bool)
35-
isref ? load_state(t.rng) : save_state!(t.rng)
36-
inc_count!(t.rng)
35+
isref ? load_state!(t.rng) : save_state!(t.rng)
36+
inc_counter!(t.rng)
3737

3838
# Move to next step
3939
return Libtask.consume(t.ctask)
@@ -60,7 +60,7 @@ end
6060
# Create new task and copy randomness
6161
function forkr(trace::Trace)
6262
newf = reset_model(trace.f)
63-
set_count!(trace.rng, 1)
63+
set_counter!(trace.rng, 1)
6464

6565
ctask = let f = trace.ctask.task.code
6666
Libtask.CTask() do
@@ -100,7 +100,7 @@ mutable struct ParticleContainer{T<:Particle}
100100
vals::Vector{T}
101101
"Unnormalized logarithmic weights."
102102
logWs::Vector{Float64}
103-
"Traced RNG"
103+
"Traced RNG to replay the resampling step"
104104
rng::TracedRNG
105105
end
106106

@@ -204,8 +204,8 @@ function update_keys!(pc::ParticleContainer, ref::Union{Particle,Nothing}=nothin
204204
n = ref === nothing ? nparticles : nparticles - 1
205205
for i in 1:n
206206
pi = pc.vals[i]
207-
k = split(pi.rng, 1)
208-
update_rng!(pi.rng, k[1])
207+
k = split(pi.rng.rng.key)
208+
seed!(pi.rng, k[1])
209209
end
210210
end
211211

@@ -252,15 +252,16 @@ function resample_propagate!(
252252
pi = particles[i]
253253
isref = pi === ref
254254
p = isref ? fork(pi, isref) : pi
255+
nseeds = isref ? ni - 1 : ni
255256

256-
seeds = split(p.rng, ni)
257-
!isref && update_rng!(p.rng, seeds[1])
257+
seeds = split(p.rng.rng.key, nseeds)
258+
!isref && seed!(p.rng, seeds[1])
258259

259260
children[j += 1] = p
260261
# fork additional children
261262
for k in 2:ni
262263
part = fork(p, isref)
263-
update_rng!(part.rng, seeds[k])
264+
seed!(part.rng, seeds[k])
264265
children[j += 1] = part
265266
end
266267
end

src/rng.jl

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Distributions
44

55
import Base.rand
66
import Random.seed!
7+
import Random123: set_counter!
78

89
# Default RNG type for when nothing is specified
910
_BASE_RNG = Philox2x
@@ -29,7 +30,7 @@ end
2930
Initialize TracedRNG with r as the inner RNG
3031
"""
3132
function TracedRNG(r::Random123.AbstractR123)
32-
set_counter!(r, 1)
33+
set_counter!(r, 0)
3334
return TracedRNG(1, r, typeof(r.key)[])
3435
end
3536

@@ -48,42 +49,33 @@ Random.rng_native_52(rng::TracedRNG{U}) where {U} = Random.rng_native_52(rng.rng
4849
Base.rand(rng::TracedRNG{U}, ::Type{T}) where {U,T} = Base.rand(rng.rng, T)
4950

5051
"""
51-
split(r::TracedRNG, n::Integer)
52+
split(key::Integer, n::Integer)
5253
53-
Split inner RNG into n new TracedRNG
54+
Split key into n new keys
5455
"""
55-
function split(r::TracedRNG{T}, n::Integer) where {T}
56-
return map(i -> hash(r.rng.key, convert(UInt, r.rng.ctr1 + i)), 1:n)
56+
function split(key::Integer, n::Integer=1) where {T}
57+
return map(i -> hash(key, convert(UInt, i)), 1:n)
5758
end
5859

5960
"""
60-
seed!(r::TracedRNG, seed::Number)
61-
62-
Set the key of the inner RNG as `seed`
63-
"""
64-
function seed!(r::TracedRNG{T}, seed) where {T}
65-
return seed!(r.rng, seed)
66-
end
67-
68-
"""
69-
load_state(r::TracedRNG, seed)
61+
load_state!(r::TracedRNG, seed)
7062
7163
Load state from current model iteration. Random streams are now replayed
7264
"""
73-
function load_state(rng::TracedRNG{T}) where {T}
65+
function load_state!(rng::TracedRNG{T}) where {T}
7466
key = rng.keys[rng.count]
7567
Random.seed!(rng.rng, key)
76-
return set_counter!(rng.rng, rng.count)
68+
return set_counter!(rng.rng, 0)
7769
end
7870

7971
"""
8072
update_rng!(rng::TracedRNG)
8173
8274
Set key and counter of inner RNG to key and the running model step
8375
"""
84-
function update_rng!(rng::TracedRNG{T}, key) where {T}
85-
seed!(rng, key)
86-
return set_counter!(rng.rng, rng.count)
76+
function seed!(rng::TracedRNG{T}, key) where {T}
77+
seed!(rng.rng, key)
78+
return set_counter!(rng.rng, 0)
8779
end
8880

8981
"""
@@ -102,7 +94,11 @@ Base.copy(r::TracedRNG{T}) where {T} = TracedRNG(r.count, copy(r.rng), copy(r.ke
10294
10395
Set the counter of the TracedRNG, used to keep track of the current model step
10496
"""
105-
set_count!(r::TracedRNG, n::Integer) = r.count = n
97+
set_counter!(r::TracedRNG, n::Integer) = r.count = n
98+
99+
"""
100+
inc_counter!(r::TracedRNG, n::Integer=1)
106101
107-
inc_count!(r::TracedRNG, n::Integer) = r.count += n
108-
inc_count!(r::TracedRNG) = inc_count!(r, 1)
102+
Increase the model step counter by n
103+
"""
104+
inc_counter!(r::TracedRNG, n::Integer=1) = r.count += n

test/rng.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,19 @@
66

77
rand(rng, Distributions.Normal())
88

9-
AdvancedPS.load_state(rng)
9+
AdvancedPS.load_state!(rng)
1010
new_vns = rand(rng, Distributions.Normal())
1111
@test new_vns vns
1212
end
13+
14+
@testset "split" begin
15+
rng = AdvancedPS.TracedRNG()
16+
key = rng.rng.key
17+
new_key, = AdvancedPS.split(key, 1)
18+
19+
@test key new_key
20+
21+
AdvancedPS.seed!(rng, new_key)
22+
@test rng.rng.key === new_key
23+
end
1324
end

0 commit comments

Comments
 (0)