Skip to content

Commit d194f86

Browse files
committed
Fix indexing for chains in different threads
1 parent 5a3b155 commit d194f86

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probabilistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "5.6.0"
6+
version = "5.6.1"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"

src/sample.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,6 @@ function mcmcsample(
391391

392392
# Copy the random number generator, model, and sample for each thread
393393
nchunks = min(nchains, Threads.nthreads())
394-
chunksize = cld(nchains, nchunks)
395394
interval = 1:nchunks
396395
# `copy` instead of `deepcopy` for RNGs: https://github.com/JuliaLang/julia/issues/42899
397396
rngs = [copy(rng) for _ in interval]
@@ -434,15 +433,24 @@ function mcmcsample(
434433
end
435434
end
436435

436+
# If nchains/nchunks = m with remainder n, then the first n chunks
437+
# will have m + 1 chains, and the rest will have m chains
438+
m, n = divrem(nchains, nchunks)
439+
chain_index_groups = UnitRange{Int}[]
440+
current_index = 1
441+
for i in interval
442+
nchains_this_chunk = i <= n ? m + 1 : m
443+
push!(
444+
chain_index_groups,
445+
current_index:(current_index + nchains_this_chunk - 1),
446+
)
447+
current_index += nchains_this_chunk
448+
end
449+
437450
Distributed.@async begin
438451
try
439-
Distributed.@sync for (i, _rng, _model, _sampler) in
440-
zip(1:nchunks, rngs, models, samplers)
441-
chainidxs = if i == nchunks
442-
((i - 1) * chunksize + 1):nchains
443-
else
444-
((i - 1) * chunksize + 1):(i * chunksize)
445-
end
452+
Distributed.@sync for (chainidxs, _rng, _model, _sampler) in
453+
zip(chain_index_groups, rngs, models, samplers)
446454
Threads.@spawn for chainidx in chainidxs
447455
# Seed the chunk-specific random number generator with the pre-made seed.
448456
Random.seed!(_rng, seeds[chainidx])

0 commit comments

Comments
 (0)