@@ -153,7 +153,7 @@ function mcmcsample(
153
153
# Determine threshold values for progress logging (by default, one
154
154
# update per 0.5% of progress, unless this has been passed in
155
155
# explicitly)
156
- n_updates = progress isa ChannelProgress ? progress . n_updates : 200
156
+ n_updates = get_n_updates ( progress)
157
157
threshold = Ntotal / n_updates
158
158
next_update = threshold
159
159
@@ -445,30 +445,8 @@ function mcmcsample(
445
445
chains = Vector {Any} (undef, nchains)
446
446
447
447
@maybewithricherlogger begin
448
- if progress == :perchain
449
- # Create a channel for each chain to report back to when it
450
- # finishes sampling.
451
- progress_channel = Channel {Bool} (nchunks)
452
- # This is the 'overall' progress bar which tracks the number of
453
- # chains that have completed. Note that this progress bar is backed
454
- # by a channel, but it is not itself a ChannelProgress (because
455
- # ChannelProgress doesn't come with a progress bar).
456
- overall_progress_bar = CreateNewProgressBar (progressname)
457
- init_progress! (overall_progress_bar)
458
- # These are the per-chain progress bars. We generate `nchains`
459
- # independent UUIDs for each progress bar
460
- child_progresses = [
461
- ExistingProgressBar (" Chain $i /$nchains " , UUIDs. uuid4 ()) for i in 1 : nchains
462
- ]
463
- # Start the per-chain progress bars (but in reverse order, because
464
- # ProgressLogging prints from the bottom up, and we want chain 1 to
465
- # show up at the top)
466
- for child_progress in reverse (child_progresses)
467
- init_progress! (child_progress)
468
- end
469
- updates_per_chain = nothing
470
- elseif progress == :overall
471
- # Just a single progress bar for the entire sampling, but instead
448
+ if progress == :perchain || progress == :overall
449
+ # Create a single progress bar for the entire sampling, but instead
472
450
# of tracking each chain as it comes in, we track each sample as it
473
451
# comes in. This allows us to have more granular progress updates.
474
452
progress_channel = Channel {Bool} (nchains)
@@ -481,14 +459,28 @@ function mcmcsample(
481
459
# twice the amount needed per chain, which doesn't cause a real
482
460
# performance hit.
483
461
updates_per_chain = max (1 , 400 ÷ nchains)
462
+ init_progress! (overall_progress_bar)
463
+ end
464
+ if progress == :perchain
465
+ # Additionally, we create per-chain progress bars. We generate `nchains`
466
+ # independent UUIDs for each progress bar
467
+ child_progresses = [
468
+ ExistingProgressBar (" Chain $i /$nchains " , UUIDs. uuid4 ()) for i in 1 : nchains
469
+ ]
470
+ # Start the per-chain progress bars (but in reverse order, because
471
+ # ProgressLogging prints from the bottom up, and we want chain 1 to
472
+ # show up at the top)
473
+ for child_progress in reverse (child_progresses)
474
+ init_progress! (child_progress)
475
+ end
484
476
end
485
477
486
478
Distributed. @sync begin
487
- if progress != :none
488
- # This task updates the progress bar
479
+ if progress == :overall || progress == :perchain
480
+ # This task updates the overall progress bar
489
481
Distributed. @async begin
490
482
# Total number of updates (across all chains)
491
- Ntotal = progress == :overall ? nchains * updates_per_chain : nchains
483
+ Ntotal = nchains * updates_per_chain
492
484
# Determine threshold values for progress logging
493
485
# (one update per 0.5% of progress)
494
486
threshold = Ntotal / 200
@@ -530,7 +522,10 @@ function mcmcsample(
530
522
elseif progress == :overall
531
523
ChannelProgress (progress_channel, updates_per_chain)
532
524
elseif progress == :perchain
533
- child_progresses[chainidx] # <- isa ExistingProgressBar
525
+ chan_prog = ChannelProgress (progress_channel, updates_per_chain)
526
+ ChannelPlusExistingProgress (
527
+ chan_prog, child_progresses[chainidx]
528
+ )
534
529
end
535
530
536
531
# Sample a chain and save it to the vector.
@@ -552,33 +547,19 @@ function mcmcsample(
552
547
end ,
553
548
kwargs... ,
554
549
)
555
-
556
- # Update the progress bars.
557
- if progress == :perchain
558
- # Tell the 'main' progress bar that this chain is done.
559
- put! (progress_channel, true )
560
- # Conclude the per-chain progress bar.
561
- finish_progress! (child_progresses[chainidx])
562
- end
563
- # Note that if progress == :overall, we don't need to do anything
564
- # because progress on that bar is triggered by
565
- # samples being obtained rather than chains being
566
- # completed.
567
550
end
568
551
end
569
552
finally
570
- if progress == :perchain
553
+ if progress == :overall || progress == : perchain
571
554
# Stop updating the main progress bar (either if sampling
572
555
# is done, or if an error occurs).
573
556
put! (progress_channel, false )
557
+ end
558
+ if progress == :perchain
574
559
# Additionally stop the per-chain progress bars
575
560
for child_progress in child_progresses
576
561
finish_progress! (child_progress)
577
562
end
578
- elseif progress == :overall
579
- # Stop updating the main progress bar (either if sampling
580
- # is done, or if an error occurs).
581
- put! (progress_channel, false )
582
563
end
583
564
end
584
565
end
0 commit comments