@@ -298,16 +298,15 @@ function mcmcsample(
298
298
end
299
299
300
300
# Copy the random number generator, model, and sample for each thread
301
- # NOTE: As of May 17, 2020, this relies on Julia's thread scheduling functionality
302
- # that distributes a for loop into equal-sized blocks and allocates them
303
- # to each thread. If this changes, we may need to rethink things here.
301
+ nchunks = min (nchains, Threads. nthreads ())
302
+ chunksize = cld (nchains, nchunks)
304
303
interval = 1 : min (nchains, Threads. nthreads ())
305
304
rngs = [deepcopy (rng) for _ in interval]
306
305
models = [deepcopy (model) for _ in interval]
307
306
samplers = [deepcopy (sampler) for _ in interval]
308
307
309
- # Create a seed for each chain using the provided random number generator.
310
- seeds = rand (rng, UInt, nchains )
308
+ # Create a seed for each chunk using the provided random number generator.
309
+ seeds = rand (rng, UInt, nchunks )
311
310
312
311
# Set up a chains vector.
313
312
chains = Vector {Any} (undef, nchains)
@@ -340,20 +339,26 @@ function mcmcsample(
340
339
341
340
Distributed. @async begin
342
341
try
343
- Threads. @threads for i in 1 : nchains
344
- # Obtain the ID of the current thread.
345
- id = Threads. threadid ()
346
-
347
- # Seed the thread-specific random number generator with the pre-made seed.
348
- subrng = rngs[id]
349
- Random. seed! (subrng, seeds[i])
350
-
351
- # Sample a chain and save it to the vector.
352
- chains[i] = StatsBase. sample (subrng, models[id], samplers[id], N;
353
- progress = false , kwargs... )
354
-
355
- # Update the progress bar.
356
- progress && put! (channel, true )
342
+ Distributed. @sync for (i, _rng, seed, _model, _sampler) in zip (1 : nchunks, rngs, seeds, models, samplers)
343
+ Threads. @spawn begin
344
+ # Seed the chunk-specific random number generator with the pre-made seed.
345
+ Random. seed! (_rng, seed)
346
+
347
+ chainidxs = if i == nchunks
348
+ ((i - 1 ) * chunksize + 1 ): nchains
349
+ else
350
+ ((i - 1 ) * chunksize + 1 ): (i * chunksize)
351
+ end
352
+
353
+ for chainidx in chainidxs
354
+ # Sample a chain and save it to the vector.
355
+ chains[chainidx] = StatsBase. sample (_rng, _model, _sampler, N;
356
+ progress = false , kwargs... )
357
+
358
+ # Update the progress bar.
359
+ progress && put! (channel, true )
360
+ end
361
+ end
357
362
end
358
363
finally
359
364
# Stop updating the progress bar.
0 commit comments