Skip to content

Commit 1742b9b

Browse files
committed
Remove resume_from
1 parent 08212a2 commit 1742b9b

File tree

4 files changed

+17
-43
lines changed

4 files changed

+17
-43
lines changed

HISTORY.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ The separation of these functions was primarily implemented to avoid performing
5454

5555
Previously `VarInfo` (or more correctly, the `Metadata` object within a `VarInfo`), had a flag called `"del"` for all variables. If it was set to `true` the variable was to be overwritten with a new value at the next evaluation. The new `InitContext` and related changes above make this flag unnecessary, and it has been removed.
5656

57+
### Removal of `resume_from`
58+
59+
The `resume_from=chn` keyword argument to `sample` has been removed; please use `initial_state=DynamicPPL.loadstate(chn)` instead.
60+
`loadstate` is exported from DynamicPPL.
61+
5762
**Other changes**
5863

5964
### `setleafcontext(model, context)`

src/DynamicPPL.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ export AbstractVarInfo,
130130
prefix,
131131
returned,
132132
to_submodel,
133+
# Chain save/resume
134+
loadstate,
133135
# Convenience macros
134136
@addlogprob!,
135137
value_iterator_from_chain,

src/sampler.jl

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,15 @@ function AbstractMCMC.sample(
5858
model::Model,
5959
sampler::Sampler,
6060
N::Integer;
61-
chain_type=default_chain_type(sampler),
62-
resume_from=nothing,
6361
initial_params=init_strategy(sampler),
64-
initial_state=loadstate(resume_from),
62+
initial_state=nothing,
6563
kwargs...,
6664
)
6765
if hasproperty(kwargs, :initial_parameters)
6866
@warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead."
6967
end
7068
return AbstractMCMC.mcmcsample(
71-
rng, model, sampler, N; chain_type, initial_params, initial_state, kwargs...
69+
rng, model, sampler, N; initial_params, initial_state, kwargs...
7270
)
7371
end
7472

@@ -79,10 +77,8 @@ function AbstractMCMC.sample(
7977
parallel::AbstractMCMC.AbstractMCMCEnsemble,
8078
N::Integer,
8179
nchains::Integer;
82-
chain_type=default_chain_type(sampler),
8380
initial_params=fill(init_strategy(sampler), nchains),
84-
resume_from=nothing,
85-
initial_state=loadstate(resume_from),
81+
initial_state=nothing,
8682
kwargs...,
8783
)
8884
if hasproperty(kwargs, :initial_parameters)
@@ -95,7 +91,6 @@ function AbstractMCMC.sample(
9591
parallel,
9692
N,
9793
nchains;
98-
chain_type,
9994
initial_params,
10095
initial_state,
10196
kwargs...,
@@ -124,20 +119,12 @@ function AbstractMCMC.step(
124119
end
125120

126121
"""
127-
loadstate(data)
122+
loadstate(chain::AbstractChains)
128123
129-
Load sampler state from `data`.
130-
131-
By default, `data` is returned.
132-
"""
133-
loadstate(data) = data
134-
135-
"""
136-
default_chain_type(sampler)
137-
138-
Default type of the chain of posterior samples from `sampler`.
124+
Load sampler state from an `AbstractChains` object. This function should be overloaded by a
125+
concrete Chains implementation.
139126
"""
140-
default_chain_type(::Sampler) = Any
127+
function loadstate end
141128

142129
"""
143130
initialstep(rng, model, sampler, varinfo; kwargs...)

test/sampler.jl

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
@test AbstractMCMC.step(Xoshiro(468), g(), spl) isa Any
1313
end
1414

15-
@testset "initial_state and resume_from kwargs" begin
15+
@testset "initial_state" begin
1616
# Model is unused, but has to be a DynamicPPL.Model otherwise we won't hit our
1717
# overloaded method.
1818
@model f() = x ~ Normal()
@@ -58,19 +58,10 @@
5858
spl,
5959
N_iters;
6060
progress=false,
61-
initial_state=chn.info.samplerstate,
61+
initial_state=DynamicPPL.loadstate(chn),
6262
chain_type=MCMCChains.Chains,
6363
)
6464
@test all(chn2[:x] .== initial_value)
65-
# using `resume_from`
66-
chn3 = sample(
67-
model,
68-
spl,
69-
N_iters;
70-
progress=false,
71-
resume_from=chn,
72-
chain_type=MCMCChains.Chains,
73-
)
7465
@test all(chn3[:x] .== initial_value)
7566
end
7667

@@ -94,21 +85,10 @@
9485
N_iters,
9586
N_chains;
9687
progress=false,
97-
initial_state=chn.info.samplerstate,
88+
initial_state=DynamicPPL.loadstate(chn),
9889
chain_type=MCMCChains.Chains,
9990
)
10091
@test all(i -> chn2[:x][i, :] == initial_value, 1:N_iters)
101-
# using `resume_from`
102-
chn3 = sample(
103-
model,
104-
spl,
105-
MCMCThreads(),
106-
N_iters,
107-
N_chains;
108-
progress=false,
109-
resume_from=chn,
110-
chain_type=MCMCChains.Chains,
111-
)
11292
@test all(i -> chn3[:x][i, :] == initial_value, 1:N_iters)
11393
end
11494
end

0 commit comments

Comments
 (0)