Skip to content

Commit c315993

Browse files
committed
fix ESS reproducibility
1 parent b5d82c9 commit c315993

File tree

3 files changed

+22
-1
lines changed

3 files changed

+22
-1
lines changed

src/mcmc/ess.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true
8282

8383
# Only define out-of-place sampling
8484
function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
85-
_, vi = DynamicPPL.init!!(p.model, p.varinfo, DynamicPPL.InitFromPrior())
85+
_, vi = DynamicPPL.init!!(rng, p.model, p.varinfo, DynamicPPL.InitFromPrior())
8686
return vi[:]
8787
end
8888

test/mcmc/ess.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module ESSTests
22

33
using ..Models: MoGtest, MoGtest_default, gdemo, gdemo_default
44
using ..NumericalTests: check_MoGtest_default, check_numerical
5+
using ..SamplerTestUtils: check_rng_respected
56
using Distributions: Normal, sample
67
using DynamicPPL: DynamicPPL
78
using DynamicPPL: Sampler
@@ -38,6 +39,12 @@ using Turing
3839
c3 = sample(gdemo_default, s2, N)
3940
end
4041

42+
@testset "RNG is respected" begin
43+
check_rng_respected(ESS())
44+
check_rng_respected(Gibbs(:x => ESS(), :y => MH()))
45+
check_rng_respected(Gibbs(:x => ESS(), :y => ESS()))
46+
end
47+
4148
@testset "ESS inference" begin
4249
@info "Starting ESS inference tests"
4350
seed = 23

test/test_utils/sampler.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module SamplerTestUtils
22

3+
using Random
34
using Turing
45
using Test
56

@@ -24,4 +25,17 @@ function test_chain_logp_metadata(spl)
2425
@test chn[:lp] chn[:logprior] + chn[:loglikelihood]
2526
end
2627

28+
function test_rng_respected(spl)
29+
@model function f(z)
30+
# put at least two variables here so that we can meaningfully test Gibbs
31+
x ~ Normal()
32+
y ~ Normal()
33+
return z ~ Normal(x + y)
34+
end
35+
model = f(2.0)
36+
chn1 = sample(Xoshiro(468), model, spl, 100)
37+
chn2 = sample(Xoshiro(468), model, spl, 100)
38+
@test chn1 == chn2
39+
end
40+
2741
end

0 commit comments

Comments
 (0)