Skip to content

Commit 7249158

Browse files
authored
Fix resume_from for parallel sampling (#1035)
* Fix `resume_from` for parallel sampling * Changelog * Fix MCMCChains >= 7.2.1 in CI
1 parent 7f802f3 commit 7249158

File tree

5 files changed

+126
-2
lines changed

5 files changed

+126
-2
lines changed

HISTORY.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# DynamicPPL Changelog
22

3+
## 0.37.2
4+
5+
Make the `resume_from` keyword work for multiple-chain (parallel) sampling as well.
6+
Prior to this version, it was silently ignored.
7+
Note that to get the correct behaviour you also need to have a recent version of MCMCChains (v7.2.1).
8+
39
## 0.37.1
410

511
Update DynamicPPLMooncakeExt to work with Mooncake 0.4.147.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.37.1"
3+
version = "0.37.2"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/sampler.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,23 @@ function AbstractMCMC.sample(
9595
)
9696
end
9797

98+
function AbstractMCMC.sample(
99+
rng::Random.AbstractRNG,
100+
model::Model,
101+
sampler::Sampler,
102+
parallel::AbstractMCMC.AbstractMCMCEnsemble,
103+
N::Integer,
104+
nchains::Integer;
105+
chain_type=default_chain_type(sampler),
106+
resume_from=nothing,
107+
initial_state=loadstate(resume_from),
108+
kwargs...,
109+
)
110+
return AbstractMCMC.mcmcsample(
111+
rng, model, sampler, parallel, N, nchains; chain_type, initial_state, kwargs...
112+
)
113+
end
114+
98115
# initial step: general interface for resuming and
99116
function AbstractMCMC.step(
100117
rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs...

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ EnzymeCore = "0.6 - 0.8"
4343
ForwardDiff = "0.10.12, 1"
4444
JET = "0.9, 0.10"
4545
LogDensityProblems = "2"
46-
MCMCChains = "6.0.4, 7"
46+
MCMCChains = "7.2.1"
4747
MacroTools = "0.5.6"
4848
OrderedCollections = "1"
4949
ReverseDiff = "1"

test/sampler.jl

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,105 @@
11
@testset "sampler.jl" begin
2+
@testset "initial_state and resume_from kwargs" begin
3+
# Model is unused, but has to be a DynamicPPL.Model otherwise we won't hit our
4+
# overloaded method.
5+
@model f() = x ~ Normal()
6+
model = f()
7+
# This sampler just returns the state it was given as its 'sample'
8+
struct S <: AbstractMCMC.AbstractSampler end
9+
function AbstractMCMC.step(
10+
rng::Random.AbstractRNG,
11+
model::Model,
12+
sampler::Sampler{<:S},
13+
state=nothing;
14+
kwargs...,
15+
)
16+
if state === nothing
17+
s = rand()
18+
return s, s
19+
else
20+
return state, state
21+
end
22+
end
23+
spl = Sampler(S())
24+
25+
function AbstractMCMC.bundle_samples(
26+
samples::Vector{Float64},
27+
model::Model,
28+
sampler::Sampler{<:S},
29+
state,
30+
chain_type::Type{MCMCChains.Chains};
31+
kwargs...,
32+
)
33+
return MCMCChains.Chains(samples, [:x]; info=(samplerstate=state,))
34+
end
35+
36+
N_iters, N_chains = 10, 3
37+
38+
@testset "single-chain sampling" begin
39+
chn = sample(model, spl, N_iters; progress=false, chain_type=MCMCChains.Chains)
40+
initial_value = chn[:x][1]
41+
@test all(chn[:x] .== initial_value) # sanity check
42+
# using `initial_state`
43+
chn2 = sample(
44+
model,
45+
spl,
46+
N_iters;
47+
progress=false,
48+
initial_state=chn.info.samplerstate,
49+
chain_type=MCMCChains.Chains,
50+
)
51+
@test all(chn2[:x] .== initial_value)
52+
# using `resume_from`
53+
chn3 = sample(
54+
model,
55+
spl,
56+
N_iters;
57+
progress=false,
58+
resume_from=chn,
59+
chain_type=MCMCChains.Chains,
60+
)
61+
@test all(chn3[:x] .== initial_value)
62+
end
63+
64+
@testset "multiple-chain sampling" begin
65+
chn = sample(
66+
model,
67+
spl,
68+
MCMCThreads(),
69+
N_iters,
70+
N_chains;
71+
progress=false,
72+
chain_type=MCMCChains.Chains,
73+
)
74+
initial_value = chn[:x][1, :]
75+
@test all(i -> chn[:x][i, :] == initial_value, 1:N_iters) # sanity check
76+
# using `initial_state`
77+
chn2 = sample(
78+
model,
79+
spl,
80+
MCMCThreads(),
81+
N_iters,
82+
N_chains;
83+
progress=false,
84+
initial_state=chn.info.samplerstate,
85+
chain_type=MCMCChains.Chains,
86+
)
87+
@test all(i -> chn2[:x][i, :] == initial_value, 1:N_iters)
88+
# using `resume_from`
89+
chn3 = sample(
90+
model,
91+
spl,
92+
MCMCThreads(),
93+
N_iters,
94+
N_chains;
95+
progress=false,
96+
resume_from=chn,
97+
chain_type=MCMCChains.Chains,
98+
)
99+
@test all(i -> chn3[:x][i, :] == initial_value, 1:N_iters)
100+
end
101+
end
102+
2103
@testset "SampleFromPrior and SampleUniform" begin
3104
@model function gdemo(x, y)
4105
s ~ InverseGamma(2, 3)

0 commit comments

Comments
 (0)