Skip to content

Commit ff8d01e

Browse files
authored
Do not take an initial step before starting the chain in HMC (#2674)
* Do not take an initial step before starting the chain in HMC * Fix some tests * [skip ci] update changelog
1 parent 03454c8 commit ff8d01e

File tree

6 files changed

+43
-31
lines changed

6 files changed

+43
-31
lines changed

HISTORY.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
# 0.41.0
2+
3+
HMC and NUTS samplers no longer take an extra single step before starting the chain.
4+
This means that if you do not discard any samples at the start, the first sample will be the initial parameters (which may be user-provided).
5+
6+
Note that if the initial sample is included, the corresponding sampler statistics will be `missing`.
7+
Due to a technical limitation of MCMCChains.jl, this causes all indexing into MCMCChains to return `Union{Float64, Missing}` or similar.
8+
If you want the old behaviour, you can discard the first sample (e.g. using `discard_initial=1`).
9+
110
# 0.40.3
211

312
This patch makes the `resume_from` keyword argument work correctly when sampling multiple chains.

src/mcmc/hmc.jl

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -216,32 +216,12 @@ function DynamicPPL.initialstep(
216216
else
217217
ϵ = spl.alg.ϵ
218218
end
219-
220-
# Generate a kernel.
219+
# Generate a kernel and adaptor.
221220
kernel = make_ahmc_kernel(spl.alg, ϵ)
222-
223-
# Create initial transition and state.
224-
# Already perform one step since otherwise we don't get any statistics.
225-
t = AHMC.transition(rng, hamiltonian, kernel, z)
226-
227-
# Adaptation
228221
adaptor = AHMCAdaptor(spl.alg, hamiltonian.metric; ϵ=ϵ)
229-
if spl.alg isa AdaptiveHamiltonian
230-
hamiltonian, kernel, _ = AHMC.adapt!(
231-
hamiltonian, kernel, adaptor, 1, nadapts, t.z.θ, t.stat.acceptance_rate
232-
)
233-
end
234-
235-
# Update VarInfo parameters based on acceptance
236-
new_params = if t.stat.is_accept
237-
t.z.θ
238-
else
239-
theta
240-
end
241-
vi = DynamicPPL.unflatten(vi, new_params)
242222

243-
transition = Transition(model, vi, t)
244-
state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor)
223+
transition = Transition(model, vi, NamedTuple())
224+
state = HMCState(vi, 1, kernel, hamiltonian, z, adaptor)
245225

246226
return transition, state
247227
end

test/mcmc/Inference.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ using Turing
4141
Random.seed!(5)
4242
chain2 = sample(model, sampler, MCMCThreads(), 10, 4)
4343

44-
@test chain1.value == chain2.value
44+
# For HMC, the first step does not have stats, so we need to use isequal to
45+
# avoid comparing `missing`s
46+
@test isequal(chain1.value, chain2.value)
4547
end
4648

4749
# Should also be stable with an explicit RNG
@@ -54,7 +56,7 @@ using Turing
5456
Random.seed!(rng, local_seed)
5557
chain2 = sample(rng, model, sampler, MCMCThreads(), 10, 4)
5658

57-
@test chain1.value == chain2.value
59+
@test isequal(chain1.value, chain2.value)
5860
end
5961
end
6062

@@ -608,8 +610,8 @@ using Turing
608610

609611
@testset "names_values" begin
610612
ks, xs = Turing.Inference.names_values([(a=1,), (b=2,), (a=3, b=4)])
611-
@test all(xs[:, 1] .=== [1, missing, 3])
612-
@test all(xs[:, 2] .=== [missing, 2, 4])
613+
@test isequal(xs[:, 1], [1, missing, 3])
614+
@test isequal(xs[:, 2], [missing, 2, 4])
613615
end
614616

615617
@testset "check model" begin

test/mcmc/hmc.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,28 @@ using Turing
171171
@test Array(res1) == Array(res2) == Array(res3)
172172
end
173173

174+
@testset "initial params are respected" begin
175+
@model demo_norm() = x ~ Beta(2, 2)
176+
init_x = 0.5
177+
@testset "$spl_name" for (spl_name, spl) in
178+
(("HMC", HMC(0.1, 10)), ("NUTS", NUTS()))
179+
chain = sample(
180+
demo_norm(), spl, 5; discard_adapt=false, initial_params=(x=init_x,)
181+
)
182+
@test chain[:x][1] == init_x
183+
chain = sample(
184+
demo_norm(),
185+
spl,
186+
MCMCThreads(),
187+
5,
188+
5;
189+
discard_adapt=false,
190+
initial_params=(fill((x=init_x,), 5)),
191+
)
192+
@test all(chain[:x][1, :] .== init_x)
193+
end
194+
end
195+
174196
@testset "warning for difficult init params" begin
175197
attempt = 0
176198
@model function demo_warn_initial_params()

test/mcmc/repeat_sampler.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ using Turing
3535
num_chains;
3636
chain_type=Chains,
3737
)
38-
@test chn1.value == chn2.value
38+
# isequal to avoid comparing `missing`s in chain stats
39+
@test isequal(chn1.value, chn2.value)
3940
end
4041
end
4142

test/stdlib/distributions.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ using Turing
5151
end
5252

5353
@testset "single distribution correctness" begin
54-
rng = StableRNG(1)
55-
5654
n_samples = 10_000
5755
mean_tol = 0.1
5856
var_atol = 1.0
@@ -132,7 +130,7 @@ using Turing
132130

133131
@model m() = x ~ dist
134132

135-
chn = sample(rng, m(), HMC(0.05, 20), n_samples)
133+
chn = sample(StableRNG(468), m(), HMC(0.05, 20), n_samples)
136134

137135
# Numerical tests.
138136
check_dist_numerical(

0 commit comments

Comments
 (0)