Skip to content

Commit cefafb0

Browse files
committed
Attempt to use proper types for logging
1 parent 5b2577f commit cefafb0

File tree

2 files changed

+110
-38
lines changed

2 files changed

+110
-38
lines changed

src/logging.jl

Lines changed: 86 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,89 @@
1+
"""
2+
AbstractProgressKwarg
3+
4+
Abstract type representing the values that the `progress` keyword argument can
5+
internally take for single-chain sampling.
6+
"""
7+
abstract type AbstractProgressKwarg end
8+
9+
"""
10+
CreateNewProgressBar
11+
12+
Create a new logger for progress logging.
13+
"""
14+
struct CreateNewProgressBar{S<:AbstractString} <: AbstractProgressKwarg
15+
name::S
16+
uuid::UUIDs.UUID
17+
18+
function CreateNewProgressBar(name::AbstractString)
19+
return new{typeof{name}}(name, UUIDs.uuid4())
20+
end
21+
end
22+
function init_progress(p::CreateNewProgressBar)
23+
if hasprogresslevel(Logging.current_logger())
24+
ProgressLogging.@withprogress $(exprs...)
25+
else
26+
$with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do
27+
$ProgressLogging.@withprogress $(exprs...)
28+
end
29+
end
30+
ProgressLogging.@logprogress p.name nothing _id = p.uuid
31+
end
32+
function update_progress(p::CreateNewProgressBar, progress_frac, ::Bool)
33+
ProgressLogging.@logprogress p.name progress_frac _id = p.uuid
34+
end
35+
finish_progress(::CreateNewProgressBar) = ProgressLogging.@logprogress "done"
36+
37+
"""
38+
NoLogging
39+
40+
Do not log progress at all.
41+
"""
42+
struct NoLogging <: AbstractProgressKwarg end
43+
init_progress(::NoLogging) = nothing
44+
update_progress(::NoLogging, ::Any, ::Bool) = nothing
45+
finish_progress(::NoLogging) = nothing
46+
47+
"""
48+
ExistingProgressBar
49+
50+
Use an existing progress bar to log progress. This is used for tracking
51+
progress in a progress bar that has been previously generated elsewhere,
52+
specifically, when `sample(..., MCMCThreads(), ...; progress=:perchain)` is
53+
called. In this case we can use `@logprogress name progress_frac _id = uuid` to
54+
log progress.
55+
"""
56+
struct ExistingProgressBar{S<:AbstractString} <: AbstractProgressKwarg
57+
name::S
58+
uuid::UUIDs.UUID
59+
end
60+
init_progress(::ExistingProgressBar) = nothing
61+
function update_progress(p::ExistingProgressBar, progress_frac, ::Bool)
62+
ProgressLogging.@logprogress p.name progress_frac _id = p.uuid
63+
end
64+
function finish_progress(p::ExistingProgressBar)
65+
ProgressLogging.@logprogress p.name "done" _id = p.uuid
66+
end
67+
68+
"""
69+
ChannelProgress
70+
71+
Use a `Channel` to log progress. This is used for 'reporting' progress back
72+
to the main thread or worker when using `progress=:overall` with MCMCThreads or
73+
MCMCDistributed.
74+
"""
75+
struct ChannelProgress{T<:Union{Channel{Bool},Distributed.RemoteChannel{Channel{Bool}}}} <:
76+
AbstractProgressKwarg
77+
channel::T
78+
end
79+
init_progress(::ChannelProgress) = nothing
80+
function update_progress(p::ChannelProgress, ::Any, update_channel::Bool)
81+
return update_channel && put!(p.channel, true)
82+
end
83+
# Note: We don't want to `put!(p.channel, false)`, because that would stop the
84+
# channel from being used for further updates e.g. from other chains.
85+
finish_progress(::ChannelProgress) = nothing
86+
187
# avoid creating a progress bar with @withprogress if progress logging is disabled
288
# and add a custom progress logger if the current logger does not seem to be able to handle
389
# progress logs
@@ -23,24 +109,6 @@ macro ifwithprogresslogger(cond, exprs...)
23109
)
24110
end
25111

26-
macro log_progress_dispatch(progress, progressname, progress_frac)
27-
return esc(
28-
quote
29-
if $progress == true
30-
# Use global logger
31-
$ProgressLogging.@logprogress $progress_frac
32-
elseif $progress isa $UUIDs.UUID
33-
# Use the logger with this specific UUID
34-
$ProgressLogging.@logprogress $progressname $progress_frac _id = $progress
35-
else
36-
# progress == false, or progress isa Channel, both of which are
37-
# handled manually
38-
nothing
39-
end
40-
end,
41-
)
42-
end
43-
44112
# improved checks?
45113
function hasprogresslevel(logger)
46114
return Logging.min_enabled_level(logger) ProgressLogging.ProgressLevel

src/sample.jl

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ function mcmcsample(
131131
model::AbstractModel,
132132
sampler::AbstractSampler,
133133
N::Integer;
134-
progress::Union{Bool,UUIDs.UUID,Channel{Bool},Distributed.RemoteChannel{Channel{Bool}}}=PROGRESS[],
134+
progress::Union{Bool,<:AbstractProgressKwarg}=PROGRESS[],
135135
progressname="Sampling",
136136
callback=nothing,
137137
num_warmup::Int=0,
@@ -152,6 +152,14 @@ function mcmcsample(
152152
ArgumentError("number of warm-up samples exceeds the total number of samples")
153153
)
154154

155+
# Initialise progress bar
156+
if progress === true
157+
progress = CreateNewProgressBar(progressname)
158+
elseif progress === false
159+
progress = NoLogging()
160+
end
161+
init_progress(progress)
162+
155163
# Determine how many samples to drop from `num_warmup` and the
156164
# main sampling process before we start saving samples.
157165
discard_from_warmup = min(num_warmup, discard_initial)
@@ -161,10 +169,7 @@ function mcmcsample(
161169
start = time()
162170
local state
163171

164-
# Only create a new progress bar if progress is explicitly equal to true, i.e.
165-
# it's not a UUID (the progress bar already exists), a channel (there's no need
166-
# for a new progress bar), or false (no progress bar).
167-
@ifwithprogresslogger (progress == true) name = progressname begin
172+
try
168173
# Determine threshold values for progress logging
169174
# (one update per 0.5% of progress)
170175
threshold = Ntotal ÷ 200
@@ -175,10 +180,10 @@ function mcmcsample(
175180
# is constructed in the multi-chain method, i.e. if we don't do this
176181
# the progress bar shows the time elapsed since _all_ sampling began,
177182
# not since the current chain started.
178-
if progress isa UUIDs.UUID
183+
if progress isa ExistingProgressBar
179184
try
180185
bartrees = Logging.current_logger().loggers[1].logger.bartrees
181-
bar = TerminalLoggers.findbar(bartrees, progress).data
186+
bar = TerminalLoggers.findbar(bartrees, progress.uuid).data
182187
bar.tfirst = time()
183188
catch
184189
end
@@ -202,13 +207,9 @@ function mcmcsample(
202207
# Start the progress bar.
203208
itotal = 1
204209
if itotal >= next_update
205-
@log_progress_dispatch progress progressname itotal / Ntotal
210+
update_progress(progress, itotal / Ntotal, true)
206211
next_update = itotal + threshold
207212
end
208-
if progress isa Channel{Bool} ||
209-
progress isa Distributed.RemoteChannel{Channel{Bool}}
210-
put!(progress, true)
211-
end
212213

213214
# Discard initial samples.
214215
for j in 1:discard_initial
@@ -222,7 +223,7 @@ function mcmcsample(
222223
# Update the progress bar.
223224
itotal += 1
224225
if itotal >= next_update
225-
@log_progress_dispatch progress progressname itotal / Ntotal
226+
update_progress(progress, itotal / Ntotal, false)
226227
next_update = itotal + threshold
227228
end
228229
end
@@ -248,7 +249,7 @@ function mcmcsample(
248249
# Update progress bar.
249250
itotal += 1
250251
if itotal >= next_update
251-
@log_progress_dispatch progress progressname itotal / Ntotal
252+
update_progress(progress, itotal / Ntotal, false)
252253
next_update = itotal + threshold
253254
end
254255
end
@@ -270,14 +271,17 @@ function mcmcsample(
270271
# Update the progress bar.
271272
itotal += 1
272273
if itotal >= next_update
273-
@log_progress_dispatch progress progressname itotal / Ntotal
274+
update_progress(progress, itotal / Ntotal, true)
274275
next_update = itotal + threshold
275276
end
276-
if progress isa Channel{Bool} ||
277-
progress isa Distributed.RemoteChannel{Channel{Bool}}
278-
put!(progress, true)
279-
end
280277
end
278+
catch e
279+
# If an error occurs, we still want to finish the progress bar.
280+
finish_progress(progress)
281+
rethrow(e)
282+
finally
283+
# Finish the progress bar.
284+
finish_progress(progress)
281285
end
282286

283287
# Get the sample stop time.
@@ -473,7 +477,7 @@ function mcmcsample(
473477
if progress == :perchain
474478
# This is the 'overall' progress bar. We create a channel for each
475479
# chain to report back to when it finishes sampling.
476-
progress_channel = Channel{Bool}()
480+
progress_channel = Channel{Bool}(nchunks)
477481
# These are the per-chain progress bars. We generate `nchains`
478482
# independent UUIDs for each progress bar
479483
uuids = [UUIDs.uuid4() for _ in 1:nchains]

0 commit comments

Comments
 (0)