1
- using UUIDs: uuid4
2
-
3
1
# Default implementations of `sample`.
4
2
const PROGRESS = Ref (true )
5
3
@@ -113,7 +111,7 @@ function mcmcsample(
113
111
model:: AbstractModel ,
114
112
sampler:: AbstractSampler ,
115
113
N:: Integer ;
116
- progress= PROGRESS[],
114
+ progress:: Union{Bool,UUIDs.UUID,Channel{Bool}} = PROGRESS[],
117
115
progressname= " Sampling" ,
118
116
callback= nothing ,
119
117
num_warmup:: Int = 0 ,
@@ -143,21 +141,21 @@ function mcmcsample(
143
141
start = time ()
144
142
local state
145
143
146
- @ifwithprogresslogger progress name = progressname begin
144
+ @single_ifwithprogresslogger progress name = progressname begin
147
145
# Determine threshold values for progress logging
148
146
# (one update per 0.5% of progress)
149
- if ! (progress == false )
150
- threshold = Ntotal ÷ 200
151
- next_update = threshold
152
- end
153
-
154
- # Ugly hacky code to reset the start timer if called from a multi-chain
155
- # sampling process
156
- # TODO : How to make this better?
157
- if progress isa ProgressLogging . Progress
147
+ threshold = Ntotal ÷ 200
148
+ next_update = threshold
149
+
150
+ # Slightly hacky code to reset the start timer if called from a
151
+ # multi-chain sampling process. We need this because the progress bar
152
+ # is constructed in the multi-chain method, i.e. if we don't do this
153
+ # the progress bar shows the time elapsed since _all_ sampling began,
154
+ # not since the current chain started.
155
+ if progress isa UUIDs . UUID
158
156
try
159
157
bartrees = Logging. current_logger (). loggers[1 ]. logger. bartrees
160
- bar = TerminalLoggers. findbar (bartrees, progress. id ). data
158
+ bar = TerminalLoggers. findbar (bartrees, progress). data
161
159
bar. tfirst = time ()
162
160
catch
163
161
end
@@ -178,17 +176,13 @@ function mcmcsample(
178
176
end
179
177
end
180
178
181
- # Update the progress bar.
179
+ # Start the progress bar.
182
180
itotal = 1
183
- if ! (progress == false ) && itotal >= next_update
184
- if progress == true
185
- ProgressLogging. @logprogress itotal / Ntotal
186
- else
187
- ProgressLogging. @logprogress name = progressname itotal / Ntotal _id =
188
- progress. id
189
- end
181
+ if itotal >= next_update
182
+ @log_progress_dispatch progress progressname itotal / Ntotal
190
183
next_update = itotal + threshold
191
184
end
185
+ progress isa Channel{Bool} && put! (progress, true )
192
186
193
187
# Discard initial samples.
194
188
for j in 1 : discard_initial
@@ -200,13 +194,9 @@ function mcmcsample(
200
194
end
201
195
202
196
# Update the progress bar.
203
- if ! (progress == false ) && (itotal += 1 ) >= next_update
204
- if progress == true
205
- ProgressLogging. @logprogress itotal / Ntotal
206
- else
207
- ProgressLogging. @logprogress name = progressname itotal / Ntotal _id =
208
- progress. id
209
- end
197
+ itotal += 1
198
+ if itotal >= next_update
199
+ @log_progress_dispatch progress progressname itotal / Ntotal
210
200
next_update = itotal + threshold
211
201
end
212
202
end
@@ -230,13 +220,9 @@ function mcmcsample(
230
220
end
231
221
232
222
# Update progress bar.
233
- if ! (progress == false ) && (itotal += 1 ) >= next_update
234
- if progress == true
235
- ProgressLogging. @logprogress itotal / Ntotal
236
- else
237
- ProgressLogging. @logprogress name = progressname itotal / Ntotal _id =
238
- progress. id
239
- end
223
+ itotal += 1
224
+ if itotal >= next_update
225
+ @log_progress_dispatch progress progressname itotal / Ntotal
240
226
next_update = itotal + threshold
241
227
end
242
228
end
@@ -256,15 +242,12 @@ function mcmcsample(
256
242
samples = save!! (samples, sample, i, model, sampler, N; kwargs... )
257
243
258
244
# Update the progress bar.
259
- if ! (progress == false ) && (itotal += 1 ) >= next_update
260
- if progress == true
261
- ProgressLogging. @logprogress itotal / Ntotal
262
- else
263
- ProgressLogging. @logprogress name = progressname itotal / Ntotal _id =
264
- progress. id
265
- end
245
+ itotal += 1
246
+ if itotal >= next_update
247
+ @log_progress_dispatch progress progressname itotal / Ntotal
266
248
next_update = itotal + threshold
267
249
end
250
+ progress isa Channel{Bool} && put! (progress, true )
268
251
end
269
252
end
270
253
@@ -316,7 +299,7 @@ function mcmcsample(
316
299
start = time ()
317
300
local state
318
301
319
- @ifwithprogresslogger progress name = progressname begin
302
+ @single_ifwithprogresslogger progress name = progressname begin
320
303
# Obtain the initial sample and state.
321
304
sample, state = if num_warmup > 0
322
305
if initial_state === nothing
@@ -423,6 +406,14 @@ function mcmcsample(
423
406
@warn " Number of chains ($nchains ) is greater than number of samples per chain ($N )"
424
407
end
425
408
409
+ # Determine default progress bar style.
410
+ if progress == true
411
+ progress = nchains > 10 ? :overall : :perchain
412
+ elseif progress == false
413
+ progress = :none
414
+ end
415
+ # By this point, `progress` should be a Symbol, one of `:overall`, `:perchain`, or `:none`.
416
+
426
417
# Copy the random number generator, model, and sample for each thread
427
418
nchunks = min (nchains, Threads. nthreads ())
428
419
interval = 1 : nchunks
@@ -445,36 +436,44 @@ function mcmcsample(
445
436
# Set up a chains vector.
446
437
chains = Vector {Any} (undef, nchains)
447
438
448
- @ifwithprogresslogger progress name = progressname begin
449
- # Create a channel for progress logging.
450
- if progress
451
- channel = Channel {Bool} (length (interval))
452
- end
453
- # Generate nchains independent UUIDs for each progress bar
454
- uuids = [uuid4 () for _ in 1 : nchains]
455
- # Start the progress bars (but in reverse order, because
456
- # ProgressLogging prints from the bottom up, and we want chain 1 to
457
- # show up at the top)
458
- for (i, uuid) in enumerate (reverse (uuids))
459
- ProgressLogging. @logprogress name = " Chain $(nchains- i+ 1 ) /$nchains " nothing _id =
460
- uuid
439
+ @multi_ifwithprogresslogger progress name = progressname begin
440
+ if progress == :perchain
441
+ # This is the 'overall' progress bar. We create a channel for each
442
+ # chain to report back to when it finishes sampling.
443
+ progress_channel = Channel {Bool} ()
444
+ # These are the per-chain progress bars. We generate `nchains`
445
+ # independent UUIDs for each progress bar
446
+ uuids = [UUIDs. uuid4 () for _ in 1 : nchains]
447
+ progress_names = [" Chain $i /$nchains " for i in 1 : nchains]
448
+ # Start the per-chain progress bars (but in reverse order, because
449
+ # ProgressLogging prints from the bottom up, and we want chain 1 to
450
+ # show up at the top)
451
+ for (progress_name, uuid) in reverse (collect (zip (progress_names, uuids)))
452
+ ProgressLogging. @logprogress name = progress_name nothing _id = uuid
453
+ end
454
+ elseif progress == :overall
455
+ # Just a single progress bar for the entire sampling, but instead
456
+ # of tracking each chain as it comes in, we track each sample as it
457
+ # comes in. This allows us to have more granular progress updates.
458
+ progress_channel = Channel {Bool} ()
461
459
end
462
460
463
461
Distributed. @sync begin
464
- if progress
465
- # Update the progress bar.
462
+ if progress != :none
463
+ # This task updates the progress bar
466
464
Distributed. @async begin
467
465
# Determine threshold values for progress logging
468
466
# (one update per 0.5% of progress)
469
- threshold = nchains ÷ 200
470
- nextprogresschains = threshold
471
-
472
- progresschains = 0
473
- while take! (channel)
474
- progresschains += 1
475
- if progresschains >= nextprogresschains
476
- ProgressLogging. @logprogress progresschains / nchains
477
- nextprogresschains = progresschains + threshold
467
+ Ntotal = progress == :overall ? nchains * N : nchains
468
+ threshold = Ntotal ÷ 200
469
+ next_update = threshold
470
+
471
+ itotal = 0
472
+ while take! (progress_channel)
473
+ itotal += 1
474
+ if itotal >= next_update
475
+ ProgressLogging. @logprogress itotal / Ntotal
476
+ next_update = itotal + threshold
478
477
end
479
478
end
480
479
end
@@ -498,15 +497,23 @@ function mcmcsample(
498
497
# Seed the chunk-specific random number generator with the pre-made seed.
499
498
Random. seed! (_rng, seeds[chainidx])
500
499
501
- # Sample a chain and save it to the vector.
502
- child_progressname = " Chain $chainidx /$nchains "
503
- child_progress = if progress == false
504
- false
505
- else
506
- ProgressLogging. Progress (
507
- uuids[chainidx]; name= child_progressname
508
- )
500
+ # Determine how to monitor progress for the child chains.
501
+ child_progress, child_progressname = if progress == :none
502
+ # No need to create a progress bar
503
+ false , " "
504
+ elseif progress == :overall
505
+ # No need to create a new progress bar, but we need to
506
+ # pass the channel to the child so that it can log when
507
+ # it has finished obtaining each sample.
508
+ progress_channel, " "
509
+ elseif progress == :perchain
510
+ # We need to specify both the ID of the progress bar for
511
+ # the child to update, and we also specify the name to use
512
+ # for the progress bar.
513
+ uuids[chainidx], progress_names[chainidx]
509
514
end
515
+
516
+ # Sample a chain and save it to the vector.
510
517
chains[chainidx] = StatsBase. sample (
511
518
_rng,
512
519
_model,
@@ -527,19 +534,36 @@ function mcmcsample(
527
534
kwargs... ,
528
535
)
529
536
530
- ProgressLogging. @logprogress name = child_progressname " done" _id = uuids[chainidx]
531
-
532
- # Update the progress bar.
533
- progress && put! (channel, true )
537
+ # Update the progress bars.
538
+ if progress == :perchain
539
+ # Tell the 'main' progress bar that this chain is done.
540
+ put! (progress_channel, true )
541
+ # Conclude the per-chain progress bar.
542
+ ProgressLogging. @logprogress progress_names[chainidx] " done" _id = uuids[chainidx]
543
+ end
544
+ # Note that if progress == :overall, we don't need to do anything
545
+ # because progress on that bar is triggered by
546
+ # samples being obtained rather than chains being
547
+ # completed.
534
548
end
535
549
end
536
550
finally
537
- # Stop updating the progress bars (either if sampling is done, or if
538
- # an error occurs).
539
- progress && put! (channel, false )
540
- for (i, uuid) in enumerate (reverse (uuids))
541
- ProgressLogging. @logprogress name = " Chain $(nchains- i+ 1 ) /$nchains " " done" _id =
542
- uuid
551
+ if progress == :perchain
552
+ # Stop updating the main progress bar (either if sampling
553
+ # is done, or if an error occurs).
554
+ put! (progress_channel, false )
555
+ # Additionally stop the per-chain progress bars (but in
556
+ # reverse order, because ProgressLogging prints from
557
+ # the bottom up, and we want chain 1 to show up at the
558
+ # top)
559
+ for (progress_name, uuid) in
560
+ reverse (collect (zip (progress_names, uuids)))
561
+ ProgressLogging. @logprogress progress_name " done" _id = uuid
562
+ end
563
+ elseif progress == :overall
564
+ # Stop updating the main progress bar (either if sampling
565
+ # is done, or if an error occurs).
566
+ put! (progress_channel, false )
543
567
end
544
568
end
545
569
end
@@ -589,7 +613,7 @@ function mcmcsample(
589
613
pool = Distributed. CachingPool (Distributed. workers ())
590
614
591
615
local chains
592
- @ifwithprogresslogger progress name = progressname begin
616
+ @single_ifwithprogresslogger progress name = progressname begin
593
617
# Create a channel for progress logging.
594
618
if progress
595
619
channel = Distributed. RemoteChannel (() -> Channel {Bool} (Distributed. nworkers ()))
0 commit comments