Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probabilistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "5.6.0"
version = "5.6.1"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
18 changes: 13 additions & 5 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -391,13 +391,16 @@ function mcmcsample(

# Copy the random number generator, model, and sample for each thread
nchunks = min(nchains, Threads.nthreads())
chunksize = cld(nchains, nchunks)
interval = 1:nchunks
# `copy` instead of `deepcopy` for RNGs: https://github.com/JuliaLang/julia/issues/42899
rngs = [copy(rng) for _ in interval]
models = [deepcopy(model) for _ in interval]
samplers = [deepcopy(sampler) for _ in interval]

# If nchains/nchunks = m with remainder n, then the first n chunks will
# have m + 1 chains, and the rest will have m chains.
m, n = divrem(nchains, nchunks)

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

Expand Down Expand Up @@ -437,12 +440,17 @@ function mcmcsample(
Distributed.@async begin
try
Distributed.@sync for (i, _rng, _model, _sampler) in
zip(1:nchunks, rngs, models, samplers)
chainidxs = if i == nchunks
((i - 1) * chunksize + 1):nchains
zip(interval, rngs, models, samplers)
if i <= n
chainidx_hi = i * (m + 1)
nchains_chunk = m + 1
else
((i - 1) * chunksize + 1):(i * chunksize)
chainidx_hi = n * (m + 1) + (i - n) * m
nchains_chunk = m
end
chainidx_lo = chainidx_hi - nchains_chunk + 1
chainidxs = chainidx_lo:chainidx_hi

Threads.@spawn for chainidx in chainidxs
# Seed the chunk-specific random number generator with the pre-made seed.
Random.seed!(_rng, seeds[chainidx])
Expand Down
Loading