@@ -305,8 +305,8 @@ function mcmcsample(
305
305
models = [deepcopy (model) for _ in interval]
306
306
samplers = [deepcopy (sampler) for _ in interval]
307
307
308
- # Create a seed for each chunk using the provided random number generator.
309
- seeds = rand (rng, UInt, nchunks )
308
+ # Create a seed for each chain using the provided random number generator.
309
+ seeds = rand (rng, UInt, nchains )
310
310
311
311
# Set up a chains vector.
312
312
chains = Vector {Any} (undef, nchains)
@@ -339,25 +339,22 @@ function mcmcsample(
339
339
340
340
Distributed. @async begin
341
341
try
342
- Distributed. @sync for (i, _rng, seed, _model, _sampler) in zip (1 : nchunks, rngs, seeds, models, samplers)
343
- Threads. @spawn begin
342
+ Distributed. @sync for (i, _rng, _model, _sampler) in zip (1 : nchunks, rngs, models, samplers)
343
+ chainidxs = if i == nchunks
344
+ ((i - 1 ) * chunksize + 1 ): nchains
345
+ else
346
+ ((i - 1 ) * chunksize + 1 ): (i * chunksize)
347
+ end
348
+ Threads. @spawn for chainidx in chainidxs
344
349
# Seed the chunk-specific random number generator with the pre-made seed.
345
- Random. seed! (_rng, seed)
346
-
347
- chainidxs = if i == nchunks
348
- ((i - 1 ) * chunksize + 1 ): nchains
349
- else
350
- ((i - 1 ) * chunksize + 1 ): (i * chunksize)
351
- end
352
-
353
- for chainidx in chainidxs
354
- # Sample a chain and save it to the vector.
355
- chains[chainidx] = StatsBase. sample (_rng, _model, _sampler, N;
356
- progress = false , kwargs... )
357
-
358
- # Update the progress bar.
359
- progress && put! (channel, true )
360
- end
350
+ Random. seed! (_rng, seeds[chainidx])
351
+
352
+ # Sample a chain and save it to the vector.
353
+ chains[chainidx] = StatsBase. sample (_rng, _model, _sampler, N;
354
+ progress = false , kwargs... )
355
+
356
+ # Update the progress bar.
357
+ progress && put! (channel, true )
361
358
end
362
359
end
363
360
finally
@@ -469,12 +466,18 @@ function mcmcsample(
469
466
@warn " Number of chains ($nchains ) is greater than number of samples per chain ($N )"
470
467
end
471
468
469
+ # Create a seed for each chain using the provided random number generator.
470
+ seeds = rand (rng, UInt, nchains)
471
+
472
472
# Sample the chains.
473
- chains = map (
474
- i -> StatsBase. sample (rng, model, sampler, N; progressname = string (progressname, " (Chain " , i, " of " , nchains, " )" ),
475
- kwargs... ),
476
- 1 : nchains
477
- )
473
+ chains = map (enumerate (seeds)) do (i, seed)
474
+ Random. seed! (rng, seed)
475
+ return StatsBase. sample (
476
+ rng, model, sampler, N;
477
+ progressname = string (progressname, " (Chain " , i, " of " , nchains, " )" ),
478
+ kwargs... ,
479
+ )
480
+ end
478
481
479
482
# Concatenate the chains together.
480
483
return chainsstack (tighten_eltype (chains))
0 commit comments