Skip to content

Commit e15d26c

Browse files
authored
Pass num_warmup as kwarg to step_warmup (#178)
* Pass `num_warmup` kwarg to `step_warmup()` * Mention the num_warmup kwarg in step_warmup's docstring
1 parent ae9760d commit e15d26c

File tree

4 files changed

+18
-11
lines changed

4 files changed

+18
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probabilistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "5.8.1"
6+
version = "5.8.2"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"

src/interface.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ When sampling using [`sample`](@ref), this takes the place of [`AbstractMCMC.ste
8383
This is useful if the sampler has an initial "warmup"-stage that is different from the
8484
standard iteration.
8585
86+
The total number of warmup steps requested in sampling will be passed to the `step_warmup`
87+
function as the `num_warmup` keyword argument. This allows implementations of `step_warmup`
88+
to customise their behavior based on this information.
89+
8690
By default, this simply calls [`AbstractMCMC.step`](@ref).
8791
"""
8892
step_warmup(rng, model, sampler; kwargs...) = step(rng, model, sampler; kwargs...)

src/sample.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,9 @@ function mcmcsample(
179179
# Obtain the initial sample and state.
180180
sample, state = if num_warmup > 0
181181
if initial_state === nothing
182-
step_warmup(rng, model, sampler; kwargs...)
182+
step_warmup(rng, model, sampler; num_warmup, kwargs...)
183183
else
184-
step_warmup(rng, model, sampler, initial_state; kwargs...)
184+
step_warmup(rng, model, sampler, initial_state; num_warmup, kwargs...)
185185
end
186186
else
187187
if initial_state === nothing
@@ -202,7 +202,7 @@ function mcmcsample(
202202
for j in 1:discard_initial
203203
# Obtain the next sample and state.
204204
sample, state = if j num_warmup
205-
step_warmup(rng, model, sampler, state; kwargs...)
205+
step_warmup(rng, model, sampler, state; num_warmup, kwargs...)
206206
else
207207
step(rng, model, sampler, state; kwargs...)
208208
end
@@ -229,7 +229,7 @@ function mcmcsample(
229229
for _ in 1:(thinning - 1)
230230
# Obtain the next sample and state.
231231
sample, state = if i keep_from_warmup
232-
step_warmup(rng, model, sampler, state; kwargs...)
232+
step_warmup(rng, model, sampler, state; num_warmup, kwargs...)
233233
else
234234
step(rng, model, sampler, state; kwargs...)
235235
end
@@ -244,7 +244,7 @@ function mcmcsample(
244244

245245
# Obtain the next sample and state.
246246
sample, state = if i keep_from_warmup
247-
step_warmup(rng, model, sampler, state; kwargs...)
247+
step_warmup(rng, model, sampler, state; num_warmup, kwargs...)
248248
else
249249
step(rng, model, sampler, state; kwargs...)
250250
end
@@ -328,9 +328,9 @@ function mcmcsample(
328328
# Obtain the initial sample and state.
329329
sample, state = if num_warmup > 0
330330
if initial_state === nothing
331-
step_warmup(rng, model, sampler; kwargs...)
331+
step_warmup(rng, model, sampler; num_warmup, kwargs...)
332332
else
333-
step_warmup(rng, model, sampler, initial_state; kwargs...)
333+
step_warmup(rng, model, sampler, initial_state; num_warmup, kwargs...)
334334
end
335335
else
336336
if initial_state === nothing
@@ -344,7 +344,7 @@ function mcmcsample(
344344
for j in 1:discard_initial
345345
# Obtain the next sample and state.
346346
sample, state = if j num_warmup
347-
step_warmup(rng, model, sampler, state; kwargs...)
347+
step_warmup(rng, model, sampler, state; num_warmup, kwargs...)
348348
else
349349
step(rng, model, sampler, state; kwargs...)
350350
end
@@ -364,15 +364,15 @@ function mcmcsample(
364364
for _ in 1:(thinning - 1)
365365
# Obtain the next sample and state.
366366
sample, state = if i keep_from_warmup
367-
step_warmup(rng, model, sampler, state; kwargs...)
367+
step_warmup(rng, model, sampler, state; num_warmup, kwargs...)
368368
else
369369
step(rng, model, sampler, state; kwargs...)
370370
end
371371
end
372372

373373
# Obtain the next sample and state.
374374
sample, state = if i keep_from_warmup
375-
step_warmup(rng, model, sampler, state; kwargs...)
375+
step_warmup(rng, model, sampler, state; num_warmup, kwargs...)
376376
else
377377
step(rng, model, sampler, state; kwargs...)
378378
end

test/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@ function AbstractMCMC.step_warmup(
2626
state::Union{Nothing,Integer}=nothing;
2727
loggers=false,
2828
initial_params=nothing,
29+
num_warmup,
2930
kwargs...,
3031
)
32+
num_warmup isa Integer ||
33+
error("num_warmup should have been passed as a keyword argument to step_warmup")
3134
transition, state = AbstractMCMC.step(
3235
rng, model, sampler, state; loggers, initial_params, kwargs...
3336
)

0 commit comments

Comments
 (0)