Skip to content

Commit 7abd5fb

Browse files
authored
Remove resume_from and default_chain_type (#1061)
* Remove resume_from * Format * Fix test
1 parent 08212a2 commit 7abd5fb

File tree

4 files changed

+18
-56
lines changed

4 files changed

+18
-56
lines changed

HISTORY.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ The separation of these functions was primarily implemented to avoid performing
5454

5555
Previously `VarInfo` (or more correctly, the `Metadata` object within a `VarInfo`), had a flag called `"del"` for all variables. If it was set to `true` the variable was to be overwritten with a new value at the next evaluation. The new `InitContext` and related changes above make this flag unnecessary, and it has been removed.
5656

57+
### Removal of `resume_from`
58+
59+
The `resume_from=chn` keyword argument to `sample` has been removed; please use `initial_state=DynamicPPL.loadstate(chn)` instead.
60+
`loadstate` is exported from DynamicPPL.
61+
5762
**Other changes**
5863

5964
### `setleafcontext(model, context)`

src/DynamicPPL.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ export AbstractVarInfo,
130130
prefix,
131131
returned,
132132
to_submodel,
133+
# Chain save/resume
134+
loadstate,
133135
# Convenience macros
134136
@addlogprob!,
135137
value_iterator_from_chain,

src/sampler.jl

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,15 @@ function AbstractMCMC.sample(
5858
model::Model,
5959
sampler::Sampler,
6060
N::Integer;
61-
chain_type=default_chain_type(sampler),
62-
resume_from=nothing,
6361
initial_params=init_strategy(sampler),
64-
initial_state=loadstate(resume_from),
62+
initial_state=nothing,
6563
kwargs...,
6664
)
6765
if hasproperty(kwargs, :initial_parameters)
6866
@warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead."
6967
end
7068
return AbstractMCMC.mcmcsample(
71-
rng, model, sampler, N; chain_type, initial_params, initial_state, kwargs...
69+
rng, model, sampler, N; initial_params, initial_state, kwargs...
7270
)
7371
end
7472

@@ -79,26 +77,15 @@ function AbstractMCMC.sample(
7977
parallel::AbstractMCMC.AbstractMCMCEnsemble,
8078
N::Integer,
8179
nchains::Integer;
82-
chain_type=default_chain_type(sampler),
8380
initial_params=fill(init_strategy(sampler), nchains),
84-
resume_from=nothing,
85-
initial_state=loadstate(resume_from),
81+
initial_state=nothing,
8682
kwargs...,
8783
)
8884
if hasproperty(kwargs, :initial_parameters)
8985
@warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead."
9086
end
9187
return AbstractMCMC.mcmcsample(
92-
rng,
93-
model,
94-
sampler,
95-
parallel,
96-
N,
97-
nchains;
98-
chain_type,
99-
initial_params,
100-
initial_state,
101-
kwargs...,
88+
rng, model, sampler, parallel, N, nchains; initial_params, initial_state, kwargs...
10289
)
10390
end
10491

@@ -124,20 +111,12 @@ function AbstractMCMC.step(
124111
end
125112

126113
"""
127-
loadstate(data)
114+
loadstate(chain::AbstractChains)
128115
129-
Load sampler state from `data`.
130-
131-
By default, `data` is returned.
132-
"""
133-
loadstate(data) = data
134-
135-
"""
136-
default_chain_type(sampler)
137-
138-
Default type of the chain of posterior samples from `sampler`.
116+
Load sampler state from an `AbstractChains` object. This function should be overloaded by a
117+
concrete Chains implementation.
139118
"""
140-
default_chain_type(::Sampler) = Any
119+
function loadstate end
141120

142121
"""
143122
initialstep(rng, model, sampler, varinfo; kwargs...)

test/sampler.jl

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
@test AbstractMCMC.step(Xoshiro(468), g(), spl) isa Any
1313
end
1414

15-
@testset "initial_state and resume_from kwargs" begin
15+
@testset "initial_state" begin
1616
# Model is unused, but has to be a DynamicPPL.Model otherwise we won't hit our
1717
# overloaded method.
1818
@model f() = x ~ Normal()
@@ -52,26 +52,15 @@
5252
chn = sample(model, spl, N_iters; progress=false, chain_type=MCMCChains.Chains)
5353
initial_value = chn[:x][1]
5454
@test all(chn[:x] .== initial_value) # sanity check
55-
# using `initial_state`
5655
chn2 = sample(
5756
model,
5857
spl,
5958
N_iters;
6059
progress=false,
61-
initial_state=chn.info.samplerstate,
60+
initial_state=DynamicPPL.loadstate(chn),
6261
chain_type=MCMCChains.Chains,
6362
)
6463
@test all(chn2[:x] .== initial_value)
65-
# using `resume_from`
66-
chn3 = sample(
67-
model,
68-
spl,
69-
N_iters;
70-
progress=false,
71-
resume_from=chn,
72-
chain_type=MCMCChains.Chains,
73-
)
74-
@test all(chn3[:x] .== initial_value)
7564
end
7665

7766
@testset "multiple-chain sampling" begin
@@ -86,30 +75,17 @@
8675
)
8776
initial_value = chn[:x][1, :]
8877
@test all(i -> chn[:x][i, :] == initial_value, 1:N_iters) # sanity check
89-
# using `initial_state`
9078
chn2 = sample(
9179
model,
9280
spl,
9381
MCMCThreads(),
9482
N_iters,
9583
N_chains;
9684
progress=false,
97-
initial_state=chn.info.samplerstate,
85+
initial_state=DynamicPPL.loadstate(chn),
9886
chain_type=MCMCChains.Chains,
9987
)
10088
@test all(i -> chn2[:x][i, :] == initial_value, 1:N_iters)
101-
# using `resume_from`
102-
chn3 = sample(
103-
model,
104-
spl,
105-
MCMCThreads(),
106-
N_iters,
107-
N_chains;
108-
progress=false,
109-
resume_from=chn,
110-
chain_type=MCMCChains.Chains,
111-
)
112-
@test all(i -> chn3[:x][i, :] == initial_value, 1:N_iters)
11389
end
11490
end
11591

0 commit comments

Comments
 (0)