Skip to content

Commit 5cbbe40

Browse files
Track Particle container
1 parent cd12c8c commit 5cbbe40

File tree

3 files changed

+26
-16
lines changed

3 files changed

+26
-16
lines changed

src/container.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,23 @@ Data structure for particle filters
8989
- normalise!(pc::ParticleContainer)
9090
- consume(pc::ParticleContainer): return incremental likelihood
9191
"""
92-
mutable struct ParticleContainer{T<:Particle}
92+
mutable struct ParticleContainer{T<:Particle,R<:Random.AbstractRNG}
9393
"Particles."
9494
vals::Vector{T}
9595
"Unnormalized logarithmic weights."
9696
logWs::Vector{Float64}
97+
"TracedRNG to track the resampling step"
98+
rng::TracedRNG{R}
9799
end
98100

99101
function ParticleContainer(particles::Vector{<:Particle})
100-
return ParticleContainer(particles, zeros(length(particles)))
102+
return ParticleContainer(particles, zeros(length(particles)), TracedRNG())
103+
end
104+
105+
function ParticleContainer(
106+
particles::Vector{<:Particle}, rng::T
107+
) where {T<:Random.AbstractRNG}
108+
return ParticleContainer(particles, zeros(length(particles)), TracedRNG(rng))
101109
end
102110

103111
Base.collect(pc::ParticleContainer) = pc.vals
@@ -124,7 +132,7 @@ function Base.copy(pc::ParticleContainer)
124132
# copy weights
125133
logWs = copy(pc.logWs)
126134

127-
return ParticleContainer(vals, logWs)
135+
return ParticleContainer(vals, logWs, pc.rng)
128136
end
129137

130138
"""

src/rng.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
Data structure to keep track of the history of the random stream
33
produced by RNG.
44
"""
5-
struct TracedRNG{T} <: Random.AbstractRNG where {T<:Random.AbstractRNG}
5+
mutable struct TracedRNG{T} <: Random.AbstractRNG where {T<:Random.AbstractRNG}
66
count::Base.RefValue{Int}
77
rng::T
8-
seed::Any
8+
seed::Array
9+
states::Array{T}
910
end
1011

1112
# Set seed manually, for init ?
@@ -18,7 +19,7 @@ end
1819
Random.seed!(rng::TracedRNG) = Random.seed!(rng.rng, rng.seed)
1920

2021
TracedRNG() = TracedRNG(Random.MersenneTwister()) # Pick up an explicit RNG from Random
21-
TracedRNG(rng::Random.AbstractRNG) = TracedRNG(Ref(0), rng, rng.seed)
22+
TracedRNG(rng::Random.AbstractRNG) = TracedRNG(Ref(0), rng, rng.seed, [rng])
2223
TracedRNG(rng::Random._GLOBAL_RNG) = TracedRNG(Random.default_rng())
2324

2425
# Intercept rand
@@ -28,6 +29,7 @@ Random.rng_native_52(r::TracedRNG) = UInt64
2829
function Base.rand(rng::TracedRNG, ::Type{T}) where {T}
2930
res = Base.rand(rng.rng, T)
3031
inc_count!(rng, length(res))
32+
push!(rng.states, copy(rng.rng))
3133
return res
3234
end
3335

src/smc.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ function AbstractMCMC.sample(
3838
end
3939

4040
# Create a set of particles.
41-
particles = ParticleContainer([
42-
Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)
43-
])
41+
particles = ParticleContainer(
42+
[Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)], rng
43+
)
4444

4545
# Perform particle sweep.
46-
logevidence = sweep!(rng, particles, sampler.resampler)
46+
logevidence = sweep!(particles.rng, particles, sampler.resampler)
4747

4848
return SMCSample(collect(particles), getweights(particles), logevidence)
4949
end
@@ -85,12 +85,12 @@ function AbstractMCMC.step(
8585
rng::Random.AbstractRNG, model::AbstractMCMC.AbstractModel, sampler::PG; kwargs...
8686
)
8787
# Create a new set of particles.
88-
particles = ParticleContainer([
89-
Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)
90-
])
88+
particles = ParticleContainer(
89+
[Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)], rng
90+
)
9191

9292
# Perform a particle sweep.
93-
logevidence = sweep!(rng, particles, sampler.resampler)
93+
logevidence = sweep!(particles.rng, particles, sampler.resampler)
9494

9595
# Pick a particle to be retained.
9696
trajectory = rand(rng, particles)
@@ -115,10 +115,10 @@ function AbstractMCMC.step(
115115
Trace(model, TracedRNG())
116116
end
117117
end
118-
particles = ParticleContainer(x)
118+
particles = ParticleContainer(x, rng)
119119

120120
# Perform a particle sweep.
121-
logevidence = sweep!(rng, particles, sampler.resampler)
121+
logevidence = sweep!(particles.rng, particles, sampler.resampler)
122122

123123
# Pick a particle to be retained.
124124
newtrajectory = rand(rng, particles)

0 commit comments

Comments
 (0)