Skip to content

Commit ec621b7

Browse files
committed
Fix indexing for chains in different threads
1 parent 5a3b155 commit ec621b7

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probabilistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "5.6.0"
6+
version = "5.6.1"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"

src/sample.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -391,13 +391,24 @@ function mcmcsample(
391391

392392
# Copy the random number generator, model, and sample for each thread
393393
nchunks = min(nchains, Threads.nthreads())
394-
chunksize = cld(nchains, nchunks)
395394
interval = 1:nchunks
396395
# `copy` instead of `deepcopy` for RNGs: https://github.com/JuliaLang/julia/issues/42899
397396
rngs = [copy(rng) for _ in interval]
398397
models = [deepcopy(model) for _ in interval]
399398
samplers = [deepcopy(sampler) for _ in interval]
400399

400+
# Distribute chains amongst the chunks. If nchains/nchunks = m with
401+
# remainder n, then the first n chunks will have m + 1 chains, and the rest
402+
# will have m chains.
403+
m, n = divrem(nchains, nchunks)
404+
chain_index_groups = UnitRange{Int}[]
405+
current_index = 1
406+
for i in interval
407+
nchains_this_chunk = i <= n ? m + 1 : m
408+
push!(chain_index_groups, current_index:(current_index + nchains_this_chunk - 1))
409+
current_index += nchains_this_chunk
410+
end
411+
401412
# Create a seed for each chain using the provided random number generator.
402413
seeds = rand(rng, UInt, nchains)
403414

@@ -436,13 +447,8 @@ function mcmcsample(
436447

437448
Distributed.@async begin
438449
try
439-
Distributed.@sync for (i, _rng, _model, _sampler) in
440-
zip(1:nchunks, rngs, models, samplers)
441-
chainidxs = if i == nchunks
442-
((i - 1) * chunksize + 1):nchains
443-
else
444-
((i - 1) * chunksize + 1):(i * chunksize)
445-
end
450+
Distributed.@sync for (chainidxs, _rng, _model, _sampler) in
451+
zip(chain_index_groups, rngs, models, samplers)
446452
Threads.@spawn for chainidx in chainidxs
447453
# Seed the chunk-specific random number generator with the pre-made seed.
448454
Random.seed!(_rng, seeds[chainidx])

0 commit comments

Comments
 (0)