Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# DynamicPPL Changelog

## 0.37.2

Make the `resume_from` keyword work for multiple-chain (parallel) sampling as well.
Prior to this version, it was silently ignored.
Note that to get the correct behaviour you also need to have a recent version of MCMCChains (v7.2.1).

## 0.37.1

Update DynamicPPLMooncakeExt to work with Mooncake 0.4.147.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.37.1"
version = "0.37.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
17 changes: 17 additions & 0 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,23 @@ function AbstractMCMC.sample(
)
end

function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::Model,
sampler::Sampler,
parallel::AbstractMCMC.AbstractMCMCEnsemble,
N::Integer,
nchains::Integer;
chain_type=default_chain_type(sampler),
resume_from=nothing,
initial_state=loadstate(resume_from),
Comment on lines +106 to +107
Copy link
Member Author

@penelopeysm penelopeysm Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now if you call sample(..., MCMCThreads(), ...; resume_from=chn then initial_state will be correctly loaded before being sent to AbstractMCMC.

The only requirement for this to work is then that initial_state must be a vector containing exactly nchains final states. For MCMCChains this is not currently true, but it will be fixed by TuringLang/MCMCChains.jl#488, hence the need for that PR.

kwargs...,
)
return AbstractMCMC.mcmcsample(
rng, model, sampler, parallel, N, nchains; chain_type, initial_state, kwargs...
)
end

# initial step: general interface for resuming and
function AbstractMCMC.step(
rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs...
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10.12, 1"
JET = "0.9, 0.10"
LogDensityProblems = "2"
MCMCChains = "6.0.4, 7"
MCMCChains = "7.2.1"
MacroTools = "0.5.6"
OrderedCollections = "1"
ReverseDiff = "1"
Expand Down
101 changes: 101 additions & 0 deletions test/sampler.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,105 @@
@testset "sampler.jl" begin
@testset "initial_state and resume_from kwargs" begin
# Model is unused, but has to be a DynamicPPL.Model otherwise we won't hit our
# overloaded method.
@model f() = x ~ Normal()
model = f()
# This sampler just returns the state it was given as its 'sample'
struct S <: AbstractMCMC.AbstractSampler end
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::Model,
sampler::Sampler{<:S},
state=nothing;
kwargs...,
)
if state === nothing
s = rand()
return s, s
else
return state, state
end
end
spl = Sampler(S())

function AbstractMCMC.bundle_samples(
samples::Vector{Float64},
model::Model,
sampler::Sampler{<:S},
state,
chain_type::Type{MCMCChains.Chains};
kwargs...,
)
return MCMCChains.Chains(samples, [:x]; info=(samplerstate=state,))
end

N_iters, N_chains = 10, 3

@testset "single-chain sampling" begin
chn = sample(model, spl, N_iters; progress=false, chain_type=MCMCChains.Chains)
initial_value = chn[:x][1]
@test all(chn[:x] .== initial_value) # sanity check
# using `initial_state`
chn2 = sample(
model,
spl,
N_iters;
progress=false,
initial_state=chn.info.samplerstate,
chain_type=MCMCChains.Chains,
)
@test all(chn2[:x] .== initial_value)
# using `resume_from`
chn3 = sample(
model,
spl,
N_iters;
progress=false,
resume_from=chn,
chain_type=MCMCChains.Chains,
)
@test all(chn3[:x] .== initial_value)
end

@testset "multiple-chain sampling" begin
chn = sample(
model,
spl,
MCMCThreads(),
N_iters,
N_chains;
progress=false,
chain_type=MCMCChains.Chains,
)
initial_value = chn[:x][1, :]
@test all(i -> chn[:x][i, :] == initial_value, 1:N_iters) # sanity check
# using `initial_state`
chn2 = sample(
model,
spl,
MCMCThreads(),
N_iters,
N_chains;
progress=false,
initial_state=chn.info.samplerstate,
chain_type=MCMCChains.Chains,
)
@test all(i -> chn2[:x][i, :] == initial_value, 1:N_iters)
# using `resume_from`
chn3 = sample(
model,
spl,
MCMCThreads(),
N_iters,
N_chains;
progress=false,
resume_from=chn,
chain_type=MCMCChains.Chains,
)
@test all(i -> chn3[:x][i, :] == initial_value, 1:N_iters)
end
end

@testset "SampleFromPrior and SampleUniform" begin
@model function gdemo(x, y)
s ~ InverseGamma(2, 3)
Expand Down
Loading