diff --git a/HISTORY.md b/HISTORY.md index f69c4a6fd..29bc56493 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -54,6 +54,11 @@ The separation of these functions was primarily implemented to avoid performing Previously `VarInfo` (or more correctly, the `Metadata` object within a `VarInfo`), had a flag called `"del"` for all variables. If it was set to `true` the variable was to be overwritten with a new value at the next evaluation. The new `InitContext` and related changes above make this flag unnecessary, and it has been removed. +### Removal of `resume_from` + +The `resume_from=chn` keyword argument to `sample` has been removed; please use `initial_state=DynamicPPL.loadstate(chn)` instead. +`loadstate` is exported from DynamicPPL. + **Other changes** ### `setleafcontext(model, context)` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 31adadb55..43180b091 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -130,6 +130,8 @@ export AbstractVarInfo, prefix, returned, to_submodel, + # Chain save/resume + loadstate, # Convenience macros @addlogprob!, value_iterator_from_chain, diff --git a/src/sampler.jl b/src/sampler.jl index c598e13f5..01f056053 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -58,17 +58,15 @@ function AbstractMCMC.sample( model::Model, sampler::Sampler, N::Integer; - chain_type=default_chain_type(sampler), - resume_from=nothing, initial_params=init_strategy(sampler), - initial_state=loadstate(resume_from), + initial_state=nothing, kwargs..., ) if hasproperty(kwargs, :initial_parameters) @warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead." end return AbstractMCMC.mcmcsample( - rng, model, sampler, N; chain_type, initial_params, initial_state, kwargs... + rng, model, sampler, N; initial_params, initial_state, kwargs... ) end @@ -79,26 +77,15 @@ function AbstractMCMC.sample( parallel::AbstractMCMC.AbstractMCMCEnsemble, N::Integer, nchains::Integer; - chain_type=default_chain_type(sampler), initial_params=fill(init_strategy(sampler), nchains), - resume_from=nothing, - initial_state=loadstate(resume_from), + initial_state=nothing, kwargs..., ) if hasproperty(kwargs, :initial_parameters) @warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead." end return AbstractMCMC.mcmcsample( - rng, - model, - sampler, - parallel, - N, - nchains; - chain_type, - initial_params, - initial_state, - kwargs..., + rng, model, sampler, parallel, N, nchains; initial_params, initial_state, kwargs... ) end @@ -124,20 +111,12 @@ function AbstractMCMC.step( end """ - loadstate(data) + loadstate(chain::AbstractChains) -Load sampler state from `data`. - -By default, `data` is returned. -""" -loadstate(data) = data - -""" - default_chain_type(sampler) - -Default type of the chain of posterior samples from `sampler`. +Load sampler state from an `AbstractChains` object. This function should be overloaded by a +concrete Chains implementation. """ -default_chain_type(::Sampler) = Any +function loadstate end """ initialstep(rng, model, sampler, varinfo; kwargs...) diff --git a/test/sampler.jl b/test/sampler.jl index 5380ad17e..8be54901d 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -12,7 +12,7 @@ @test AbstractMCMC.step(Xoshiro(468), g(), spl) isa Any end - @testset "initial_state and resume_from kwargs" begin + @testset "initial_state" begin # Model is unused, but has to be a DynamicPPL.Model otherwise we won't hit our # overloaded method. @model f() = x ~ Normal() @@ -52,26 +52,15 @@ chn = sample(model, spl, N_iters; progress=false, chain_type=MCMCChains.Chains) initial_value = chn[:x][1] @test all(chn[:x] .== initial_value) # sanity check - # using `initial_state` chn2 = sample( model, spl, N_iters; progress=false, - initial_state=chn.info.samplerstate, + initial_state=DynamicPPL.loadstate(chn), chain_type=MCMCChains.Chains, ) @test all(chn2[:x] .== initial_value) - # using `resume_from` - chn3 = sample( - model, - spl, - N_iters; - progress=false, - resume_from=chn, - chain_type=MCMCChains.Chains, - ) - @test all(chn3[:x] .== initial_value) end @testset "multiple-chain sampling" begin @@ -86,7 +75,6 @@ ) initial_value = chn[:x][1, :] @test all(i -> chn[:x][i, :] == initial_value, 1:N_iters) # sanity check - # using `initial_state` chn2 = sample( model, spl, @@ -94,22 +82,10 @@ N_iters, N_chains; progress=false, - initial_state=chn.info.samplerstate, + initial_state=DynamicPPL.loadstate(chn), chain_type=MCMCChains.Chains, ) @test all(i -> chn2[:x][i, :] == initial_value, 1:N_iters) - # using `resume_from` - chn3 = sample( - model, - spl, - MCMCThreads(), - N_iters, - N_chains; - progress=false, - resume_from=chn, - chain_type=MCMCChains.Chains, - ) - @test all(i -> chn3[:x][i, :] == initial_value, 1:N_iters) end end