Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ The separation of these functions was primarily implemented to avoid performing

Previously `VarInfo` (or more correctly, the `Metadata` object within a `VarInfo`), had a flag called `"del"` for all variables. If it was set to `true` the variable was to be overwritten with a new value at the next evaluation. The new `InitContext` and related changes above make this flag unnecessary, and it has been removed.

### Removal of `resume_from`

The `resume_from=chn` keyword argument to `sample` has been removed; please use `initial_state=DynamicPPL.loadstate(chn)` instead.
`loadstate` is exported from DynamicPPL.

**Other changes**

### `setleafcontext(model, context)`
Expand Down
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ export AbstractVarInfo,
prefix,
returned,
to_submodel,
# Chain save/resume
loadstate,
# Convenience macros
@addlogprob!,
value_iterator_from_chain,
Expand Down
37 changes: 8 additions & 29 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,15 @@ function AbstractMCMC.sample(
model::Model,
sampler::Sampler,
N::Integer;
chain_type=default_chain_type(sampler),
resume_from=nothing,
initial_params=init_strategy(sampler),
initial_state=loadstate(resume_from),
initial_state=nothing,
kwargs...,
)
if hasproperty(kwargs, :initial_parameters)
@warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead."
end
return AbstractMCMC.mcmcsample(
rng, model, sampler, N; chain_type, initial_params, initial_state, kwargs...
rng, model, sampler, N; initial_params, initial_state, kwargs...
)
end

Expand All @@ -79,26 +77,15 @@ function AbstractMCMC.sample(
parallel::AbstractMCMC.AbstractMCMCEnsemble,
N::Integer,
nchains::Integer;
chain_type=default_chain_type(sampler),
initial_params=fill(init_strategy(sampler), nchains),
resume_from=nothing,
initial_state=loadstate(resume_from),
initial_state=nothing,
kwargs...,
)
if hasproperty(kwargs, :initial_parameters)
@warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead."
end
return AbstractMCMC.mcmcsample(
rng,
model,
sampler,
parallel,
N,
nchains;
chain_type,
initial_params,
initial_state,
kwargs...,
rng, model, sampler, parallel, N, nchains; initial_params, initial_state, kwargs...
)
end

Expand All @@ -124,20 +111,12 @@ function AbstractMCMC.step(
end

"""
loadstate(data)
loadstate(chain::AbstractChains)
Load sampler state from `data`.
By default, `data` is returned.
"""
loadstate(data) = data

"""
default_chain_type(sampler)
Default type of the chain of posterior samples from `sampler`.
Load sampler state from an `AbstractChains` object. This function should be overloaded by a
concrete Chains implementation.
"""
default_chain_type(::Sampler) = Any
function loadstate end

"""
initialstep(rng, model, sampler, varinfo; kwargs...)
Expand Down
30 changes: 3 additions & 27 deletions test/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@test AbstractMCMC.step(Xoshiro(468), g(), spl) isa Any
end

@testset "initial_state and resume_from kwargs" begin
@testset "initial_state" begin
# Model is unused, but has to be a DynamicPPL.Model otherwise we won't hit our
# overloaded method.
@model f() = x ~ Normal()
Expand Down Expand Up @@ -52,26 +52,15 @@
chn = sample(model, spl, N_iters; progress=false, chain_type=MCMCChains.Chains)
initial_value = chn[:x][1]
@test all(chn[:x] .== initial_value) # sanity check
# using `initial_state`
chn2 = sample(
model,
spl,
N_iters;
progress=false,
initial_state=chn.info.samplerstate,
initial_state=DynamicPPL.loadstate(chn),
chain_type=MCMCChains.Chains,
)
@test all(chn2[:x] .== initial_value)
# using `resume_from`
chn3 = sample(
model,
spl,
N_iters;
progress=false,
resume_from=chn,
chain_type=MCMCChains.Chains,
)
@test all(chn3[:x] .== initial_value)
end

@testset "multiple-chain sampling" begin
Expand All @@ -86,30 +75,17 @@
)
initial_value = chn[:x][1, :]
@test all(i -> chn[:x][i, :] == initial_value, 1:N_iters) # sanity check
# using `initial_state`
chn2 = sample(
model,
spl,
MCMCThreads(),
N_iters,
N_chains;
progress=false,
initial_state=chn.info.samplerstate,
initial_state=DynamicPPL.loadstate(chn),
chain_type=MCMCChains.Chains,
)
@test all(i -> chn2[:x][i, :] == initial_value, 1:N_iters)
# using `resume_from`
chn3 = sample(
model,
spl,
MCMCThreads(),
N_iters,
N_chains;
progress=false,
resume_from=chn,
chain_type=MCMCChains.Chains,
)
@test all(i -> chn3[:x][i, :] == initial_value, 1:N_iters)
end
end

Expand Down
Loading