diff --git a/HISTORY.md b/HISTORY.md index d7334c60b..2cdd2a644 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,12 @@ +# 0.40.3 + +This patch makes the `resume_from` keyword argument work correctly when sampling multiple chains. + +In the process this also fixes a method ambiguity caused by a bugfix in DynamicPPL 0.37.2. + +This patch means that if you are using `RepeatSampler()` to sample from a model, and you want to obtain `MCMCChains.Chains` from it, you need to specify `sample(...; chain_type=MCMCChains.Chains)`. +This only applies if the sampler itself is a `RepeatSampler`; it doesn't apply if you are using `RepeatSampler` _within_ another sampler like Gibbs. + # 0.40.2 `sample(model, NUTS(), N; verbose=false)` now suppresses the 'initial step size' message. diff --git a/Project.toml b/Project.toml index 422284095..e047098f9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.40.2" +version = "0.40.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -64,7 +64,7 @@ Distributions = "0.25.77" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.37" +DynamicPPL = "0.37.2" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" Libtask = "0.9.3" diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index 4522875b4..edd563885 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -57,27 +57,3 @@ function AbstractMCMC.sample( check_model && _check_model(model, alg) return AbstractMCMC.sample(rng, model, Sampler(alg), ensemble, N, n_chains; kwargs...) end - -function AbstractMCMC.sample( - rng::AbstractRNG, - model::AbstractModel, - sampler::Union{Sampler{<:InferenceAlgorithm},RepeatSampler}, - ensemble::AbstractMCMC.AbstractMCMCEnsemble, - N::Integer, - n_chains::Integer; - chain_type=MCMCChains.Chains, - progress=PROGRESS[], - kwargs..., -) - return AbstractMCMC.mcmcsample( - rng, - model, - sampler, - ensemble, - N, - n_chains; - chain_type=chain_type, - progress=progress, - kwargs..., - ) -end diff --git a/src/mcmc/repeat_sampler.jl b/src/mcmc/repeat_sampler.jl index 3145e6c99..fa2eca96d 100644 --- a/src/mcmc/repeat_sampler.jl +++ b/src/mcmc/repeat_sampler.jl @@ -28,11 +28,6 @@ function RepeatSampler(alg::InferenceAlgorithm, num_repeat::Int) return RepeatSampler(Sampler(alg), num_repeat) end -getADType(spl::RepeatSampler) = getADType(spl.sampler) -DynamicPPL.default_chain_type(sampler::RepeatSampler) = default_chain_type(sampler.sampler) -# TODO(mhauru) Remove the below once DynamicPPL has removed all its Selector stuff. -DynamicPPL.inspace(vn::VarName, spl::RepeatSampler) = inspace(vn, spl.sampler) - function setparams_varinfo!!(model::DynamicPPL.Model, sampler::RepeatSampler, state, params) return setparams_varinfo!!(model, sampler.sampler, state, params) end diff --git a/test/Project.toml b/test/Project.toml index b10be0140..ba7a83be1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -53,14 +53,14 @@ Combinatorics = "1" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.37" +DynamicPPL = "0.37.2" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1" HypothesisTests = "0.11" LinearAlgebra = "1" LogDensityProblems = "2" LogDensityProblemsAD = "1.4" -MCMCChains = "5, 6, 7" +MCMCChains = "7.3.0" NamedArrays = "0.9.4, 0.10" Optim = "1" Optimization = "3, 4" diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 2cc7e4bc0..9f69a2de5 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -4,6 +4,7 @@ using ..Models: gdemo_d, gdemo_default using ..NumericalTests: check_gdemo, check_numerical using Distributions: Bernoulli, Beta, InverseGamma, Normal using Distributions: sample +using AbstractMCMC: AbstractMCMC import DynamicPPL using DynamicPPL: Sampler import ForwardDiff @@ -72,7 +73,48 @@ using Turing end end - @testset "chain save/resume" begin + @testset "save/resume correctly reloads state" begin + struct StaticSampler <: Turing.Inference.InferenceAlgorithm end + function DynamicPPL.initialstep( + rng, model, ::DynamicPPL.Sampler{<:StaticSampler}, vi; kwargs... + ) + return Turing.Inference.Transition(model, vi, nothing), vi + end + function AbstractMCMC.step( + rng, + model, + ::DynamicPPL.Sampler{<:StaticSampler}, + vi::DynamicPPL.AbstractVarInfo; + kwargs..., + ) + return Turing.Inference.Transition(model, vi, nothing), vi + end + + @model demo() = x ~ Normal() + + @testset "single-chain" begin + chn1 = sample(demo(), StaticSampler(), 10; save_state=true) + @test chn1.info.samplerstate isa DynamicPPL.AbstractVarInfo + chn2 = sample(demo(), StaticSampler(), 10; resume_from=chn1) + xval = chn1[:x][1] + @test all(chn2[:x] .== xval) + end + + @testset "multiple-chain" for nchains in [1, 3] + chn1 = sample( + demo(), StaticSampler(), MCMCThreads(), 10, nchains; save_state=true + ) + @test chn1.info.samplerstate isa AbstractVector{<:DynamicPPL.AbstractVarInfo} + @test length(chn1.info.samplerstate) == nchains + chn2 = sample( + demo(), StaticSampler(), MCMCThreads(), 10, nchains; resume_from=chn1 + ) + xval = chn1[:x][1, :] + @test all(i -> chn2[:x][i, :] == xval, 1:10) + end + end + + @testset "single-chain save/resume numerical accuracy" begin alg1 = HMCDA(1000, 0.65, 0.15) alg2 = PG(20) alg3 = Gibbs(:s => PG(30), :m => HMC(0.2, 4)) diff --git a/test/mcmc/repeat_sampler.jl b/test/mcmc/repeat_sampler.jl index 7328d1168..d2ca427df 100644 --- a/test/mcmc/repeat_sampler.jl +++ b/test/mcmc/repeat_sampler.jl @@ -2,6 +2,7 @@ module RepeatSamplerTests using ..Models: gdemo_default using DynamicPPL: Sampler +using MCMCChains: Chains using StableRNGs: StableRNG using Test: @test, @testset using Turing @@ -26,7 +27,13 @@ using Turing ) repeat_sampler = RepeatSampler(sampler, num_repeats) chn2 = sample( - copy(rng), gdemo_default, repeat_sampler, MCMCThreads(), num_samples, num_chains + copy(rng), + gdemo_default, + repeat_sampler, + MCMCThreads(), + num_samples, + num_chains; + chain_type=Chains, ) @test chn1.value == chn2.value end