Skip to content

Commit c0201f5

Browse files
committed
Remove resume_from
1 parent 7311465 commit c0201f5

File tree

4 files changed

+18
-45
lines changed

4 files changed

+18
-45
lines changed

HISTORY.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ Other functions such as `tilde_assume` and `assume` (and their `observe` counter
5050
Note that this was effectively already the case in DynamicPPL 0.37 (where they were just wrappers around each other).
5151
The separation of these functions was primarily implemented to avoid performing extra work where unneeded (e.g. to not calculate the log-likelihood when `PriorContext` was being used). This functionality has since been replaced with accumulators (see the 0.37 changelog for more details).
5252

53+
### Removal of `resume_from`
54+
55+
The `resume_from=chn` keyword argument to `sample` has been removed; please use `initial_state=DynamicPPL.loadstate(chn)` instead.
56+
`loadstate` is exported from DynamicPPL.
57+
5358
**Other changes**
5459

5560
### Reimplementation of functions using `InitContext`

src/DynamicPPL.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ export AbstractVarInfo,
129129
prefix,
130130
returned,
131131
to_submodel,
132+
# Chain save/resume
133+
loadstate,
132134
# Convenience macros
133135
@addlogprob!,
134136
value_iterator_from_chain,

src/sampler.jl

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,10 @@ function AbstractMCMC.sample(
5858
model::Model,
5959
sampler::Sampler,
6060
N::Integer;
61-
chain_type=default_chain_type(sampler),
62-
resume_from=nothing,
63-
initial_state=loadstate(resume_from),
61+
initial_state=nothing,
6462
kwargs...,
6563
)
66-
return AbstractMCMC.mcmcsample(
67-
rng, model, sampler, N; chain_type, initial_state, kwargs...
68-
)
64+
return AbstractMCMC.mcmcsample(rng, model, sampler, N; initial_state, kwargs...)
6965
end
7066

7167
function AbstractMCMC.sample(
@@ -75,13 +71,11 @@ function AbstractMCMC.sample(
7571
parallel::AbstractMCMC.AbstractMCMCEnsemble,
7672
N::Integer,
7773
nchains::Integer;
78-
chain_type=default_chain_type(sampler),
79-
resume_from=nothing,
80-
initial_state=loadstate(resume_from),
74+
initial_state=nothing,
8175
kwargs...,
8276
)
8377
return AbstractMCMC.mcmcsample(
84-
rng, model, sampler, parallel, N, nchains; chain_type, initial_state, kwargs...
78+
rng, model, sampler, parallel, N, nchains; initial_state, kwargs...
8579
)
8680
end
8781

@@ -107,20 +101,12 @@ function AbstractMCMC.step(
107101
end
108102

109103
"""
110-
loadstate(data)
111-
112-
Load sampler state from `data`.
113-
114-
By default, `data` is returned.
115-
"""
116-
loadstate(data) = data
117-
118-
"""
119-
default_chain_type(sampler)
104+
loadstate(chain::AbstractChains)
120105
121-
Default type of the chain of posterior samples from `sampler`.
106+
Load sampler state from an `AbstractChains` object. This function should be overloaded by a
107+
concrete Chains implementation.
122108
"""
123-
default_chain_type(::Sampler) = Any
109+
function loadstate end
124110

125111
"""
126112
initialstep(rng, model, sampler, varinfo; kwargs...)

test/sampler.jl

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
@test AbstractMCMC.step(Xoshiro(468), g(), spl) isa Any
1313
end
1414

15-
@testset "initial_state and resume_from kwargs" begin
15+
@testset "initial_state" begin
1616
# Model is unused, but has to be a DynamicPPL.Model otherwise we won't hit our
1717
# overloaded method.
1818
@model f() = x ~ Normal()
@@ -58,19 +58,10 @@
5858
spl,
5959
N_iters;
6060
progress=false,
61-
initial_state=chn.info.samplerstate,
61+
initial_state=DynamicPPL.loadstate(chn),
6262
chain_type=MCMCChains.Chains,
6363
)
6464
@test all(chn2[:x] .== initial_value)
65-
# using `resume_from`
66-
chn3 = sample(
67-
model,
68-
spl,
69-
N_iters;
70-
progress=false,
71-
resume_from=chn,
72-
chain_type=MCMCChains.Chains,
73-
)
7465
@test all(chn3[:x] .== initial_value)
7566
end
7667

@@ -94,21 +85,10 @@
9485
N_iters,
9586
N_chains;
9687
progress=false,
97-
initial_state=chn.info.samplerstate,
88+
initial_state=DynamicPPL.loadstate(chn),
9889
chain_type=MCMCChains.Chains,
9990
)
10091
@test all(i -> chn2[:x][i, :] == initial_value, 1:N_iters)
101-
# using `resume_from`
102-
chn3 = sample(
103-
model,
104-
spl,
105-
MCMCThreads(),
106-
N_iters,
107-
N_chains;
108-
progress=false,
109-
resume_from=chn,
110-
chain_type=MCMCChains.Chains,
111-
)
11292
@test all(i -> chn3[:x][i, :] == initial_value, 1:N_iters)
11393
end
11494
end

0 commit comments

Comments
 (0)