@@ -397,17 +397,9 @@ function mcmcsample(
397
397
models = [deepcopy (model) for _ in interval]
398
398
samplers = [deepcopy (sampler) for _ in interval]
399
399
400
- # Distribute chains amongst the chunks. If nchains/nchunks = m with
401
- # remainder n, then the first n chunks will have m + 1 chains, and the rest
402
- # will have m chains.
400
+ # If nchains/nchunks = m with remainder n, then the first n chunks will
401
+ # have m + 1 chains, and the rest will have m chains.
403
402
m, n = divrem (nchains, nchunks)
404
- chain_index_groups = UnitRange{Int}[]
405
- current_index = 1
406
- for i in interval
407
- nchains_this_chunk = i <= n ? m + 1 : m
408
- push! (chain_index_groups, current_index: (current_index + nchains_this_chunk - 1 ))
409
- current_index += nchains_this_chunk
410
- end
411
403
412
404
# Create a seed for each chain using the provided random number generator.
413
405
seeds = rand (rng, UInt, nchains)
@@ -447,8 +439,18 @@ function mcmcsample(
447
439
448
440
Distributed. @async begin
449
441
try
450
- Distributed. @sync for (chainidxs, _rng, _model, _sampler) in
451
- zip (chain_index_groups, rngs, models, samplers)
442
+ Distributed. @sync for (i, _rng, _model, _sampler) in
443
+ zip (interval, rngs, models, samplers)
444
+ if i <= n
445
+ chainidx_hi = i * (m + 1 )
446
+ nchains_chunk = m + 1
447
+ else
448
+ chainidx_hi = n * (m + 1 ) + (i - n) * m
449
+ nchains_chunk = m
450
+ end
451
+ chainidx_lo = chainidx_hi - nchains_chunk + 1
452
+ chainidxs = chainidx_lo: chainidx_hi
453
+
452
454
Threads. @spawn for chainidx in chainidxs
453
455
# Seed the chunk-specific random number generator with the pre-made seed.
454
456
Random. seed! (_rng, seeds[chainidx])
0 commit comments