Skip to content

Commit 25c6513

Browse files
committed
Fix everything (I _think_)
1 parent 3afd807 commit 25c6513

File tree

4 files changed

+41
-48
lines changed

4 files changed

+41
-48
lines changed

test/mcmc/ess.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module ESSTests
22

33
using ..Models: MoGtest, MoGtest_default, gdemo, gdemo_default
44
using ..NumericalTests: check_MoGtest_default, check_numerical
5-
using ..SamplerTestUtils: check_rng_respected
5+
using ..SamplerTestUtils: test_rng_respected
66
using Distributions: Normal, sample
77
using DynamicPPL: DynamicPPL
88
using DynamicPPL: Sampler
@@ -40,9 +40,9 @@ using Turing
4040
end
4141

4242
@testset "RNG is respected" begin
43-
check_rng_respected(ESS())
44-
check_rng_respected(Gibbs(:x => ESS(), :y => MH()))
45-
check_rng_respected(Gibbs(:x => ESS(), :y => ESS()))
43+
test_rng_respected(ESS())
44+
test_rng_respected(Gibbs(:x => ESS(), :y => MH()))
45+
test_rng_respected(Gibbs(:x => ESS(), :y => ESS()))
4646
end
4747

4848
@testset "ESS inference" begin

test/mcmc/external_sampler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ end
174174

175175
function test_initial_params(model, sampler; kwargs...)
176176
# Generate some parameters.
177-
dict = DynamicPPL.values_as(VarInfo(model), Dict)
177+
dict = DynamicPPL.values_as(DynamicPPL.VarInfo(model), Dict)
178178
init_strategy = DynamicPPL.InitFromParams(dict)
179179

180180
# Execute the transition with two different RNGs and check that the resulting

test/mcmc/is.jl

Lines changed: 31 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,52 @@
11
module ISTests
22

3-
using Distributions: Normal, sample
43
using DynamicPPL: logpdf
54
using Random: Random
5+
using StableRNGs: StableRNG
66
using StatsFuns: logsumexp
77
using Test: @test, @testset
88
using Turing
99

1010
@testset "is.jl" begin
11-
function reference(n)
12-
as = Vector{Float64}(undef, n)
13-
bs = Vector{Float64}(undef, n)
14-
logps = Vector{Float64}(undef, n)
11+
@testset "numerical accuracy" begin
12+
function reference(n)
13+
rng = StableRNG(468)
14+
as = Vector{Float64}(undef, n)
15+
bs = Vector{Float64}(undef, n)
1516

16-
for i in 1:n
17-
as[i], bs[i], logps[i] = reference()
17+
for i in 1:n
18+
as[i] = rand(rng, Normal(4, 5))
19+
bs[i] = rand(rng, Normal(as[i], 1))
20+
end
21+
# logevidence = logsumexp(logps) - log(n)
22+
return (as=as, bs=bs)
1823
end
19-
logevidence = logsumexp(logps) - log(n)
2024

21-
return (as=as, bs=bs, logps=logps, logevidence=logevidence)
22-
end
23-
24-
function reference()
25-
x = rand(Normal(4, 5))
26-
y = rand(Normal(x, 1))
27-
loglik = logpdf(Normal(x, 2), 3) + logpdf(Normal(y, 2), 1.5)
28-
return x, y, loglik
29-
end
30-
31-
@model function normal()
32-
a ~ Normal(4, 5)
33-
3 ~ Normal(a, 2)
34-
b ~ Normal(a, 1)
35-
1.5 ~ Normal(b, 2)
36-
return a, b
37-
end
38-
39-
alg = IS()
40-
seed = 0
41-
n = 10
25+
@model function normal()
26+
a ~ Normal(4, 5)
27+
3 ~ Normal(a, 2)
28+
b ~ Normal(a, 1)
29+
1.5 ~ Normal(b, 2)
30+
return a, b
31+
end
4232

43-
model = normal()
44-
for i in 1:100
45-
Random.seed!(seed)
46-
ref = reference(n)
33+
function expected_loglikelihoods(as, bs)
34+
return logpdf.(Normal.(as, 2), 3) .+ logpdf.(Normal.(bs, 2), 1.5)
35+
end
4736

48-
Random.seed!(seed)
49-
chain = sample(model, alg, n; check_model=false)
50-
sampled = get(chain, [:a, :b, :loglikelihood])
37+
alg = IS()
38+
N = 1000
39+
model = normal()
40+
chain = sample(StableRNG(468), model, alg, N)
41+
ref = reference(N)
5142

52-
@test vec(sampled.a) == ref.as
53-
@test vec(sampled.b) == ref.bs
54-
@test vec(sampled.loglikelihood) == ref.logps
55-
@test chain.logevidence == ref.logevidence
43+
@test isapprox(mean(chain[:a]), mean(ref.as); atol=0.1)
44+
@test isapprox(mean(chain[:b]), mean(ref.bs); atol=0.1)
45+
@test isapprox(chain[:loglikelihood], expected_loglikelihoods(chain[:a], chain[:b]))
46+
@test isapprox(chain.logevidence, logsumexp(chain[:loglikelihood]) - log(N))
5647
end
5748

5849
@testset "logevidence" begin
59-
Random.seed!(100)
60-
6150
@model function test()
6251
a ~ Normal(0, 1)
6352
x ~ Bernoulli(1)

test/test_utils/sampler.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ function test_chain_logp_metadata(spl)
2525
@test chn[:lp] chn[:logprior] + chn[:loglikelihood]
2626
end
2727

28+
"""
29+
Check that sampling is deterministic when using the same RNG seed.
30+
"""
2831
function test_rng_respected(spl)
2932
@model function f(z)
3033
# put at least two variables here so that we can meaningfully test Gibbs
@@ -35,7 +38,8 @@ function test_rng_respected(spl)
3538
model = f(2.0)
3639
chn1 = sample(Xoshiro(468), model, spl, 100)
3740
chn2 = sample(Xoshiro(468), model, spl, 100)
38-
@test chn1 == chn2
41+
@test isapprox(chn1[:x], chn2[:x])
42+
@test isapprox(chn1[:y], chn2[:y])
3943
end
4044

4145
end

0 commit comments

Comments
 (0)