@@ -131,7 +131,7 @@ function mcmcsample(
131
131
model:: AbstractModel ,
132
132
sampler:: AbstractSampler ,
133
133
N:: Integer ;
134
- progress:: Union{Bool,UUIDs.UUID,Channel{Bool},Distributed.RemoteChannel{Channel{Bool}} } = PROGRESS[],
134
+ progress:: Union{Bool,<:AbstractProgressKwarg } = PROGRESS[],
135
135
progressname= " Sampling" ,
136
136
callback= nothing ,
137
137
num_warmup:: Int = 0 ,
@@ -152,6 +152,14 @@ function mcmcsample(
152
152
ArgumentError (" number of warm-up samples exceeds the total number of samples" )
153
153
)
154
154
155
+ # Initialise progress bar
156
+ if progress === true
157
+ progress = CreateNewProgressBar (progressname)
158
+ elseif progress === false
159
+ progress = NoLogging ()
160
+ end
161
+ init_progress (progress)
162
+
155
163
# Determine how many samples to drop from `num_warmup` and the
156
164
# main sampling process before we start saving samples.
157
165
discard_from_warmup = min (num_warmup, discard_initial)
@@ -161,10 +169,7 @@ function mcmcsample(
161
169
start = time ()
162
170
local state
163
171
164
- # Only create a new progress bar if progress is explicitly equal to true, i.e.
165
- # it's not a UUID (the progress bar already exists), a channel (there's no need
166
- # for a new progress bar), or false (no progress bar).
167
- @ifwithprogresslogger (progress == true ) name = progressname begin
172
+ try
168
173
# Determine threshold values for progress logging
169
174
# (one update per 0.5% of progress)
170
175
threshold = Ntotal ÷ 200
@@ -175,10 +180,10 @@ function mcmcsample(
175
180
# is constructed in the multi-chain method, i.e. if we don't do this
176
181
# the progress bar shows the time elapsed since _all_ sampling began,
177
182
# not since the current chain started.
178
- if progress isa UUIDs . UUID
183
+ if progress isa ExistingProgressBar
179
184
try
180
185
bartrees = Logging. current_logger (). loggers[1 ]. logger. bartrees
181
- bar = TerminalLoggers. findbar (bartrees, progress). data
186
+ bar = TerminalLoggers. findbar (bartrees, progress. uuid ). data
182
187
bar. tfirst = time ()
183
188
catch
184
189
end
@@ -202,13 +207,9 @@ function mcmcsample(
202
207
# Start the progress bar.
203
208
itotal = 1
204
209
if itotal >= next_update
205
- @log_progress_dispatch progress progressname itotal / Ntotal
210
+ update_progress ( progress, itotal / Ntotal, true )
206
211
next_update = itotal + threshold
207
212
end
208
- if progress isa Channel{Bool} ||
209
- progress isa Distributed. RemoteChannel{Channel{Bool}}
210
- put! (progress, true )
211
- end
212
213
213
214
# Discard initial samples.
214
215
for j in 1 : discard_initial
@@ -222,7 +223,7 @@ function mcmcsample(
222
223
# Update the progress bar.
223
224
itotal += 1
224
225
if itotal >= next_update
225
- @log_progress_dispatch progress progressname itotal / Ntotal
226
+ update_progress ( progress, itotal / Ntotal, false )
226
227
next_update = itotal + threshold
227
228
end
228
229
end
@@ -248,7 +249,7 @@ function mcmcsample(
248
249
# Update progress bar.
249
250
itotal += 1
250
251
if itotal >= next_update
251
- @log_progress_dispatch progress progressname itotal / Ntotal
252
+ update_progress ( progress, itotal / Ntotal, false )
252
253
next_update = itotal + threshold
253
254
end
254
255
end
@@ -270,14 +271,17 @@ function mcmcsample(
270
271
# Update the progress bar.
271
272
itotal += 1
272
273
if itotal >= next_update
273
- @log_progress_dispatch progress progressname itotal / Ntotal
274
+ update_progress ( progress, itotal / Ntotal, true )
274
275
next_update = itotal + threshold
275
276
end
276
- if progress isa Channel{Bool} ||
277
- progress isa Distributed. RemoteChannel{Channel{Bool}}
278
- put! (progress, true )
279
- end
280
277
end
278
+ catch e
279
+ # If an error occurs, we still want to finish the progress bar.
280
+ finish_progress (progress)
281
+ rethrow (e)
282
+ finally
283
+ # Finish the progress bar.
284
+ finish_progress (progress)
281
285
end
282
286
283
287
# Get the sample stop time.
@@ -473,7 +477,7 @@ function mcmcsample(
473
477
if progress == :perchain
474
478
# This is the 'overall' progress bar. We create a channel for each
475
479
# chain to report back to when it finishes sampling.
476
- progress_channel = Channel {Bool} ()
480
+ progress_channel = Channel {Bool} (nchunks )
477
481
# These are the per-chain progress bars. We generate `nchains`
478
482
# independent UUIDs for each progress bar
479
483
uuids = [UUIDs. uuid4 () for _ in 1 : nchains]
0 commit comments