Skip to content

Commit 81e0fe7

Browse files
committed
Calculate chainidxs inside the loop
1 parent ec621b7 commit 81e0fe7

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

src/sample.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -397,17 +397,9 @@ function mcmcsample(
397397
models = [deepcopy(model) for _ in interval]
398398
samplers = [deepcopy(sampler) for _ in interval]
399399

400-
# Distribute chains amongst the chunks. If nchains/nchunks = m with
401-
# remainder n, then the first n chunks will have m + 1 chains, and the rest
402-
# will have m chains.
400+
# If nchains/nchunks = m with remainder n, then the first n chunks will
401+
# have m + 1 chains, and the rest will have m chains.
403402
m, n = divrem(nchains, nchunks)
404-
chain_index_groups = UnitRange{Int}[]
405-
current_index = 1
406-
for i in interval
407-
nchains_this_chunk = i <= n ? m + 1 : m
408-
push!(chain_index_groups, current_index:(current_index + nchains_this_chunk - 1))
409-
current_index += nchains_this_chunk
410-
end
411403

412404
# Create a seed for each chain using the provided random number generator.
413405
seeds = rand(rng, UInt, nchains)
@@ -447,8 +439,18 @@ function mcmcsample(
447439

448440
Distributed.@async begin
449441
try
450-
Distributed.@sync for (chainidxs, _rng, _model, _sampler) in
451-
zip(chain_index_groups, rngs, models, samplers)
442+
Distributed.@sync for (i, _rng, _model, _sampler) in
443+
zip(interval, rngs, models, samplers)
444+
if i <= n
445+
chainidx_hi = i * (m + 1)
446+
nchains_chunk = m + 1
447+
else
448+
chainidx_hi = n * (m + 1) + (i - n) * m
449+
nchains_chunk = m
450+
end
451+
chainidx_lo = chainidx_hi - nchains_chunk + 1
452+
chainidxs = chainidx_lo:chainidx_hi
453+
452454
Threads.@spawn for chainidx in chainidxs
453455
# Seed the chunk-specific random number generator with the pre-made seed.
454456
Random.seed!(_rng, seeds[chainidx])

0 commit comments

Comments
 (0)