@@ -102,10 +102,10 @@ may be provided as keyword argument `callback`. It is called after every samplin
102
102
function StatsBase. sample (
103
103
model:: AbstractModel ,
104
104
sampler:: AbstractSampler ,
105
- N :: Integer ;
105
+ arg ;
106
106
kwargs...
107
107
)
108
- return sample (GLOBAL_RNG, model, sampler, N ; kwargs... )
108
+ return sample (GLOBAL_RNG, model, sampler, arg ; kwargs... )
109
109
end
110
110
111
111
function StatsBase. sample (
247
247
248
248
"""
249
249
transitions_init(transition, model, sampler, N[; kwargs...])
250
+ transitions_init(transition, model, sampler[; kwargs...])
250
251
251
252
Generate a container for the `N` transitions of the MCMC `sampler` for the provided
252
- `model`, whose first transition is `transition`.
253
+ `model`, whose first transition is `transition`. Can be called with and without a predefined size `N`.
253
254
"""
254
255
function transitions_init (
255
256
transition,
@@ -261,11 +262,21 @@ function transitions_init(
261
262
return Vector {typeof(transition)} (undef, N)
262
263
end
263
264
265
+ function transitions_init (
266
+ transition,
267
+ :: AbstractModel ,
268
+ :: AbstractSampler ;
269
+ kwargs...
270
+ )
271
+ return [transition]
272
+ end
273
+
264
274
"""
265
275
transitions_save!(transitions, iteration, transition, model, sampler, N[; kwargs...])
276
+ transitions_save!(transitions, iteration, transition, model, sampler[; kwargs...])
266
277
267
278
Save the `transition` of the MCMC `sampler` at the current `iteration` in the container of
268
- `transitions`.
279
+ `transitions`. Can be called with and without a predefined size `N`.
269
280
"""
270
281
function transitions_save! (
271
282
transitions:: AbstractVector ,
@@ -280,6 +291,19 @@ function transitions_save!(
280
291
return
281
292
end
282
293
294
+
295
+ function transitions_save! (
296
+ transitions:: AbstractVector ,
297
+ iteration:: Integer ,
298
+ transition,
299
+ :: AbstractModel ,
300
+ :: AbstractSampler ;
301
+ kwargs...
302
+ )
303
+ push! (transitions, transition)
304
+ return
305
+ end
306
+
283
307
"""
284
308
psample([rng::AbstractRNG, ]model::AbstractModel, sampler::AbstractSampler, N::Integer,
285
309
nchains::Integer; kwargs...)
@@ -417,4 +441,71 @@ function steps!(
417
441
return Stepper (rng, model, s, kwargs)
418
442
end
419
443
444
+ # #################################
445
+ # Sample-until-convergence tools #
446
+ # #################################
447
+
448
+ """
449
+ sample([rng::AbstractRNG, ]model::AbstractModel, s::AbstractSampler, is_done::Function; kwargs...)
450
+
451
+ `sample` will continuously draw samples without defining a maximum number of samples until
452
+ a convergence criteria defined by a user-defined function `is_done` returns `true`.
453
+
454
+ `is_done` is a function `f` that returns a `Bool`, with the signature
455
+
456
+ ```julia
457
+ f(rng::AbstractRNG, model::AbstractModel, s::AbstractSampler, transitions::Vector, iteration::Int; kwargs...)
458
+ ```
459
+
460
+ `is_done` should return `true` when sampling should end, and `false` otherwise.
461
+ """
462
+ function StatsBase. sample (
463
+ rng:: AbstractRNG ,
464
+ model:: AbstractModel ,
465
+ sampler:: AbstractSampler ,
466
+ is_done;
467
+ chain_type:: Type = Any,
468
+ progress = true ,
469
+ progressname = " Convergence sampling" ,
470
+ callback = (args... ; kwargs... ) -> nothing ,
471
+ kwargs...
472
+ )
473
+ # Perform any necessary setup.
474
+ sample_init! (rng, model, sampler, 1 ; kwargs... )
475
+
476
+ @ifwithprogresslogger progress name= progressname begin
477
+ # Obtain the initial transition.
478
+ transition = step! (rng, model, sampler, 1 ; iteration= 1 , kwargs... )
479
+
480
+ # Run callback.
481
+ callback (rng, model, sampler, 1 , 1 , transition; kwargs... )
482
+
483
+ # Save the transition.
484
+ transitions = transitions_init (transition, model, sampler; kwargs... )
485
+
486
+ # Step through the sampler until stopping.
487
+ i = 2
488
+
489
+ while ! is_done (rng, model, sampler, transitions, i; progress= progress, kwargs... )
490
+ # Obtain the next transition.
491
+ transition = step! (rng, model, sampler, 1 , transition; iteration= i, kwargs... )
492
+
493
+ # Run callback.
494
+ callback (rng, model, sampler, 1 , i, transition; kwargs... )
495
+
496
+ # Save the transition.
497
+ transitions_save! (transitions, i, transition, model, sampler; kwargs... )
498
+
499
+ # Increment iteration counter.
500
+ i += 1
501
+ end
502
+ end
503
+
504
+ # Wrap up the sampler, if necessary.
505
+ sample_end! (rng, model, sampler, i, transitions; kwargs... )
506
+
507
+ # Wrap the samples up.
508
+ return bundle_samples (rng, model, sampler, i, transitions, chain_type; kwargs... )
509
+ end
510
+
420
511
end # module AbstractMCMC
0 commit comments