Skip to content

Commit 8d66ea0

Browse files
committed
Fix implementation
1 parent ab3cf26 commit 8d66ea0

File tree

3 files changed

+151
-88
lines changed

3 files changed

+151
-88
lines changed

src/AbstractMCMC.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ using FillArrays: FillArrays
1313
using Distributed: Distributed
1414
using Logging: Logging
1515
using Random: Random
16+
using UUIDs: UUIDs
1617

1718
# Reexport sample
1819
using StatsBase: sample

src/logging.jl

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# avoid creating a progress bar with @withprogress if progress logging is disabled
22
# and add a custom progress logger if the current logger does not seem to be able to handle
33
# progress logs
4-
macro ifwithprogresslogger(progress, exprs...)
4+
macro single_ifwithprogresslogger(progress, exprs...)
55
return esc(
66
quote
77
if $progress == true
8+
# If progress == true, then we want to create a new logger. Note that
9+
# progress might not be a Bool.
810
if $hasprogresslevel($Logging.current_logger())
911
$ProgressLogging.@withprogress $(exprs...)
1012
else
@@ -13,12 +15,48 @@ macro ifwithprogresslogger(progress, exprs...)
1315
end
1416
end
1517
else
18+
# otherwise, progress isa UUID, or a channel, or false, in
19+
# which case we don't want to create a new logger.
1620
$(exprs[end])
1721
end
1822
end,
1923
)
2024
end
2125

26+
# TODO(penelopeysm): figure out how to not have so much code duplication
27+
macro multi_ifwithprogresslogger(progress, exprs...)
28+
return esc(
29+
quote
30+
if $progress != :none
31+
if $hasprogresslevel($Logging.current_logger())
32+
$ProgressLogging.@withprogress $(exprs...)
33+
else
34+
$with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do
35+
$ProgressLogging.@withprogress $(exprs...)
36+
end
37+
end
38+
else
39+
$(exprs[end])
40+
end
41+
end,
42+
)
43+
end
44+
45+
macro log_progress_dispatch(progress, progressname, progress_frac)
46+
return esc(
47+
quote
48+
if $progress == true
49+
$ProgressLogging.@logprogress $progress_frac
50+
elseif $progress isa $UUIDs.UUID
51+
$ProgressLogging.@logprogress $progressname $progress_frac _id = $progress
52+
else
53+
# progress == false, or progress isa Channel, which is handled manually
54+
nothing
55+
end
56+
end,
57+
)
58+
end
59+
2260
# improved checks?
2361
function hasprogresslevel(logger)
2462
return Logging.min_enabled_level(logger) ProgressLogging.ProgressLevel

src/sample.jl

Lines changed: 111 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using UUIDs: uuid4
2-
31
# Default implementations of `sample`.
42
const PROGRESS = Ref(true)
53

@@ -113,7 +111,7 @@ function mcmcsample(
113111
model::AbstractModel,
114112
sampler::AbstractSampler,
115113
N::Integer;
116-
progress=PROGRESS[],
114+
progress::Union{Bool,UUIDs.UUID,Channel{Bool}}=PROGRESS[],
117115
progressname="Sampling",
118116
callback=nothing,
119117
num_warmup::Int=0,
@@ -143,21 +141,21 @@ function mcmcsample(
143141
start = time()
144142
local state
145143

146-
@ifwithprogresslogger progress name = progressname begin
144+
@single_ifwithprogresslogger progress name = progressname begin
147145
# Determine threshold values for progress logging
148146
# (one update per 0.5% of progress)
149-
if !(progress == false)
150-
threshold = Ntotal ÷ 200
151-
next_update = threshold
152-
end
153-
154-
# Ugly hacky code to reset the start timer if called from a multi-chain
155-
# sampling process
156-
# TODO: How to make this better?
157-
if progress isa ProgressLogging.Progress
147+
threshold = Ntotal ÷ 200
148+
next_update = threshold
149+
150+
# Slightly hacky code to reset the start timer if called from a
151+
# multi-chain sampling process. We need this because the progress bar
152+
# is constructed in the multi-chain method, i.e. if we don't do this
153+
# the progress bar shows the time elapsed since _all_ sampling began,
154+
# not since the current chain started.
155+
if progress isa UUIDs.UUID
158156
try
159157
bartrees = Logging.current_logger().loggers[1].logger.bartrees
160-
bar = TerminalLoggers.findbar(bartrees, progress.id).data
158+
bar = TerminalLoggers.findbar(bartrees, progress).data
161159
bar.tfirst = time()
162160
catch
163161
end
@@ -178,17 +176,13 @@ function mcmcsample(
178176
end
179177
end
180178

181-
# Update the progress bar.
179+
# Start the progress bar.
182180
itotal = 1
183-
if !(progress == false) && itotal >= next_update
184-
if progress == true
185-
ProgressLogging.@logprogress itotal / Ntotal
186-
else
187-
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
188-
progress.id
189-
end
181+
if itotal >= next_update
182+
@log_progress_dispatch progress progressname itotal / Ntotal
190183
next_update = itotal + threshold
191184
end
185+
progress isa Channel{Bool} && put!(progress, true)
192186

193187
# Discard initial samples.
194188
for j in 1:discard_initial
@@ -200,13 +194,9 @@ function mcmcsample(
200194
end
201195

202196
# Update the progress bar.
203-
if !(progress == false) && (itotal += 1) >= next_update
204-
if progress == true
205-
ProgressLogging.@logprogress itotal / Ntotal
206-
else
207-
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
208-
progress.id
209-
end
197+
itotal += 1
198+
if itotal >= next_update
199+
@log_progress_dispatch progress progressname itotal / Ntotal
210200
next_update = itotal + threshold
211201
end
212202
end
@@ -230,13 +220,9 @@ function mcmcsample(
230220
end
231221

232222
# Update progress bar.
233-
if !(progress == false) && (itotal += 1) >= next_update
234-
if progress == true
235-
ProgressLogging.@logprogress itotal / Ntotal
236-
else
237-
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
238-
progress.id
239-
end
223+
itotal += 1
224+
if itotal >= next_update
225+
@log_progress_dispatch progress progressname itotal / Ntotal
240226
next_update = itotal + threshold
241227
end
242228
end
@@ -256,15 +242,12 @@ function mcmcsample(
256242
samples = save!!(samples, sample, i, model, sampler, N; kwargs...)
257243

258244
# Update the progress bar.
259-
if !(progress == false) && (itotal += 1) >= next_update
260-
if progress == true
261-
ProgressLogging.@logprogress itotal / Ntotal
262-
else
263-
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
264-
progress.id
265-
end
245+
itotal += 1
246+
if itotal >= next_update
247+
@log_progress_dispatch progress progressname itotal / Ntotal
266248
next_update = itotal + threshold
267249
end
250+
progress isa Channel{Bool} && put!(progress, true)
268251
end
269252
end
270253

@@ -316,7 +299,7 @@ function mcmcsample(
316299
start = time()
317300
local state
318301

319-
@ifwithprogresslogger progress name = progressname begin
302+
@single_ifwithprogresslogger progress name = progressname begin
320303
# Obtain the initial sample and state.
321304
sample, state = if num_warmup > 0
322305
if initial_state === nothing
@@ -423,6 +406,14 @@ function mcmcsample(
423406
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
424407
end
425408

409+
# Determine default progress bar style.
410+
if progress == true
411+
progress = nchains > 10 ? :overall : :perchain
412+
elseif progress == false
413+
progress = :none
414+
end
415+
# By this point, `progress` should be a Symbol, one of `:overall`, `:perchain`, or `:none`.
416+
426417
# Copy the random number generator, model, and sample for each thread
427418
nchunks = min(nchains, Threads.nthreads())
428419
interval = 1:nchunks
@@ -445,36 +436,44 @@ function mcmcsample(
445436
# Set up a chains vector.
446437
chains = Vector{Any}(undef, nchains)
447438

448-
@ifwithprogresslogger progress name = progressname begin
449-
# Create a channel for progress logging.
450-
if progress
451-
channel = Channel{Bool}(length(interval))
452-
end
453-
# Generate nchains independent UUIDs for each progress bar
454-
uuids = [uuid4() for _ in 1:nchains]
455-
# Start the progress bars (but in reverse order, because
456-
# ProgressLogging prints from the bottom up, and we want chain 1 to
457-
# show up at the top)
458-
for (i, uuid) in enumerate(reverse(uuids))
459-
ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" nothing _id =
460-
uuid
439+
@multi_ifwithprogresslogger progress name = progressname begin
440+
if progress == :perchain
441+
# This is the 'overall' progress bar. We create a channel for each
442+
# chain to report back to when it finishes sampling.
443+
progress_channel = Channel{Bool}()
444+
# These are the per-chain progress bars. We generate `nchains`
445+
# independent UUIDs for each progress bar
446+
uuids = [UUIDs.uuid4() for _ in 1:nchains]
447+
progress_names = ["Chain $i/$nchains" for i in 1:nchains]
448+
# Start the per-chain progress bars (but in reverse order, because
449+
# ProgressLogging prints from the bottom up, and we want chain 1 to
450+
# show up at the top)
451+
for (progress_name, uuid) in reverse(collect(zip(progress_names, uuids)))
452+
ProgressLogging.@logprogress name = progress_name nothing _id = uuid
453+
end
454+
elseif progress == :overall
455+
# Just a single progress bar for the entire sampling, but instead
456+
# of tracking each chain as it comes in, we track each sample as it
457+
# comes in. This allows us to have more granular progress updates.
458+
progress_channel = Channel{Bool}()
461459
end
462460

463461
Distributed.@sync begin
464-
if progress
465-
# Update the progress bar.
462+
if progress != :none
463+
# This task updates the progress bar
466464
Distributed.@async begin
467465
# Determine threshold values for progress logging
468466
# (one update per 0.5% of progress)
469-
threshold = nchains ÷ 200
470-
nextprogresschains = threshold
471-
472-
progresschains = 0
473-
while take!(channel)
474-
progresschains += 1
475-
if progresschains >= nextprogresschains
476-
ProgressLogging.@logprogress progresschains / nchains
477-
nextprogresschains = progresschains + threshold
467+
Ntotal = progress == :overall ? nchains * N : nchains
468+
threshold = Ntotal ÷ 200
469+
next_update = threshold
470+
471+
itotal = 0
472+
while take!(progress_channel)
473+
itotal += 1
474+
if itotal >= next_update
475+
ProgressLogging.@logprogress itotal / Ntotal
476+
next_update = itotal + threshold
478477
end
479478
end
480479
end
@@ -498,15 +497,23 @@ function mcmcsample(
498497
# Seed the chunk-specific random number generator with the pre-made seed.
499498
Random.seed!(_rng, seeds[chainidx])
500499

501-
# Sample a chain and save it to the vector.
502-
child_progressname = "Chain $chainidx/$nchains"
503-
child_progress = if progress == false
504-
false
505-
else
506-
ProgressLogging.Progress(
507-
uuids[chainidx]; name=child_progressname
508-
)
500+
# Determine how to monitor progress for the child chains.
501+
child_progress, child_progressname = if progress == :none
502+
# No need to create a progress bar
503+
false, ""
504+
elseif progress == :overall
505+
# No need to create a new progress bar, but we need to
506+
# pass the channel to the child so that it can log when
507+
# it has finished obtaining each sample.
508+
progress_channel, ""
509+
elseif progress == :perchain
510+
# We need to specify both the ID of the progress bar for
511+
# the child to update, and we also specify the name to use
512+
# for the progress bar.
513+
uuids[chainidx], progress_names[chainidx]
509514
end
515+
516+
# Sample a chain and save it to the vector.
510517
chains[chainidx] = StatsBase.sample(
511518
_rng,
512519
_model,
@@ -527,19 +534,36 @@ function mcmcsample(
527534
kwargs...,
528535
)
529536

530-
ProgressLogging.@logprogress name = child_progressname "done" _id = uuids[chainidx]
531-
532-
# Update the progress bar.
533-
progress && put!(channel, true)
537+
# Update the progress bars.
538+
if progress == :perchain
539+
# Tell the 'main' progress bar that this chain is done.
540+
put!(progress_channel, true)
541+
# Conclude the per-chain progress bar.
542+
ProgressLogging.@logprogress progress_names[chainidx] "done" _id = uuids[chainidx]
543+
end
544+
# Note that if progress == :overall, we don't need to do anything
545+
# because progress on that bar is triggered by
546+
# samples being obtained rather than chains being
547+
# completed.
534548
end
535549
end
536550
finally
537-
# Stop updating the progress bars (either if sampling is done, or if
538-
# an error occurs).
539-
progress && put!(channel, false)
540-
for (i, uuid) in enumerate(reverse(uuids))
541-
ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" "done" _id =
542-
uuid
551+
if progress == :perchain
552+
# Stop updating the main progress bar (either if sampling
553+
# is done, or if an error occurs).
554+
put!(progress_channel, false)
555+
# Additionally stop the per-chain progress bars (but in
556+
# reverse order, because ProgressLogging prints from
557+
# the bottom up, and we want chain 1 to show up at the
558+
# top)
559+
for (progress_name, uuid) in
560+
reverse(collect(zip(progress_names, uuids)))
561+
ProgressLogging.@logprogress progress_name "done" _id = uuid
562+
end
563+
elseif progress == :overall
564+
# Stop updating the main progress bar (either if sampling
565+
# is done, or if an error occurs).
566+
put!(progress_channel, false)
543567
end
544568
end
545569
end
@@ -589,7 +613,7 @@ function mcmcsample(
589613
pool = Distributed.CachingPool(Distributed.workers())
590614

591615
local chains
592-
@ifwithprogresslogger progress name = progressname begin
616+
@single_ifwithprogresslogger progress name = progressname begin
593617
# Create a channel for progress logging.
594618
if progress
595619
channel = Distributed.RemoteChannel(() -> Channel{Bool}(Distributed.nworkers()))

0 commit comments

Comments
 (0)