Skip to content

Commit dc43291

Browse files
committed
report proportion of total samples instead
1 parent ab3cf26 commit dc43291

File tree

1 file changed

+34
-40
lines changed

1 file changed

+34
-40
lines changed

src/sample.jl

Lines changed: 34 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ function mcmcsample(
121121
thinning=1,
122122
chain_type::Type=Any,
123123
initial_state=nothing,
124+
_progress_channel=nothing,
124125
kwargs...,
125126
)
126127
# Check the number of requested samples.
@@ -146,10 +147,8 @@ function mcmcsample(
146147
@ifwithprogresslogger progress name = progressname begin
147148
# Determine threshold values for progress logging
148149
# (one update per 0.5% of progress)
149-
if !(progress == false)
150-
threshold = Ntotal ÷ 200
151-
next_update = threshold
152-
end
150+
threshold = Ntotal ÷ 200
151+
next_update = threshold
153152

154153
# Ugly hacky code to reset the start timer if called from a multi-chain
155154
# sampling process
@@ -180,10 +179,10 @@ function mcmcsample(
180179

181180
# Update the progress bar.
182181
itotal = 1
183-
if !(progress == false) && itotal >= next_update
182+
if itotal >= next_update
184183
if progress == true
185184
ProgressLogging.@logprogress itotal / Ntotal
186-
else
185+
elseif progress isa ProgressLogging.Progress
187186
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
188187
progress.id
189188
end
@@ -200,10 +199,11 @@ function mcmcsample(
200199
end
201200

202201
# Update the progress bar.
203-
if !(progress == false) && (itotal += 1) >= next_update
202+
_progress_channel !== nothing && put!(_progress_channel, true)
203+
if (itotal += 1) >= next_update
204204
if progress == true
205205
ProgressLogging.@logprogress itotal / Ntotal
206-
else
206+
elseif progress isa ProgressLogging.Progress
207207
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
208208
progress.id
209209
end
@@ -230,10 +230,10 @@ function mcmcsample(
230230
end
231231

232232
# Update progress bar.
233-
if !(progress == false) && (itotal += 1) >= next_update
233+
if (itotal += 1) >= next_update
234234
if progress == true
235235
ProgressLogging.@logprogress itotal / Ntotal
236-
else
236+
elseif progress isa ProgressLogging.Progress
237237
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
238238
progress.id
239239
end
@@ -256,10 +256,11 @@ function mcmcsample(
256256
samples = save!!(samples, sample, i, model, sampler, N; kwargs...)
257257

258258
# Update the progress bar.
259-
if !(progress == false) && (itotal += 1) >= next_update
259+
_progress_channel !== nothing && put!(_progress_channel, true)
260+
if (itotal += 1) >= next_update
260261
if progress == true
261262
ProgressLogging.@logprogress itotal / Ntotal
262-
else
263+
elseif progress isa ProgressLogging.Progress
263264
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
264265
progress.id
265266
end
@@ -451,30 +452,31 @@ function mcmcsample(
451452
channel = Channel{Bool}(length(interval))
452453
end
453454
# Generate nchains independent UUIDs for each progress bar
454-
uuids = [uuid4() for _ in 1:nchains]
455+
# uuids = [uuid4() for _ in 1:nchains]
455456
# Start the progress bars (but in reverse order, because
456457
# ProgressLogging prints from the bottom up, and we want chain 1 to
457458
# show up at the top)
458-
for (i, uuid) in enumerate(reverse(uuids))
459-
ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" nothing _id =
460-
uuid
461-
end
459+
# for (i, uuid) in enumerate(reverse(uuids))
460+
# ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" nothing _id =
461+
# uuid
462+
# end
462463

463464
Distributed.@sync begin
464465
if progress
465466
# Update the progress bar.
466467
Distributed.@async begin
467468
# Determine threshold values for progress logging
468469
# (one update per 0.5% of progress)
469-
threshold = nchains ÷ 200
470-
nextprogresschains = threshold
470+
nprogupdates = nchains * N
471+
threshold = nprogupdates ÷ 200
472+
counter = 0
473+
next_update = threshold
471474

472-
progresschains = 0
473475
while take!(channel)
474-
progresschains += 1
475-
if progresschains >= nextprogresschains
476-
ProgressLogging.@logprogress progresschains / nchains
477-
nextprogresschains = progresschains + threshold
476+
counter += 1
477+
if counter >= next_update
478+
ProgressLogging.@logprogress counter / nprogupdates
479+
next_update = next_update + threshold
478480
end
479481
end
480482
end
@@ -499,21 +501,12 @@ function mcmcsample(
499501
Random.seed!(_rng, seeds[chainidx])
500502

501503
# Sample a chain and save it to the vector.
502-
child_progressname = "Chain $chainidx/$nchains"
503-
child_progress = if progress == false
504-
false
505-
else
506-
ProgressLogging.Progress(
507-
uuids[chainidx]; name=child_progressname
508-
)
509-
end
510504
chains[chainidx] = StatsBase.sample(
511505
_rng,
512506
_model,
513507
_sampler,
514508
N;
515-
progress=child_progress,
516-
progressname=child_progressname,
509+
progress=false,
517510
initial_params=if initial_params === nothing
518511
nothing
519512
else
@@ -524,11 +517,12 @@ function mcmcsample(
524517
else
525518
initial_state[chainidx]
526519
end,
520+
_progress_channel=channel,
527521
kwargs...,
528522
)
529523

530-
ProgressLogging.@logprogress name = child_progressname "done" _id = uuids[chainidx]
531-
524+
# ProgressLogging.@logprogress name = child_progressname "done" _id = uuids[chainidx]
525+
#
532526
# Update the progress bar.
533527
progress && put!(channel, true)
534528
end
@@ -537,10 +531,10 @@ function mcmcsample(
537531
# Stop updating the progress bars (either if sampling is done, or if
538532
# an error occurs).
539533
progress && put!(channel, false)
540-
for (i, uuid) in enumerate(reverse(uuids))
541-
ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" "done" _id =
542-
uuid
543-
end
534+
# for (i, uuid) in enumerate(reverse(uuids))
535+
# ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" "done" _id =
536+
# uuid
537+
# end
544538
end
545539
end
546540
end

0 commit comments

Comments
 (0)