diff --git a/HISTORY.md b/HISTORY.md index 2cdd2a644..23a686a73 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,12 @@ +# 0.41.0 + +HMC and NUTS samplers no longer take an extra single step before starting the chain. +This means that if you do not discard any samples at the start, the first sample will be the initial parameters (which may be user-provided). + +Note that if the initial sample is included, the corresponding sampler statistics will be `missing`. +Due to a technical limitation of MCMCChains.jl, this causes all indexing into MCMCChains to return `Union{Float64, Missing}` or similar. +If you want the old behaviour, you can discard the first sample (e.g. using `discard_initial=1`). + # 0.40.3 This patch makes the `resume_from` keyword argument work correctly when sampling multiple chains. diff --git a/Project.toml b/Project.toml index e047098f9..e679949d4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.40.3" +version = "0.41.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 6ff975d4d..363508e70 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -216,32 +216,12 @@ function DynamicPPL.initialstep( else ϵ = spl.alg.ϵ end - - # Generate a kernel. + # Generate a kernel and adaptor. kernel = make_ahmc_kernel(spl.alg, ϵ) - - # Create initial transition and state. - # Already perform one step since otherwise we don't get any statistics. - t = AHMC.transition(rng, hamiltonian, kernel, z) - - # Adaptation adaptor = AHMCAdaptor(spl.alg, hamiltonian.metric; ϵ=ϵ) - if spl.alg isa AdaptiveHamiltonian - hamiltonian, kernel, _ = AHMC.adapt!( - hamiltonian, kernel, adaptor, 1, nadapts, t.z.θ, t.stat.acceptance_rate - ) - end - - # Update VarInfo parameters based on acceptance - new_params = if t.stat.is_accept - t.z.θ - else - theta - end - vi = DynamicPPL.unflatten(vi, new_params) - transition = Transition(model, vi, t) - state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor) + transition = Transition(model, vi, NamedTuple()) + state = HMCState(vi, 1, kernel, hamiltonian, z, adaptor) return transition, state end diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 9f69a2de5..0bffda17e 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -41,7 +41,9 @@ using Turing Random.seed!(5) chain2 = sample(model, sampler, MCMCThreads(), 10, 4) - @test chain1.value == chain2.value + # For HMC, the first step does not have stats, so we need to use isequal to + # avoid comparing `missing`s + @test isequal(chain1.value, chain2.value) end # Should also be stable with an explicit RNG @@ -54,7 +56,7 @@ using Turing Random.seed!(rng, local_seed) chain2 = sample(rng, model, sampler, MCMCThreads(), 10, 4) - @test chain1.value == chain2.value + @test isequal(chain1.value, chain2.value) end end @@ -608,8 +610,8 @@ using Turing @testset "names_values" begin ks, xs = Turing.Inference.names_values([(a=1,), (b=2,), (a=3, b=4)]) - @test all(xs[:, 1] .=== [1, missing, 3]) - @test all(xs[:, 2] .=== [missing, 2, 4]) + @test isequal(xs[:, 1], [1, missing, 3]) + @test isequal(xs[:, 2], [missing, 2, 4]) end @testset "check model" begin diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 428c193ca..3328838a9 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -171,6 +171,28 @@ using Turing @test Array(res1) == Array(res2) == Array(res3) end + @testset "initial params are respected" begin + @model demo_norm() = x ~ Beta(2, 2) + init_x = 0.5 + @testset "$spl_name" for (spl_name, spl) in + (("HMC", HMC(0.1, 10)), ("NUTS", NUTS())) + chain = sample( + demo_norm(), spl, 5; discard_adapt=false, initial_params=(x=init_x,) + ) + @test chain[:x][1] == init_x + chain = sample( + demo_norm(), + spl, + MCMCThreads(), + 5, + 5; + discard_adapt=false, + initial_params=(fill((x=init_x,), 5)), + ) + @test all(chain[:x][1, :] .== init_x) + end + end + @testset "warning for difficult init params" begin attempt = 0 @model function demo_warn_initial_params() diff --git a/test/mcmc/repeat_sampler.jl b/test/mcmc/repeat_sampler.jl index d2ca427df..d848627d7 100644 --- a/test/mcmc/repeat_sampler.jl +++ b/test/mcmc/repeat_sampler.jl @@ -35,7 +35,8 @@ using Turing num_chains; chain_type=Chains, ) - @test chn1.value == chn2.value + # isequal to avoid comparing `missing`s in chain stats + @test isequal(chn1.value, chn2.value) end end diff --git a/test/stdlib/distributions.jl b/test/stdlib/distributions.jl index 0f8a7c718..e6ce5794d 100644 --- a/test/stdlib/distributions.jl +++ b/test/stdlib/distributions.jl @@ -51,8 +51,6 @@ using Turing end @testset "single distribution correctness" begin - rng = StableRNG(1) - n_samples = 10_000 mean_tol = 0.1 var_atol = 1.0 @@ -132,7 +130,7 @@ using Turing @model m() = x ~ dist - chn = sample(rng, m(), HMC(0.05, 20), n_samples) + chn = sample(StableRNG(468), m(), HMC(0.05, 20), n_samples) # Numerical tests. check_dist_numerical(