Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# DynamicPPL Changelog

## 0.37.2

Make the `resume_from` keyword work for multiple-chain (parallel) sampling as well.
Prior to this version, it was silently ignored.
Note that to get the correct behaviour you also need to have a recent version of MCMCChains (v7.2.1).

## 0.37.1

Update DynamicPPLMooncakeExt to work with Mooncake 0.4.147.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.37.1"
version = "0.37.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
17 changes: 17 additions & 0 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,23 @@ function AbstractMCMC.sample(
)
end

function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::Model,
sampler::Sampler,
parallel::AbstractMCMC.AbstractMCMCEnsemble,
N::Integer,
nchains::Integer;
chain_type=default_chain_type(sampler),
resume_from=nothing,
initial_state=loadstate(resume_from),
Comment on lines +106 to +107
Copy link
Member Author

@penelopeysm penelopeysm Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now if you call sample(..., MCMCThreads(), ...; resume_from=chn then initial_state will be correctly loaded before being sent to AbstractMCMC.

The only requirement for this to work is then that initial_state must be a vector containing exactly nchains final states. For MCMCChains this is not currently true, but it will be fixed by TuringLang/MCMCChains.jl#488, hence the need for that PR.

kwargs...,
)
return AbstractMCMC.mcmcsample(
rng, model, sampler, parallel, N, nchains; chain_type, initial_state, kwargs...
)
end

# initial step: general interface for resuming and
function AbstractMCMC.step(
rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs...
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10.12, 1"
JET = "0.9, 0.10"
LogDensityProblems = "2"
MCMCChains = "6.0.4, 7"
MCMCChains = "7.2.1"
MacroTools = "0.5.6"
OrderedCollections = "1"
ReverseDiff = "1"
Expand Down
101 changes: 101 additions & 0 deletions test/sampler.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,105 @@
@testset "sampler.jl" begin
@testset "initial_state and resume_from kwargs" begin
# Model is unused, but has to be a DynamicPPL.Model otherwise we won't hit our
# overloaded method.
@model f() = x ~ Normal()
model = f()
# This sampler just returns the state it was given as its 'sample'
struct S <: AbstractMCMC.AbstractSampler end
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::Model,
sampler::Sampler{<:S},
state=nothing;
kwargs...,
)
if state === nothing
s = rand()
return s, s
else
return state, state
end
end
spl = Sampler(S())

function AbstractMCMC.bundle_samples(
samples::Vector{Float64},
model::Model,
sampler::Sampler{<:S},
state,
chain_type::Type{MCMCChains.Chains};
kwargs...,
)
return MCMCChains.Chains(samples, [:x]; info=(samplerstate=state,))
end

N_iters, N_chains = 10, 3

@testset "single-chain sampling" begin
chn = sample(model, spl, N_iters; progress=false, chain_type=MCMCChains.Chains)
initial_value = chn[:x][1]
@test all(chn[:x] .== initial_value) # sanity check
# using `initial_state`
chn2 = sample(
model,
spl,
N_iters;
progress=false,
initial_state=chn.info.samplerstate,
chain_type=MCMCChains.Chains,
)
@test all(chn2[:x] .== initial_value)
# using `resume_from`
chn3 = sample(
model,
spl,
N_iters;
progress=false,
resume_from=chn,
chain_type=MCMCChains.Chains,
)
@test all(chn3[:x] .== initial_value)
end

@testset "multiple-chain sampling" begin
chn = sample(
model,
spl,
MCMCThreads(),
N_iters,
N_chains;
progress=false,
chain_type=MCMCChains.Chains,
)
initial_value = chn[:x][1, :]
@test all(i -> chn[:x][i, :] == initial_value, 1:N_iters) # sanity check
# using `initial_state`
chn2 = sample(
model,
spl,
MCMCThreads(),
N_iters,
N_chains;
progress=false,
initial_state=chn.info.samplerstate,
chain_type=MCMCChains.Chains,
)
@test all(i -> chn2[:x][i, :] == initial_value, 1:N_iters)
# using `resume_from`
chn3 = sample(
model,
spl,
MCMCThreads(),
N_iters,
N_chains;
progress=false,
resume_from=chn,
chain_type=MCMCChains.Chains,
)
@test all(i -> chn3[:x][i, :] == initial_value, 1:N_iters)
end
end

@testset "SampleFromPrior and SampleUniform" begin
@model function gdemo(x, y)
s ~ InverseGamma(2, 3)
Expand Down
Loading