Skip to content

Commit cd12c8c

Browse files
Format
1 parent 156afff commit cd12c8c

File tree

5 files changed

+19
-14
lines changed

5 files changed

+19
-14
lines changed

src/container.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,17 @@ end
66

77
const Particle = Trace
88

9-
function Trace(f, rng::Random.AbstractRNG)
10-
trng = TracedRNG(rng)
11-
9+
function Trace(f, rng::TracedRNG)
1210
ctask = let f = f
1311
Libtask.CTask() do
14-
res = f(trng)
12+
res = f(rng)
1513
Libtask.produce(nothing)
1614
return res
1715
end
1816
end
1917

2018
# add backward reference
21-
newtrace = Trace(f, ctask, trng)
19+
newtrace = Trace(f, ctask, rng)
2220
addreference!(ctask.task, newtrace)
2321

2422
return newtrace

src/rng.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@ struct TracedRNG{T} <: Random.AbstractRNG where {T<:Random.AbstractRNG}
88
seed::Any
99
end
1010

11-
1211
# Set seed manually, for init ?
13-
Random.seed!(rng::TracedRNG, seed) = Random.seed!(rng.rng, seed)
12+
function Random.seed!(rng::TracedRNG, seed)
13+
rng.rng.seed = seed
14+
return Random.seed!(rng.rng, seed)
15+
end
16+
1417
# Reset the rng to the initial seed
1518
Random.seed!(rng::TracedRNG) = Random.seed!(rng.rng, rng.seed)
1619

src/smc.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@ function AbstractMCMC.sample(
3535
)
3636
if !isempty(kwargs)
3737
@warn "keyword arguments $(keys(kwargs)) are not supported by `SMC`"
38-
end
38+
end
3939

4040
# Create a set of particles.
41-
particles = ParticleContainer([Trace(model, TracedRNG()) for _ in 1:sampler.nparticles])
41+
particles = ParticleContainer([
42+
Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)
43+
])
4244

4345
# Perform particle sweep.
4446
logevidence = sweep!(rng, particles, sampler.resampler)
@@ -83,7 +85,9 @@ function AbstractMCMC.step(
8385
rng::Random.AbstractRNG, model::AbstractMCMC.AbstractModel, sampler::PG; kwargs...
8486
)
8587
# Create a new set of particles.
86-
particles = ParticleContainer([Trace(model, TracedRNG()) for _ in 1:sampler.nparticles])
88+
particles = ParticleContainer([
89+
Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)
90+
])
8791

8892
# Perform a particle sweep.
8993
logevidence = sweep!(rng, particles, sampler.resampler)

test/container.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
# Create particle container.
2424
logps = [0.0, -1.0, -2.0]
25-
particles = [AdvancedPS.Trace(fpc(logp), Random.MersenneTwister()) for logp in logps]
25+
particles = [AdvancedPS.Trace(fpc(logp), AdvancedPS.TracedRNG()) for logp in logps]
2626
pc = AdvancedPS.ParticleContainer(particles)
2727

2828
# Initial state.
@@ -94,7 +94,7 @@
9494
end
9595

9696
# Test task copy version of trace
97-
tr = AdvancedPS.Trace(f2, Random.MersenneTwister())
97+
tr = AdvancedPS.Trace(f2, AdvancedPS.TracedRNG())
9898

9999
consume(tr.ctask)
100100
consume(tr.ctask)

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using Test
1515
@testset "SMC and PG tests" begin
1616
include("smc.jl")
1717
end
18-
@testset "RNG tests" begin
19-
include("rng.jl")
18+
@testset "RNG tests" begin
19+
include("rng.jl")
2020
end
2121
end

0 commit comments

Comments
 (0)