Skip to content

Commit 4f8ec74

Browse files
Merge pull request #5 from FredericWantiez/feature/split
Fix types, import and PR review
2 parents 6979bbe + 8ed2100 commit 4f8ec74

File tree

4 files changed

+38
-57
lines changed

4 files changed

+38
-57
lines changed

src/AdvancedPS.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Distributions: Distributions
55
using Libtask: Libtask
66
using Random: Random
77
using StatsFuns: StatsFuns
8+
using Random123: Random123
89

910
include("resampling.jl")
1011
include("rng.jl")

src/container.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
struct Trace{F}
1+
struct Trace{F,U,N,V<:Random123.AbstractR123{U}}
22
f::F
33
ctask::Libtask.CTask
4-
rng::TracedRNG
4+
rng::TracedRNG{U,N,V}
55
end
66

77
const Particle = Trace
@@ -60,7 +60,7 @@ end
6060
# Create new task and copy randomness
6161
function forkr(trace::Trace)
6262
newf = reset_model(trace.f)
63-
set_counter!(trace.rng, 1)
63+
Random123.set_counter!(trace.rng, 1)
6464

6565
ctask = let f = trace.ctask.task.code
6666
Libtask.CTask() do
@@ -95,13 +95,13 @@ Data structure for particle filters
9595
- normalise!(pc::ParticleContainer)
9696
- consume(pc::ParticleContainer): return incremental likelihood
9797
"""
98-
mutable struct ParticleContainer{T<:Particle}
98+
mutable struct ParticleContainer{T<:Particle,U,N,V<:Random123.AbstractR123{U}}
9999
"Particles."
100100
vals::Vector{T}
101101
"Unnormalized logarithmic weights."
102102
logWs::Vector{Float64}
103103
"Traced RNG to replay the resampling step"
104-
rng::TracedRNG
104+
rng::TracedRNG{U,N,V}
105105
end
106106

107107
function ParticleContainer(particles::Vector{<:Particle})
@@ -205,7 +205,7 @@ function update_keys!(pc::ParticleContainer, ref::Union{Particle,Nothing}=nothin
205205
for i in 1:n
206206
pi = pc.vals[i]
207207
k = split(pi.rng.rng.key)
208-
seed!(pi.rng, k[1])
208+
Random.seed!(pi.rng, k[1])
209209
end
210210
end
211211

@@ -255,13 +255,13 @@ function resample_propagate!(
255255
nseeds = isref ? ni - 1 : ni
256256

257257
seeds = split(p.rng.rng.key, nseeds)
258-
!isref && seed!(p.rng, seeds[1])
258+
!isref && Random.seed!(p.rng, seeds[1])
259259

260260
children[j += 1] = p
261261
# fork additional children
262262
for k in 2:ni
263263
part = fork(p, isref)
264-
seed!(part.rng, seeds[k])
264+
Random.seed!(part.rng, seeds[k])
265265
children[j += 1] = part
266266
end
267267
end

src/rng.jl

Lines changed: 28 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +1,84 @@
1-
using Random123
2-
using Random
3-
using Distributions
4-
5-
import Base.rand
6-
import Random.seed!
7-
import Random123: set_counter!
8-
91
# Default RNG type for when nothing is specified
10-
_BASE_RNG = Philox2x
2+
const _BASE_RNG = Random123.Philox2x
113

124
"""
13-
TracedRNG{R,T}
5+
TracedRNG{R,N,T}
146
157
Wrapped random number generator from Random123 to keep track of random streams during model evaluation
168
"""
17-
mutable struct TracedRNG{T} <:
18-
Random.AbstractRNG where {T<:(Random123.AbstractR123{R} where {R})}
9+
mutable struct TracedRNG{R,N,T<:Random123.AbstractR123{R}} <: Random.AbstractRNG
1910
"Model step counter"
2011
count::Int
2112
"Inner RNG"
2213
rng::T
2314
"Array of keys"
24-
keys
15+
keys::Array{R,N}
2516
end
2617

2718
"""
28-
TracedRNG(r::Random123.AbstractR123)
29-
30-
Initialize TracedRNG with r as the inner RNG
19+
TracedRNG(r::Random123.AbstractR123=AdvancedPS._BASE_RNG())
20+
Create a `TracedRNG` with `r` as the inner RNG.
3121
"""
32-
function TracedRNG(r::Random123.AbstractR123)
33-
set_counter!(r, 0)
22+
function TracedRNG(r::Random123.AbstractR123=_BASE_RNG())
23+
Random123.set_counter!(r, 0)
3424
return TracedRNG(1, r, typeof(r.key)[])
3525
end
3626

37-
"""
38-
TracedRNG()
39-
40-
Create a default TracedRNG
41-
"""
42-
function TracedRNG()
43-
r = _BASE_RNG()
44-
return TracedRNG(r)
45-
end
46-
4727
# Connect to the Random API
48-
Random.rng_native_52(rng::TracedRNG{U}) where {U} = Random.rng_native_52(rng.rng)
49-
Base.rand(rng::TracedRNG{U}, ::Type{T}) where {U,T} = Base.rand(rng.rng, T)
28+
Random.rng_native_52(rng::TracedRNG) = Random.rng_native_52(rng.rng)
29+
Base.rand(rng::TracedRNG, ::Type{T}) where {T} = Base.rand(rng.rng, T)
5030

5131
"""
52-
split(key::Integer, n::Integer)
32+
split(key::Integer, n::Integer=1)
5333
54-
Split key into n new keys
34+
Split `key` into `n` new keys
5535
"""
56-
function split(key::Integer, n::Integer=1) where {T}
57-
return map(i -> hash(key, convert(UInt, i)), 1:n)
36+
function split(key::Integer, n::Integer=1)
37+
return [hash(key, i) for i in UInt(1):UInt(n)]
5838
end
5939

6040
"""
61-
load_state!(r::TracedRNG, seed)
41+
load_state!(r::TracedRNG)
6242
6343
Load state from current model iteration. Random streams are now replayed
6444
"""
65-
function load_state!(rng::TracedRNG{T}) where {T}
45+
function load_state!(rng::TracedRNG)
6646
key = rng.keys[rng.count]
6747
Random.seed!(rng.rng, key)
68-
return set_counter!(rng.rng, 0)
48+
return Random123.set_counter!(rng.rng, 0)
6949
end
7050

7151
"""
7252
update_rng!(rng::TracedRNG)
7353
74-
Set key and counter of inner RNG to key and the running model step
54+
Set key and counter of inner rng in `rng` to `key` and the running model step to 0
7555
"""
76-
function seed!(rng::TracedRNG{T}, key) where {T}
77-
seed!(rng.rng, key)
78-
return set_counter!(rng.rng, 0)
56+
function Random.seed!(rng::TracedRNG, key)
57+
Random.seed!(rng.rng, key)
58+
return Random123.set_counter!(rng.rng, 0)
7959
end
8060

8161
"""
8262
save_state!(r::TracedRNG)
8363
84-
Track current key of the inner RNG
64+
Add current key of the inner rng in `r` to `keys`.
8565
"""
86-
function save_state!(r::TracedRNG{T}) where {T}
66+
function save_state!(r::TracedRNG)
8767
return push!(r.keys, r.rng.key)
8868
end
8969

90-
Base.copy(r::TracedRNG{T}) where {T} = TracedRNG(r.count, copy(r.rng), copy(r.keys))
70+
Base.copy(r::TracedRNG) = TracedRNG(r.count, copy(r.rng), deepcopy(r.keys))
9171

9272
"""
93-
set_count!(r::TracedRNG, n::Integer)
73+
set_counter!(r::TracedRNG, n::Integer)
9474
95-
Set the counter of the TracedRNG, used to keep track of the current model step
75+
Set the counter of the inner rng in `r`, used to keep track of the current model step
9676
"""
97-
set_counter!(r::TracedRNG, n::Integer) = r.count = n
77+
Random123.set_counter!(r::TracedRNG, n::Integer) = r.count = n
9878

9979
"""
10080
inc_counter!(r::TracedRNG, n::Integer=1)
10181
102-
Increase the model step counter by n
82+
Increase the model step counter by `n`
10383
"""
10484
inc_counter!(r::TracedRNG, n::Integer=1) = r.count += n

test/rng.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
@test key new_key
2020

21-
AdvancedPS.seed!(rng, new_key)
21+
Random.seed!(rng, new_key)
2222
@test rng.rng.key === new_key
2323
end
2424
end

0 commit comments

Comments
 (0)