Skip to content

Commit f83d087

Browse files
committed
Don't duplicate macro
1 parent 1539c2a commit f83d087

File tree

2 files changed

+11
-26
lines changed

2 files changed

+11
-26
lines changed

src/logging.jl

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# avoid creating a progress bar with @withprogress if progress logging is disabled
22
# and add a custom progress logger if the current logger does not seem to be able to handle
33
# progress logs
4-
macro single_ifwithprogresslogger(progress, exprs...)
4+
macro ifwithprogresslogger(cond, exprs...)
55
return esc(
66
quote
7-
if $progress == true
7+
if $cond
88
# If progress == true, then we want to create a new logger. Note that
99
# progress might not be a Bool.
1010
if $hasprogresslevel($Logging.current_logger())
@@ -23,25 +23,6 @@ macro single_ifwithprogresslogger(progress, exprs...)
2323
)
2424
end
2525

26-
# TODO(penelopeysm): figure out how to not have so much code duplication
27-
macro multi_ifwithprogresslogger(progress, exprs...)
28-
return esc(
29-
quote
30-
if $progress != :none
31-
if $hasprogresslevel($Logging.current_logger())
32-
$ProgressLogging.@withprogress $(exprs...)
33-
else
34-
$with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do
35-
$ProgressLogging.@withprogress $(exprs...)
36-
end
37-
end
38-
else
39-
$(exprs[end])
40-
end
41-
end,
42-
)
43-
end
44-
4526
macro log_progress_dispatch(progress, progressname, progress_frac)
4627
return esc(
4728
quote
@@ -50,7 +31,8 @@ macro log_progress_dispatch(progress, progressname, progress_frac)
5031
elseif $progress isa $UUIDs.UUID
5132
$ProgressLogging.@logprogress $progressname $progress_frac _id = $progress
5233
else
53-
# progress == false, or progress isa Channel, which is handled manually
34+
# progress == false, or progress isa Channel, both of which are
35+
# handled manually
5436
nothing
5537
end
5638
end,

src/sample.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,10 @@ function mcmcsample(
161161
start = time()
162162
local state
163163

164-
@single_ifwithprogresslogger progress name = progressname begin
164+
# Only create a new progress bar if progress is explicitly equal to true, i.e.
165+
# it's not a UUID (the progress bar already exists), a channel (there's no need
166+
# for a new progress bar), or false (no progress bar).
167+
@ifwithprogresslogger (progress == true) name = progressname begin
165168
# Determine threshold values for progress logging
166169
# (one update per 0.5% of progress)
167170
threshold = Ntotal ÷ 200
@@ -319,7 +322,7 @@ function mcmcsample(
319322
start = time()
320323
local state
321324

322-
@single_ifwithprogresslogger progress name = progressname begin
325+
@ifwithprogresslogger (progress == true) name = progressname begin
323326
# Obtain the initial sample and state.
324327
sample, state = if num_warmup > 0
325328
if initial_state === nothing
@@ -456,7 +459,7 @@ function mcmcsample(
456459
# Set up a chains vector.
457460
chains = Vector{Any}(undef, nchains)
458461

459-
@multi_ifwithprogresslogger progress name = progressname begin
462+
@ifwithprogresslogger (progress != :none) name = progressname begin
460463
if progress == :perchain
461464
# This is the 'overall' progress bar. We create a channel for each
462465
# chain to report back to when it finishes sampling.
@@ -633,7 +636,7 @@ function mcmcsample(
633636
pool = Distributed.CachingPool(Distributed.workers())
634637

635638
local chains
636-
@single_ifwithprogresslogger progress name = progressname begin
639+
@ifwithprogresslogger (progress == true) name = progressname begin
637640
# Create a channel for progress logging.
638641
if progress
639642
channel = Distributed.RemoteChannel(() -> Channel{Bool}(Distributed.nworkers()))

0 commit comments

Comments
 (0)