@@ -442,11 +442,13 @@ function StatsBase.sample(
442
442
sampler:: AbstractSampler ,
443
443
is_done;
444
444
chain_type:: Type = Any,
445
+ progress = true ,
446
+ progressname = " Convergence sampling" ,
445
447
callback = (args... ; kwargs... ) -> nothing ,
446
448
kwargs...
447
449
)
448
450
# Perform any necessary setup.
449
- sample_init! (rng, model, sampler, N ; kwargs... )
451
+ sample_init! (rng, model, sampler, 1 ; kwargs... )
450
452
451
453
# Obtain the initial transition.
452
454
transition = step! (rng, model, sampler, 1 ; iteration= 1 , kwargs... )
@@ -460,22 +462,24 @@ function StatsBase.sample(
460
462
# Step through the sampler until stopping.
461
463
i = 2
462
464
463
- while ! is_done (rng, model, sampler, transitions, i; kwargs... )
464
- # Obtain the next transition.
465
- transition = step! (rng, model, sampler, 1 , transition; iteration= i, kwargs... )
465
+ @ifwithprogresslogger progress name= progressname begin
466
+ while ! is_done (rng, model, sampler, transitions, i; progress= progress, kwargs... )
467
+ # Obtain the next transition.
468
+ transition = step! (rng, model, sampler, 1 , transition; iteration= i, kwargs... )
466
469
467
- # Run callback.
468
- callback (rng, model, sampler, 1 , i, transition; kwargs... )
470
+ # Run callback.
471
+ callback (rng, model, sampler, 1 , i, transition; kwargs... )
469
472
470
- # Save the transition.
471
- push! (transitions, transition)
473
+ # Save the transition.
474
+ push! (transitions, transition)
472
475
473
- # Increment iteration counter.
474
- i += 1
476
+ # Increment iteration counter.
477
+ i += 1
478
+ end
475
479
end
476
480
477
481
# Wrap up the sampler, if necessary.
478
- sample_end! (rng, model, sampler, N , transitions; kwargs... )
482
+ sample_end! (rng, model, sampler, i , transitions; kwargs... )
479
483
480
484
# Wrap the samples up.
481
485
return bundle_samples (rng, model, sampler, i, transitions, chain_type; kwargs... )
0 commit comments