Skip to content

Commit cc3352a

Browse files
committed
Add new signatures for transition init/save
1 parent 651fe75 commit cc3352a

File tree

1 file changed

+36
-13
lines changed

1 file changed

+36
-13
lines changed

src/AbstractMCMC.jl

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,10 @@ end
247247

248248
"""
249249
transitions_init(transition, model, sampler, N[; kwargs...])
250+
transitions_init(transition, model, sampler[; kwargs...])
250251
251252
Generate a container for the `N` transitions of the MCMC `sampler` for the provided
252-
`model`, whose first transition is `transition`.
253+
`model`, whose first transition is `transition`. Can be called with an without a predefined size `N`.
253254
"""
254255
function transitions_init(
255256
transition,
@@ -261,11 +262,21 @@ function transitions_init(
261262
return Vector{typeof(transition)}(undef, N)
262263
end
263264

265+
function transitions_init(
266+
transition,
267+
::AbstractModel,
268+
::AbstractSampler;
269+
kwargs...
270+
)
271+
return [transition]
272+
end
273+
264274
"""
265275
transitions_save!(transitions, iteration, transition, model, sampler, N[; kwargs...])
276+
transitions_save!(transitions, iteration, transition, model, sampler[; kwargs...])
266277
267278
Save the `transition` of the MCMC `sampler` at the current `iteration` in the container of
268-
`transitions`.
279+
`transitions`. Can be called with an without a predefined size `N`.
269280
"""
270281
function transitions_save!(
271282
transitions::AbstractVector,
@@ -280,6 +291,19 @@ function transitions_save!(
280291
return
281292
end
282293

294+
295+
function transitions_save!(
296+
transitions::AbstractVector,
297+
iteration::Integer,
298+
transition,
299+
::AbstractModel,
300+
::AbstractSampler;
301+
kwargs...
302+
)
303+
push!(transitions, transition)
304+
return
305+
end
306+
283307
"""
284308
psample([rng::AbstractRNG, ]model::AbstractModel, sampler::AbstractSampler, N::Integer,
285309
nchains::Integer; kwargs...)
@@ -421,7 +445,6 @@ end
421445
# Sample-until-convergence tools #
422446
##################################
423447

424-
425448
"""
426449
sample([rng::AbstractRNG, ]model::AbstractModel, s::AbstractSampler, is_done::Function; kwargs...)
427450
@@ -450,19 +473,19 @@ function StatsBase.sample(
450473
# Perform any necessary setup.
451474
sample_init!(rng, model, sampler, 1; kwargs...)
452475

453-
# Obtain the initial transition.
454-
transition = step!(rng, model, sampler, 1; iteration=1, kwargs...)
476+
@ifwithprogresslogger progress name=progressname begin
477+
# Obtain the initial transition.
478+
transition = step!(rng, model, sampler, 1; iteration=1, kwargs...)
455479

456-
# Run callback.
457-
callback(rng, model, sampler, 1, 1, transition; kwargs...)
480+
# Run callback.
481+
callback(rng, model, sampler, 1, 1, transition; kwargs...)
458482

459-
# Save the transition.
460-
transitions = [transition]
483+
# Save the transition.
484+
transitions = transitions_init(transition, model, sampler; kwargs...)
461485

462-
# Step through the sampler until stopping.
463-
i = 2
486+
# Step through the sampler until stopping.
487+
i = 2
464488

465-
@ifwithprogresslogger progress name=progressname begin
466489
while !is_done(rng, model, sampler, transitions, i; progress=progress, kwargs...)
467490
# Obtain the next transition.
468491
transition = step!(rng, model, sampler, 1, transition; iteration=i, kwargs...)
@@ -471,7 +494,7 @@ function StatsBase.sample(
471494
callback(rng, model, sampler, 1, i, transition; kwargs...)
472495

473496
# Save the transition.
474-
push!(transitions, transition)
497+
transitions_save!(transitions, i, transition, model, sampler; kwargs...)
475498

476499
# Increment iteration counter.
477500
i += 1

0 commit comments

Comments
 (0)