Skip to content

Commit 05f91cf

Browse files
authored
Fix reproducibility of ensemble sampling
1 parent 6dae58b commit 05f91cf

File tree

1 file changed

+28
-25
lines changed

1 file changed

+28
-25
lines changed

src/sample.jl

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,8 @@ function mcmcsample(
305305
models = [deepcopy(model) for _ in interval]
306306
samplers = [deepcopy(sampler) for _ in interval]
307307

308-
# Create a seed for each chunk using the provided random number generator.
309-
seeds = rand(rng, UInt, nchunks)
308+
# Create a seed for each chain using the provided random number generator.
309+
seeds = rand(rng, UInt, nchains)
310310

311311
# Set up a chains vector.
312312
chains = Vector{Any}(undef, nchains)
@@ -339,25 +339,22 @@ function mcmcsample(
339339

340340
Distributed.@async begin
341341
try
342-
Distributed.@sync for (i, _rng, seed, _model, _sampler) in zip(1:nchunks, rngs, seeds, models, samplers)
343-
Threads.@spawn begin
342+
Distributed.@sync for (i, _rng, _model, _sampler) in zip(1:nchunks, rngs, models, samplers)
343+
chainidxs = if i == nchunks
344+
((i - 1) * chunksize + 1):nchains
345+
else
346+
((i - 1) * chunksize + 1):(i * chunksize)
347+
end
348+
Threads.@spawn for chainidx in chainidxs
344349
# Seed the chunk-specific random number generator with the pre-made seed.
345-
Random.seed!(_rng, seed)
346-
347-
chainidxs = if i == nchunks
348-
((i - 1) * chunksize + 1):nchains
349-
else
350-
((i - 1) * chunksize + 1):(i * chunksize)
351-
end
352-
353-
for chainidx in chainidxs
354-
# Sample a chain and save it to the vector.
355-
chains[chainidx] = StatsBase.sample(_rng, _model, _sampler, N;
356-
progress = false, kwargs...)
357-
358-
# Update the progress bar.
359-
progress && put!(channel, true)
360-
end
350+
Random.seed!(_rng, seeds[chainidx])
351+
352+
# Sample a chain and save it to the vector.
353+
chains[chainidx] = StatsBase.sample(_rng, _model, _sampler, N;
354+
progress = false, kwargs...)
355+
356+
# Update the progress bar.
357+
progress && put!(channel, true)
361358
end
362359
end
363360
finally
@@ -469,12 +466,18 @@ function mcmcsample(
469466
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
470467
end
471468

469+
# Create a seed for each chain using the provided random number generator.
470+
seeds = rand(rng, UInt, nchains)
471+
472472
# Sample the chains.
473-
chains = map(
474-
i -> StatsBase.sample(rng, model, sampler, N; progressname = string(progressname, " (Chain ", i, " of ", nchains, ")"),
475-
kwargs...),
476-
1:nchains
477-
)
473+
chains = map(enumerate(seeds)) do (i, seed)
474+
Random.seed!(rng, seed)
475+
return StatsBase.sample(
476+
rng, model, sampler, N;
477+
progressname = string(progressname, " (Chain ", i, " of ", nchains, ")"),
478+
kwargs...,
479+
)
480+
end
478481

479482
# Concatenate the chains together.
480483
return chainsstack(tighten_eltype(chains))

0 commit comments

Comments
 (0)