diff --git a/HISTORY.md b/HISTORY.md index c0db1cd5d..0f22721d9 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,11 @@ # DynamicPPL Changelog +## 0.37.2 + +Make the `resume_from` keyword work for multiple-chain (parallel) sampling as well. +Prior to this version, it was silently ignored. +Note that to get the correct behaviour you also need to have a recent version of MCMCChains (v7.2.1). + ## 0.37.1 Update DynamicPPLMooncakeExt to work with Mooncake 0.4.147. diff --git a/Project.toml b/Project.toml index b24c88523..5f11cba3f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.37.1" +version = "0.37.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/sampler.jl b/src/sampler.jl index 673b5128f..27b990336 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -95,6 +95,23 @@ function AbstractMCMC.sample( ) end +function AbstractMCMC.sample( + rng::Random.AbstractRNG, + model::Model, + sampler::Sampler, + parallel::AbstractMCMC.AbstractMCMCEnsemble, + N::Integer, + nchains::Integer; + chain_type=default_chain_type(sampler), + resume_from=nothing, + initial_state=loadstate(resume_from), + kwargs..., +) + return AbstractMCMC.mcmcsample( + rng, model, sampler, parallel, N, nchains; chain_type, initial_state, kwargs... + ) +end + # initial step: general interface for resuming and function AbstractMCMC.step( rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs... diff --git a/test/Project.toml b/test/Project.toml index 6da3786f5..91a885e96 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,7 +43,7 @@ EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10.12, 1" JET = "0.9, 0.10" LogDensityProblems = "2" -MCMCChains = "6.0.4, 7" +MCMCChains = "7.2.1" MacroTools = "0.5.6" OrderedCollections = "1" ReverseDiff = "1" diff --git a/test/sampler.jl b/test/sampler.jl index fe9fd331a..5eb0da057 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -1,4 +1,105 @@ @testset "sampler.jl" begin + @testset "initial_state and resume_from kwargs" begin + # Model is unused, but has to be a DynamicPPL.Model otherwise we won't hit our + # overloaded method. + @model f() = x ~ Normal() + model = f() + # This sampler just returns the state it was given as its 'sample' + struct S <: AbstractMCMC.AbstractSampler end + function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::Model, + sampler::Sampler{<:S}, + state=nothing; + kwargs..., + ) + if state === nothing + s = rand() + return s, s + else + return state, state + end + end + spl = Sampler(S()) + + function AbstractMCMC.bundle_samples( + samples::Vector{Float64}, + model::Model, + sampler::Sampler{<:S}, + state, + chain_type::Type{MCMCChains.Chains}; + kwargs..., + ) + return MCMCChains.Chains(samples, [:x]; info=(samplerstate=state,)) + end + + N_iters, N_chains = 10, 3 + + @testset "single-chain sampling" begin + chn = sample(model, spl, N_iters; progress=false, chain_type=MCMCChains.Chains) + initial_value = chn[:x][1] + @test all(chn[:x] .== initial_value) # sanity check + # using `initial_state` + chn2 = sample( + model, + spl, + N_iters; + progress=false, + initial_state=chn.info.samplerstate, + chain_type=MCMCChains.Chains, + ) + @test all(chn2[:x] .== initial_value) + # using `resume_from` + chn3 = sample( + model, + spl, + N_iters; + progress=false, + resume_from=chn, + chain_type=MCMCChains.Chains, + ) + @test all(chn3[:x] .== initial_value) + end + + @testset "multiple-chain sampling" begin + chn = sample( + model, + spl, + MCMCThreads(), + N_iters, + N_chains; + progress=false, + chain_type=MCMCChains.Chains, + ) + initial_value = chn[:x][1, :] + @test all(i -> chn[:x][i, :] == initial_value, 1:N_iters) # sanity check + # using `initial_state` + chn2 = sample( + model, + spl, + MCMCThreads(), + N_iters, + N_chains; + progress=false, + initial_state=chn.info.samplerstate, + chain_type=MCMCChains.Chains, + ) + @test all(i -> chn2[:x][i, :] == initial_value, 1:N_iters) + # using `resume_from` + chn3 = sample( + model, + spl, + MCMCThreads(), + N_iters, + N_chains; + progress=false, + resume_from=chn, + chain_type=MCMCChains.Chains, + ) + @test all(i -> chn3[:x][i, :] == initial_value, 1:N_iters) + end + end + @testset "SampleFromPrior and SampleUniform" begin @model function gdemo(x, y) s ~ InverseGamma(2, 3)