Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# 0.40.3

This patch makes the `resume_from` keyword argument work correctly when sampling multiple chains.

In the process this also fixes a method ambiguity caused by a bugfix in DynamicPPL 0.37.2.

This patch means that if you are using `RepeatSampler()` to sample from a model, and you want to obtain `MCMCChains.Chains` from it, you need to specify `sample(...; chain_type=MCMCChains.Chains)`.
This only applies if the sampler itself is a `RepeatSampler`; it doesn't apply if you are using `RepeatSampler` _within_ another sampler like Gibbs.

# 0.40.2

`sample(model, NUTS(), N; verbose=false)` now suppresses the 'initial step size' message.
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.40.2"
version = "0.40.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -64,7 +64,7 @@ Distributions = "0.25.77"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.37"
DynamicPPL = "0.37.2"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3, 1"
Libtask = "0.9.3"
Expand Down
24 changes: 0 additions & 24 deletions src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,27 +57,3 @@ function AbstractMCMC.sample(
check_model && _check_model(model, alg)
return AbstractMCMC.sample(rng, model, Sampler(alg), ensemble, N, n_chains; kwargs...)
end

function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
sampler::Union{Sampler{<:InferenceAlgorithm},RepeatSampler},
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
N::Integer,
n_chains::Integer;
chain_type=MCMCChains.Chains,
progress=PROGRESS[],
kwargs...,
)
return AbstractMCMC.mcmcsample(
rng,
model,
sampler,
ensemble,
N,
n_chains;
chain_type=chain_type,
progress=progress,
kwargs...,
)
end
5 changes: 0 additions & 5 deletions src/mcmc/repeat_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ function RepeatSampler(alg::InferenceAlgorithm, num_repeat::Int)
return RepeatSampler(Sampler(alg), num_repeat)
end

getADType(spl::RepeatSampler) = getADType(spl.sampler)
DynamicPPL.default_chain_type(sampler::RepeatSampler) = default_chain_type(sampler.sampler)
# TODO(mhauru) Remove the below once DynamicPPL has removed all its Selector stuff.
DynamicPPL.inspace(vn::VarName, spl::RepeatSampler) = inspace(vn, spl.sampler)

function setparams_varinfo!!(model::DynamicPPL.Model, sampler::RepeatSampler, state, params)
return setparams_varinfo!!(model, sampler.sampler, state, params)
end
Expand Down
4 changes: 2 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ Combinatorics = "1"
Distributions = "0.25"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.37"
DynamicPPL = "0.37.2"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1"
HypothesisTests = "0.11"
LinearAlgebra = "1"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.4"
MCMCChains = "5, 6, 7"
MCMCChains = "7.3.0"
NamedArrays = "0.9.4, 0.10"
Optim = "1"
Optimization = "3, 4"
Expand Down
44 changes: 43 additions & 1 deletion test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using ..Models: gdemo_d, gdemo_default
using ..NumericalTests: check_gdemo, check_numerical
using Distributions: Bernoulli, Beta, InverseGamma, Normal
using Distributions: sample
using AbstractMCMC: AbstractMCMC
import DynamicPPL
using DynamicPPL: Sampler
import ForwardDiff
Expand Down Expand Up @@ -72,7 +73,48 @@ using Turing
end
end

@testset "chain save/resume" begin
@testset "save/resume correctly reloads state" begin
struct StaticSampler <: Turing.Inference.InferenceAlgorithm end
function DynamicPPL.initialstep(
rng, model, ::DynamicPPL.Sampler{<:StaticSampler}, vi; kwargs...
)
return Turing.Inference.Transition(model, vi, nothing), vi
end
function AbstractMCMC.step(
rng,
model,
::DynamicPPL.Sampler{<:StaticSampler},
vi::DynamicPPL.AbstractVarInfo;
kwargs...,
)
return Turing.Inference.Transition(model, vi, nothing), vi
end

@model demo() = x ~ Normal()

@testset "single-chain" begin
chn1 = sample(demo(), StaticSampler(), 10; save_state=true)
@test chn1.info.samplerstate isa DynamicPPL.AbstractVarInfo
chn2 = sample(demo(), StaticSampler(), 10; resume_from=chn1)
xval = chn1[:x][1]
@test all(chn2[:x] .== xval)
end

@testset "multiple-chain" for nchains in [1, 3]
chn1 = sample(
demo(), StaticSampler(), MCMCThreads(), 10, nchains; save_state=true
)
@test chn1.info.samplerstate isa AbstractVector{<:DynamicPPL.AbstractVarInfo}
@test length(chn1.info.samplerstate) == nchains
chn2 = sample(
demo(), StaticSampler(), MCMCThreads(), 10, nchains; resume_from=chn1
)
xval = chn1[:x][1, :]
@test all(i -> chn2[:x][i, :] == xval, 1:10)
end
end

@testset "single-chain save/resume numerical accuracy" begin
alg1 = HMCDA(1000, 0.65, 0.15)
alg2 = PG(20)
alg3 = Gibbs(:s => PG(30), :m => HMC(0.2, 4))
Expand Down
9 changes: 8 additions & 1 deletion test/mcmc/repeat_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module RepeatSamplerTests

using ..Models: gdemo_default
using DynamicPPL: Sampler
using MCMCChains: Chains
using StableRNGs: StableRNG
using Test: @test, @testset
using Turing
Expand All @@ -26,7 +27,13 @@ using Turing
)
repeat_sampler = RepeatSampler(sampler, num_repeats)
chn2 = sample(
copy(rng), gdemo_default, repeat_sampler, MCMCThreads(), num_samples, num_chains
copy(rng),
gdemo_default,
repeat_sampler,
MCMCThreads(),
num_samples,
num_chains;
chain_type=Chains,
)
@test chn1.value == chn2.value
end
Expand Down
Loading