Skip to content

Commit 651fe75

Browse files
committed
Wrap progress in a logger
1 parent 687d2eb commit 651fe75

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

src/AbstractMCMC.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -442,11 +442,13 @@ function StatsBase.sample(
442442
sampler::AbstractSampler,
443443
is_done;
444444
chain_type::Type=Any,
445+
progress = true,
446+
progressname = "Convergence sampling",
445447
callback = (args...; kwargs...) -> nothing,
446448
kwargs...
447449
)
448450
# Perform any necessary setup.
449-
sample_init!(rng, model, sampler, N; kwargs...)
451+
sample_init!(rng, model, sampler, 1; kwargs...)
450452

451453
# Obtain the initial transition.
452454
transition = step!(rng, model, sampler, 1; iteration=1, kwargs...)
@@ -460,22 +462,24 @@ function StatsBase.sample(
460462
# Step through the sampler until stopping.
461463
i = 2
462464

463-
while !is_done(rng, model, sampler, transitions, i; kwargs...)
464-
# Obtain the next transition.
465-
transition = step!(rng, model, sampler, 1, transition; iteration=i, kwargs...)
465+
@ifwithprogresslogger progress name=progressname begin
466+
while !is_done(rng, model, sampler, transitions, i; progress=progress, kwargs...)
467+
# Obtain the next transition.
468+
transition = step!(rng, model, sampler, 1, transition; iteration=i, kwargs...)
466469

467-
# Run callback.
468-
callback(rng, model, sampler, 1, i, transition; kwargs...)
470+
# Run callback.
471+
callback(rng, model, sampler, 1, i, transition; kwargs...)
469472

470-
# Save the transition.
471-
push!(transitions, transition)
473+
# Save the transition.
474+
push!(transitions, transition)
472475

473-
# Increment iteration counter.
474-
i += 1
476+
# Increment iteration counter.
477+
i += 1
478+
end
475479
end
476480

477481
# Wrap up the sampler, if necessary.
478-
sample_end!(rng, model, sampler, N, transitions; kwargs...)
482+
sample_end!(rng, model, sampler, i, transitions; kwargs...)
479483

480484
# Wrap the samples up.
481485
return bundle_samples(rng, model, sampler, i, transitions, chain_type; kwargs...)

0 commit comments

Comments
 (0)