@@ -121,6 +121,7 @@ function mcmcsample(
121
121
thinning= 1 ,
122
122
chain_type:: Type = Any,
123
123
initial_state= nothing ,
124
+ _progress_channel= nothing ,
124
125
kwargs... ,
125
126
)
126
127
# Check the number of requested samples.
@@ -146,10 +147,8 @@ function mcmcsample(
146
147
@ifwithprogresslogger progress name = progressname begin
147
148
# Determine threshold values for progress logging
148
149
# (one update per 0.5% of progress)
149
- if ! (progress == false )
150
- threshold = Ntotal ÷ 200
151
- next_update = threshold
152
- end
150
+ threshold = Ntotal ÷ 200
151
+ next_update = threshold
153
152
154
153
# Ugly hacky code to reset the start timer if called from a multi-chain
155
154
# sampling process
@@ -180,10 +179,10 @@ function mcmcsample(
180
179
181
180
# Update the progress bar.
182
181
itotal = 1
183
- if ! (progress == false ) && itotal >= next_update
182
+ if itotal >= next_update
184
183
if progress == true
185
184
ProgressLogging. @logprogress itotal / Ntotal
186
- else
185
+ elseif progress isa ProgressLogging . Progress
187
186
ProgressLogging. @logprogress name = progressname itotal / Ntotal _id =
188
187
progress. id
189
188
end
@@ -200,10 +199,11 @@ function mcmcsample(
200
199
end
201
200
202
201
# Update the progress bar.
203
- if ! (progress == false ) && (itotal += 1 ) >= next_update
202
+ _progress_channel != = nothing && put! (_progress_channel, true )
203
+ if (itotal += 1 ) >= next_update
204
204
if progress == true
205
205
ProgressLogging. @logprogress itotal / Ntotal
206
- else
206
+ elseif progress isa ProgressLogging . Progress
207
207
ProgressLogging. @logprogress name = progressname itotal / Ntotal _id =
208
208
progress. id
209
209
end
@@ -230,10 +230,10 @@ function mcmcsample(
230
230
end
231
231
232
232
# Update progress bar.
233
- if ! (progress == false ) && (itotal += 1 ) >= next_update
233
+ if (itotal += 1 ) >= next_update
234
234
if progress == true
235
235
ProgressLogging. @logprogress itotal / Ntotal
236
- else
236
+ elseif progress isa ProgressLogging . Progress
237
237
ProgressLogging. @logprogress name = progressname itotal / Ntotal _id =
238
238
progress. id
239
239
end
@@ -256,10 +256,11 @@ function mcmcsample(
256
256
samples = save!! (samples, sample, i, model, sampler, N; kwargs... )
257
257
258
258
# Update the progress bar.
259
- if ! (progress == false ) && (itotal += 1 ) >= next_update
259
+ _progress_channel != = nothing && put! (_progress_channel, true )
260
+ if (itotal += 1 ) >= next_update
260
261
if progress == true
261
262
ProgressLogging. @logprogress itotal / Ntotal
262
- else
263
+ elseif progress isa ProgressLogging . Progress
263
264
ProgressLogging. @logprogress name = progressname itotal / Ntotal _id =
264
265
progress. id
265
266
end
@@ -451,30 +452,31 @@ function mcmcsample(
451
452
channel = Channel {Bool} (length (interval))
452
453
end
453
454
# Generate nchains independent UUIDs for each progress bar
454
- uuids = [uuid4 () for _ in 1 : nchains]
455
+ # uuids = [uuid4() for _ in 1:nchains]
455
456
# Start the progress bars (but in reverse order, because
456
457
# ProgressLogging prints from the bottom up, and we want chain 1 to
457
458
# show up at the top)
458
- for (i, uuid) in enumerate (reverse (uuids))
459
- ProgressLogging. @logprogress name = " Chain $(nchains- i+ 1 ) /$nchains " nothing _id =
460
- uuid
461
- end
459
+ # for (i, uuid) in enumerate(reverse(uuids))
460
+ # ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" nothing _id =
461
+ # uuid
462
+ # end
462
463
463
464
Distributed. @sync begin
464
465
if progress
465
466
# Update the progress bar.
466
467
Distributed. @async begin
467
468
# Determine threshold values for progress logging
468
469
# (one update per 0.5% of progress)
469
- threshold = nchains ÷ 200
470
- nextprogresschains = threshold
470
+ nprogupdates = nchains * N
471
+ threshold = nprogupdates ÷ 200
472
+ counter = 0
473
+ next_update = threshold
471
474
472
- progresschains = 0
473
475
while take! (channel)
474
- progresschains += 1
475
- if progresschains >= nextprogresschains
476
- ProgressLogging. @logprogress progresschains / nchains
477
- nextprogresschains = progresschains + threshold
476
+ counter += 1
477
+ if counter >= next_update
478
+ ProgressLogging. @logprogress counter / nprogupdates
479
+ next_update = next_update + threshold
478
480
end
479
481
end
480
482
end
@@ -499,21 +501,12 @@ function mcmcsample(
499
501
Random. seed! (_rng, seeds[chainidx])
500
502
501
503
# Sample a chain and save it to the vector.
502
- child_progressname = " Chain $chainidx /$nchains "
503
- child_progress = if progress == false
504
- false
505
- else
506
- ProgressLogging. Progress (
507
- uuids[chainidx]; name= child_progressname
508
- )
509
- end
510
504
chains[chainidx] = StatsBase. sample (
511
505
_rng,
512
506
_model,
513
507
_sampler,
514
508
N;
515
- progress= child_progress,
516
- progressname= child_progressname,
509
+ progress= false ,
517
510
initial_params= if initial_params === nothing
518
511
nothing
519
512
else
@@ -524,11 +517,12 @@ function mcmcsample(
524
517
else
525
518
initial_state[chainidx]
526
519
end ,
520
+ _progress_channel= channel,
527
521
kwargs... ,
528
522
)
529
523
530
- ProgressLogging. @logprogress name = child_progressname " done" _id = uuids[chainidx]
531
-
524
+ # ProgressLogging.@logprogress name = child_progressname "done" _id = uuids[chainidx]
525
+ #
532
526
# Update the progress bar.
533
527
progress && put! (channel, true )
534
528
end
@@ -537,10 +531,10 @@ function mcmcsample(
537
531
# Stop updating the progress bars (either if sampling is done, or if
538
532
# an error occurs).
539
533
progress && put! (channel, false )
540
- for (i, uuid) in enumerate (reverse (uuids))
541
- ProgressLogging. @logprogress name = " Chain $(nchains- i+ 1 ) /$nchains " " done" _id =
542
- uuid
543
- end
534
+ # for (i, uuid) in enumerate(reverse(uuids))
535
+ # ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" "done" _id =
536
+ # uuid
537
+ # end
544
538
end
545
539
end
546
540
end
0 commit comments