@@ -391,13 +391,24 @@ function mcmcsample(
391
391
392
392
# Copy the random number generator, model, and sample for each thread
393
393
nchunks = min (nchains, Threads. nthreads ())
394
- chunksize = cld (nchains, nchunks)
395
394
interval = 1 : nchunks
396
395
# `copy` instead of `deepcopy` for RNGs: https://github.com/JuliaLang/julia/issues/42899
397
396
rngs = [copy (rng) for _ in interval]
398
397
models = [deepcopy (model) for _ in interval]
399
398
samplers = [deepcopy (sampler) for _ in interval]
400
399
400
+ # Distribute chains amongst the chunks. If nchains/nchunks = m with
401
+ # remainder n, then the first n chunks will have m + 1 chains, and the rest
402
+ # will have m chains.
403
+ m, n = divrem (nchains, nchunks)
404
+ chain_index_groups = UnitRange{Int}[]
405
+ current_index = 1
406
+ for i in interval
407
+ nchains_this_chunk = i <= n ? m + 1 : m
408
+ push! (chain_index_groups, current_index: (current_index + nchains_this_chunk - 1 ))
409
+ current_index += nchains_this_chunk
410
+ end
411
+
401
412
# Create a seed for each chain using the provided random number generator.
402
413
seeds = rand (rng, UInt, nchains)
403
414
@@ -436,13 +447,8 @@ function mcmcsample(
436
447
437
448
Distributed. @async begin
438
449
try
439
- Distributed. @sync for (i, _rng, _model, _sampler) in
440
- zip (1 : nchunks, rngs, models, samplers)
441
- chainidxs = if i == nchunks
442
- ((i - 1 ) * chunksize + 1 ): nchains
443
- else
444
- ((i - 1 ) * chunksize + 1 ): (i * chunksize)
445
- end
450
+ Distributed. @sync for (chainidxs, _rng, _model, _sampler) in
451
+ zip (chain_index_groups, rngs, models, samplers)
446
452
Threads. @spawn for chainidx in chainidxs
447
453
# Seed the chunk-specific random number generator with the pre-made seed.
448
454
Random. seed! (_rng, seeds[chainidx])
0 commit comments