Skip to content

Commit 5c5b912

Browse files
committed
Make :perchain use the richer overall progress bar
1 parent 284741f commit 5c5b912

File tree

2 files changed

+62
-53
lines changed

2 files changed

+62
-53
lines changed

src/logging.jl

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,15 @@ struct NoLogging <: AbstractProgressKwarg end
3737
init_progress!(::NoLogging) = nothing
3838
update_progress!(::NoLogging, ::Any) = nothing
3939
finish_progress!(::NoLogging) = nothing
40+
get_n_updates(::NoLogging) = 200
4041

4142
"""
4243
ExistingProgressBar
43-
4444
Use an existing progress bar to log progress. This is used for tracking
4545
progress in a progress bar that has been previously generated elsewhere,
46-
specifically, when `sample(..., MCMCThreads(), ...; progress=:perchain)` is
47-
called. In this case we can use `@logprogress name progress_frac _id = uuid` to
48-
log progress.
46+
specifically, during multi-threaded sampling where per-chain progress
47+
bars are requested. In this case we can use `@logprogress name progress_frac
48+
_id = uuid` to log progress.
4949
"""
5050
struct ExistingProgressBar{S<:AbstractString} <: AbstractProgressKwarg
5151
name::S
@@ -71,13 +71,13 @@ end
7171
function finish_progress!(p::ExistingProgressBar)
7272
ProgressLogging.@logprogress p.name "done" _id = p.uuid
7373
end
74+
get_n_updates(::ExistingProgressBar) = 200
7475

7576
"""
7677
ChannelProgress
7778
78-
Use a `Channel` to log progress. This is used for 'reporting' progress back
79-
to the main thread or worker when using `progress=:overall` with MCMCThreads or
80-
MCMCDistributed.
79+
Use a `Channel` to log progress. This is used for 'reporting' progress back to
80+
the main thread or worker when using multi-threaded or distributed sampling.
8181
8282
n_updates is the number of updates that each child thread is expected to report
8383
back to the main thread.
@@ -92,6 +92,34 @@ update_progress!(p::ChannelProgress, ::Any) = put!(p.channel, true)
9292
# Note: We don't want to `put!(p.channel, false)`, because that would stop the
9393
# channel from being used for further updates e.g. from other chains.
9494
finish_progress!(::ChannelProgress) = nothing
95+
get_n_updates(p::ChannelProgress) = p.n_updates
96+
97+
"""
98+
ChannelPlusExistingProgress
99+
100+
Send updates to two places: a `Channel` as well as an existing progress bar.
101+
"""
102+
struct ChannelPlusExistingProgress{C<:ChannelProgress,E<:ExistingProgressBar} <:
103+
AbstractProgressKwarg
104+
channel_progress::C
105+
existing_progress::E
106+
end
107+
function init_progress!(p::ChannelPlusExistingProgress)
108+
init_progress!(p.channel_progress)
109+
init_progress!(p.existing_progress)
110+
return nothing
111+
end
112+
function update_progress!(p::ChannelPlusExistingProgress, progress_frac)
113+
update_progress!(p.channel_progress, progress_frac)
114+
update_progress!(p.existing_progress, progress_frac)
115+
return nothing
116+
end
117+
function finish_progress!(p::ChannelPlusExistingProgress)
118+
finish_progress!(p.channel_progress)
119+
finish_progress!(p.existing_progress)
120+
return nothing
121+
end
122+
get_n_updates(p::ChannelPlusExistingProgress) = get_n_updates(p.channel_progress)
95123

96124
# Add a custom progress logger if the current logger does not seem to be able to handle
97125
# progress logs.

src/sample.jl

Lines changed: 27 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ function mcmcsample(
153153
# Determine threshold values for progress logging (by default, one
154154
# update per 0.5% of progress, unless this has been passed in
155155
# explicitly)
156-
n_updates = progress isa ChannelProgress ? progress.n_updates : 200
156+
n_updates = get_n_updates(progress)
157157
threshold = Ntotal / n_updates
158158
next_update = threshold
159159

@@ -445,30 +445,8 @@ function mcmcsample(
445445
chains = Vector{Any}(undef, nchains)
446446

447447
@maybewithricherlogger begin
448-
if progress == :perchain
449-
# Create a channel for each chain to report back to when it
450-
# finishes sampling.
451-
progress_channel = Channel{Bool}(nchunks)
452-
# This is the 'overall' progress bar which tracks the number of
453-
# chains that have completed. Note that this progress bar is backed
454-
# by a channel, but it is not itself a ChannelProgress (because
455-
# ChannelProgress doesn't come with a progress bar).
456-
overall_progress_bar = CreateNewProgressBar(progressname)
457-
init_progress!(overall_progress_bar)
458-
# These are the per-chain progress bars. We generate `nchains`
459-
# independent UUIDs for each progress bar
460-
child_progresses = [
461-
ExistingProgressBar("Chain $i/$nchains", UUIDs.uuid4()) for i in 1:nchains
462-
]
463-
# Start the per-chain progress bars (but in reverse order, because
464-
# ProgressLogging prints from the bottom up, and we want chain 1 to
465-
# show up at the top)
466-
for child_progress in reverse(child_progresses)
467-
init_progress!(child_progress)
468-
end
469-
updates_per_chain = nothing
470-
elseif progress == :overall
471-
# Just a single progress bar for the entire sampling, but instead
448+
if progress == :perchain || progress == :overall
449+
# Create a single progress bar for the entire sampling, but instead
472450
# of tracking each chain as it comes in, we track each sample as it
473451
# comes in. This allows us to have more granular progress updates.
474452
progress_channel = Channel{Bool}(nchains)
@@ -481,14 +459,28 @@ function mcmcsample(
481459
# twice the amount needed per chain, which doesn't cause a real
482460
# performance hit.
483461
updates_per_chain = max(1, 400 ÷ nchains)
462+
init_progress!(overall_progress_bar)
463+
end
464+
if progress == :perchain
465+
# Additionally, we create per-chain progress bars. We generate `nchains`
466+
# independent UUIDs for each progress bar
467+
child_progresses = [
468+
ExistingProgressBar("Chain $i/$nchains", UUIDs.uuid4()) for i in 1:nchains
469+
]
470+
# Start the per-chain progress bars (but in reverse order, because
471+
# ProgressLogging prints from the bottom up, and we want chain 1 to
472+
# show up at the top)
473+
for child_progress in reverse(child_progresses)
474+
init_progress!(child_progress)
475+
end
484476
end
485477

486478
Distributed.@sync begin
487-
if progress != :none
488-
# This task updates the progress bar
479+
if progress == :overall || progress == :perchain
480+
# This task updates the overall progress bar
489481
Distributed.@async begin
490482
# Total number of updates (across all chains)
491-
Ntotal = progress == :overall ? nchains * updates_per_chain : nchains
483+
Ntotal = nchains * updates_per_chain
492484
# Determine threshold values for progress logging
493485
# (one update per 0.5% of progress)
494486
threshold = Ntotal / 200
@@ -530,7 +522,10 @@ function mcmcsample(
530522
elseif progress == :overall
531523
ChannelProgress(progress_channel, updates_per_chain)
532524
elseif progress == :perchain
533-
child_progresses[chainidx] # <- isa ExistingProgressBar
525+
chan_prog = ChannelProgress(progress_channel, updates_per_chain)
526+
ChannelPlusExistingProgress(
527+
chan_prog, child_progresses[chainidx]
528+
)
534529
end
535530

536531
# Sample a chain and save it to the vector.
@@ -552,33 +547,19 @@ function mcmcsample(
552547
end,
553548
kwargs...,
554549
)
555-
556-
# Update the progress bars.
557-
if progress == :perchain
558-
# Tell the 'main' progress bar that this chain is done.
559-
put!(progress_channel, true)
560-
# Conclude the per-chain progress bar.
561-
finish_progress!(child_progresses[chainidx])
562-
end
563-
# Note that if progress == :overall, we don't need to do anything
564-
# because progress on that bar is triggered by
565-
# samples being obtained rather than chains being
566-
# completed.
567550
end
568551
end
569552
finally
570-
if progress == :perchain
553+
if progress == :overall || progress == :perchain
571554
# Stop updating the main progress bar (either if sampling
572555
# is done, or if an error occurs).
573556
put!(progress_channel, false)
557+
end
558+
if progress == :perchain
574559
# Additionally stop the per-chain progress bars
575560
for child_progress in child_progresses
576561
finish_progress!(child_progress)
577562
end
578-
elseif progress == :overall
579-
# Stop updating the main progress bar (either if sampling
580-
# is done, or if an error occurs).
581-
put!(progress_channel, false)
582563
end
583564
end
584565
end

0 commit comments

Comments
 (0)