1
+ using UUIDs: uuid4
2
+
1
3
# Default implementations of `sample`.
2
4
const PROGRESS = Ref (true )
3
5
@@ -144,11 +146,22 @@ function mcmcsample(
144
146
@ifwithprogresslogger progress name = progressname begin
145
147
# Determine threshold values for progress logging
146
148
# (one update per 0.5% of progress)
147
- if (progress == true || progress === nothing )
149
+ if ! (progress == false )
148
150
threshold = Ntotal ÷ 200
149
151
next_update = threshold
150
152
end
151
153
154
+ # Ugly hacky code to reset the start timer if called from a multi-chain
155
+ # sampling process
156
+ if progress isa ProgressLogging. Progress
157
+ try
158
+ bartrees = Logging. current_logger (). loggers[1 ]. logger. bartrees
159
+ bar = TerminalLoggers. findbar (bartrees, progress. id). data
160
+ bar. tfirst = time ()
161
+ catch
162
+ end
163
+ end
164
+
152
165
# Obtain the initial sample and state.
153
166
sample, state = if num_warmup > 0
154
167
if initial_state === nothing
@@ -170,7 +183,8 @@ function mcmcsample(
170
183
if progress == true
171
184
ProgressLogging. @logprogress itotal / Ntotal
172
185
else
173
- ProgressLogging. @logprogress itotal / Ntotal _id = " hello"
186
+ ProgressLogging. @logprogress name = progressname itotal / Ntotal _id =
187
+ progress. id
174
188
end
175
189
next_update = itotal + threshold
176
190
end
@@ -189,7 +203,8 @@ function mcmcsample(
189
203
if progress == true
190
204
ProgressLogging. @logprogress itotal / Ntotal
191
205
else
192
- ProgressLogging. @logprogress itotal / Ntotal _id = " hello"
206
+ ProgressLogging. @logprogress name = progressname itotal / Ntotal _id =
207
+ progress. id
193
208
end
194
209
next_update = itotal + threshold
195
210
end
@@ -218,7 +233,8 @@ function mcmcsample(
218
233
if progress == true
219
234
ProgressLogging. @logprogress itotal / Ntotal
220
235
else
221
- ProgressLogging. @logprogress itotal / Ntotal _id = " hello"
236
+ ProgressLogging. @logprogress name = progressname itotal / Ntotal _id =
237
+ progress. id
222
238
end
223
239
next_update = itotal + threshold
224
240
end
@@ -243,7 +259,8 @@ function mcmcsample(
243
259
if progress == true
244
260
ProgressLogging. @logprogress itotal / Ntotal
245
261
else
246
- ProgressLogging. @logprogress itotal / Ntotal _id = " hello"
262
+ ProgressLogging. @logprogress name = progressname itotal / Ntotal _id =
263
+ progress. id
247
264
end
248
265
next_update = itotal + threshold
249
266
end
@@ -432,6 +449,18 @@ function mcmcsample(
432
449
if progress
433
450
channel = Channel {Bool} (length (interval))
434
451
end
452
+ # Generate nchains independent UUIDs for each progress bar
453
+ uuids = [uuid4 () for _ in 1 : nchains]
454
+ # Start the progress bars (but in reverse order, because
455
+ # ProgressLogging prints from the bottom up, and we want chain 1 to
456
+ # show up at the top)
457
+ # TODO : This has an unintended effect that the 'time' field in the
458
+ # progress bar shows the total time since the beginning of sampling,
459
+ # even if the specific chain doesn't start sampling until later on.
460
+ for (i, uuid) in enumerate (reverse (uuids))
461
+ ProgressLogging. @logprogress name = " Chain $(nchains- i+ 1 ) /$nchains " nothing _id =
462
+ uuid
463
+ end
435
464
436
465
Distributed. @sync begin
437
466
if progress
@@ -472,17 +501,21 @@ function mcmcsample(
472
501
Random. seed! (_rng, seeds[chainidx])
473
502
474
503
# Sample a chain and save it to the vector.
504
+ child_progressname = " Chain $chainidx /$nchains "
475
505
child_progress = if progress == false
476
506
false
477
507
else
478
- nothing
508
+ ProgressLogging. Progress (
509
+ uuids[chainidx]; name= child_progressname
510
+ )
479
511
end
480
- @ifwithprogresslogger progress chains[chainidx] = StatsBase. sample (
512
+ chains[chainidx] = StatsBase. sample (
481
513
_rng,
482
514
_model,
483
515
_sampler,
484
516
N;
485
517
progress= child_progress,
518
+ progressname= child_progressname,
486
519
initial_params= if initial_params === nothing
487
520
nothing
488
521
else
@@ -496,6 +529,8 @@ function mcmcsample(
496
529
kwargs... ,
497
530
)
498
531
532
+ ProgressLogging. @logprogress name = child_progressname " done" _id = uuids[chainidx]
533
+
499
534
# Update the progress bar.
500
535
progress && put! (channel, true )
501
536
end
0 commit comments