Skip to content

Commit f9353f0

Browse files
authored
Fix non-reproducible step sizes (#1924)
* Fix non-reproducible step sizes * Add test
1 parent 61b06f6 commit f9353f0

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.23.1"
3+
version = "0.23.2"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/inference/hmc.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ function DynamicPPL.initialstep(
186186

187187
# Find good eps if not provided one
188188
if iszero(spl.alg.ϵ)
189-
ϵ = AHMC.find_good_stepsize(hamiltonian, theta)
189+
ϵ = AHMC.find_good_stepsize(rng, hamiltonian, theta)
190190
@info "Found initial step size" ϵ
191191
else
192192
ϵ = spl.alg.ϵ
@@ -550,7 +550,7 @@ function HMCState(
550550

551551
# Find good eps if not provided one
552552
if iszero(spl.alg.ϵ)
553-
ϵ = AHMC.find_good_stepsize(h, θ_init)
553+
ϵ = AHMC.find_good_stepsize(rng, h, θ_init)
554554
@info "Found initial step size" ϵ
555555
else
556556
ϵ = spl.alg.ϵ

test/inference/hmc.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,4 +207,13 @@
207207
end
208208
@test sample(rng, mwe3(), HMC(0.2, 4), 1_000) isa Chains
209209
end
210+
211+
# issue #1923
212+
@turing_testset "reproducibility" begin
213+
alg = NUTS(1000, 0.8)
214+
res1 = sample(StableRNG(123), gdemo_default, alg, 1000)
215+
res2 = sample(StableRNG(123), gdemo_default, alg, 1000)
216+
res3 = sample(StableRNG(123), gdemo_default, alg, 1000)
217+
@test Array(res1) == Array(res2) == Array(res3)
218+
end
210219
end

0 commit comments

Comments
 (0)