@@ -544,34 +544,50 @@ function mcmcsample(
544
544
pool = Distributed. CachingPool (Distributed. workers ())
545
545
546
546
local chains
547
- # Create a channel for progress logging.
548
- if progress
549
- channel = Distributed. RemoteChannel (() -> Channel {Bool} (Distributed. nworkers ()))
547
+
548
+ # Create overall progress logging object (tracks number of chains completed)
549
+ overall_progress_obj = ProgressMeter. Progress (
550
+ nchains; desc= progressname, dt= 0.0 , enabled= progress
551
+ )
552
+ # ProgressMeter doesn't start printing until the second iteration or so. This
553
+ # forces it to start printing an empty progress bar immediately.
554
+ # https://github.com/timholy/ProgressMeter.jl/issues/288
555
+ ProgressMeter. update! (overall_progress_obj, 0 ; force= true )
556
+ # Create per-chain progress logging objects
557
+ progress_objs = [
558
+ ProgressMeter. Progress (
559
+ N; desc= " Chain $i /$nchains " , dt= 0.01 , enabled= progress, offset= i
560
+ ) for i in 1 : nchains
561
+ ]
562
+ for obj in progress_objs
563
+ ProgressMeter. update! (obj, 0 ; force= true )
550
564
end
565
+ # Create a channel to synchronise progress updates.
566
+ channel = Distributed. RemoteChannel (() -> Channel {Tuple{Int,Bool}} (), 1 )
551
567
552
568
Distributed. @sync begin
553
- if progress
554
- # Update the progress bar.
555
- Distributed . @async begin
556
- # Determine threshold values for progress logging
557
- # (one update per 0.5% of progress)
558
- threshold = nchains ÷ 200
559
- nextprogresschains = threshold
560
-
561
- progresschains = 0
562
- while take! (channel)
563
- progresschains += 1
564
- if progresschains >= nextprogresschains
565
- # ProgressLogging.@logprogress progresschains / nchains
566
- nextprogresschains = progresschains + threshold
567
- end
569
+ Distributed . @async begin
570
+ while true
571
+ i, res = take! (channel)
572
+ # i == 0 means the overall progress bar; i > 0 means the
573
+ # progress bar for chain i.
574
+ prog_obj = if i == 0
575
+ overall_progress_obj
576
+ else
577
+ progress_objs[i]
578
+ end
579
+ if res # true = a chain / sample finished
580
+ ProgressMeter . next! (prog_obj)
581
+ else # false = all chains / samples finished (or one failed)
582
+ ProgressMeter . finish! (prog_obj)
583
+ break
568
584
end
569
585
end
570
586
end
571
587
572
588
Distributed. @async begin
573
589
try
574
- function sample_chain (seed, initial_params, initial_state)
590
+ function sample_chain (i, seed, initial_params, initial_state)
575
591
# Seed a new random number generator with the pre-made seed.
576
592
Random. seed! (rng, seed)
577
593
@@ -584,21 +600,23 @@ function mcmcsample(
584
600
progress= false ,
585
601
initial_params= initial_params,
586
602
initial_state= initial_state,
603
+ _chain_idx= i,
604
+ _progress_channel= channel,
587
605
kwargs... ,
588
606
)
589
607
590
608
# Update the progress bar.
591
- progress && put! (channel, true )
609
+ progress && put! (channel, ( 0 , true ) )
592
610
593
611
# Return the new chain.
594
612
return chain
595
613
end
596
614
chains = Distributed. pmap (
597
- sample_chain, pool, seeds, _initial_params, _initial_state
615
+ sample_chain, pool, 1 : nchains, seeds, _initial_params, _initial_state
598
616
)
599
617
finally
600
618
# Stop updating the progress bar.
601
- progress && put! (channel, false )
619
+ progress && put! (channel, ( 0 , false ) )
602
620
end
603
621
end
604
622
end
0 commit comments