@@ -391,7 +391,6 @@ function mcmcsample(
391
391
392
392
# Copy the random number generator, model, and sample for each thread
393
393
nchunks = min (nchains, Threads. nthreads ())
394
- chunksize = cld (nchains, nchunks)
395
394
interval = 1 : nchunks
396
395
# `copy` instead of `deepcopy` for RNGs: https://github.com/JuliaLang/julia/issues/42899
397
396
rngs = [copy (rng) for _ in interval]
@@ -434,15 +433,24 @@ function mcmcsample(
434
433
end
435
434
end
436
435
436
+ # If nchains/nchunks = m with remainder n, then the first n chunks
437
+ # will have m + 1 chains, and the rest will have m chains
438
+ m, n = divrem (nchains, nchunks)
439
+ chain_index_groups = UnitRange{Int}[]
440
+ current_index = 1
441
+ for i in interval
442
+ nchains_this_chunk = i <= n ? m + 1 : m
443
+ push! (
444
+ chain_index_groups,
445
+ current_index: (current_index + nchains_this_chunk - 1 ),
446
+ )
447
+ current_index += nchains_this_chunk
448
+ end
449
+
437
450
Distributed. @async begin
438
451
try
439
- Distributed. @sync for (i, _rng, _model, _sampler) in
440
- zip (1 : nchunks, rngs, models, samplers)
441
- chainidxs = if i == nchunks
442
- ((i - 1 ) * chunksize + 1 ): nchains
443
- else
444
- ((i - 1 ) * chunksize + 1 ): (i * chunksize)
445
- end
452
+ Distributed. @sync for (chainidxs, _rng, _model, _sampler) in
453
+ zip (chain_index_groups, rngs, models, samplers)
446
454
Threads. @spawn for chainidx in chainidxs
447
455
# Seed the chunk-specific random number generator with the pre-made seed.
448
456
Random. seed! (_rng, seeds[chainidx])
0 commit comments