Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 135 additions & 103 deletions test/mcmc/Inference.jl

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions test/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ end
DynamicPPL.TestUtils.test_sampler(
[model],
sampler_ext,
5_000;
2_000;
rtol=0.2,
sampler_name="AdvancedHMC",
sample_kwargs...,
Expand Down Expand Up @@ -187,7 +187,7 @@ end
DynamicPPL.TestUtils.test_sampler(
[model],
sampler_ext,
10_000;
2_000;
discard_initial=1_000,
thinning=10,
rtol=0.2,
Expand Down
80 changes: 45 additions & 35 deletions test/mcmc/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ using Distributions: Normal, sample
using DynamicPPL: DynamicPPL
using DynamicPPL: Sampler
using Random: Random
using StableRNGs: StableRNG
using Test: @test, @testset
using Turing

@testset "ESS" begin
@info "Starting ESS tests"

@model function demo(x)
m ~ Normal()
return x ~ Normal(m, 0.5)
Expand All @@ -24,8 +27,7 @@ using Turing
demodot_default = demodot(1.0)

@testset "ESS constructor" begin
Random.seed!(0)
N = 500
N = 10

s1 = ESS()
s2 = ESS(:m)
Expand All @@ -43,41 +45,49 @@ using Turing
end

@testset "ESS inference" begin
Random.seed!(1)
chain = sample(demo_default, ESS(), 5_000)
check_numerical(chain, [:m], [0.8]; atol=0.1)

Random.seed!(1)
chain = sample(demodot_default, ESS(), 5_000)
check_numerical(chain, ["m[1]", "m[2]"], [0.0, 0.8]; atol=0.1)

Random.seed!(100)
alg = Gibbs(CSMC(15, :s), ESS(:m))
chain = sample(gdemo(1.5, 2.0), alg, 10_000)
check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1)

# MoGtest
Random.seed!(125)
alg = Gibbs(CSMC(15, :z1, :z2, :z3, :z4), ESS(:mu1), ESS(:mu2))
chain = sample(MoGtest_default, alg, 6000)
check_MoGtest_default(chain; atol=0.1)

# Different "equivalent" models.
# NOTE: Because `ESS` only supports "single" variables with
# Gaussian priors, we restrict ourselves to this subspace by conditioning
# on the non-Gaussian variables in `DEMO_MODELS`.
models_conditioned = map(DynamicPPL.TestUtils.DEMO_MODELS) do model
# Condition on the non-Gaussian random variables.
model | (s=DynamicPPL.TestUtils.posterior_mean(model).s,)
@info "Starting ESS inference tests"
rng = StableRNG(23)

@testset "demo_default" begin
chain = sample(copy(rng), demo_default, ESS(), 5_000)
check_numerical(chain, [:m], [0.8]; atol=0.1)
end

@testset "demodot_default" begin
chain = sample(copy(rng), demodot_default, ESS(), 5_000)
check_numerical(chain, ["m[1]", "m[2]"], [0.0, 0.8]; atol=0.1)
end

@testset "gdemo with CSMC + ESS" begin
alg = Gibbs(CSMC(15, :s), ESS(:m))
chain = sample(copy(rng), gdemo(1.5, 2.0), alg, 2000)
check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1)
end

@testset "MoGtest_default with CSMC + ESS" begin
alg = Gibbs(CSMC(15, :z1, :z2, :z3, :z4), ESS(:mu1), ESS(:mu2))
chain = sample(copy(rng), MoGtest_default, alg, 2000)
check_MoGtest_default(chain; atol=0.1)
end

DynamicPPL.TestUtils.test_sampler(
models_conditioned,
DynamicPPL.Sampler(ESS()),
10_000;
# Filter out the varnames we've conditioned on.
varnames_filter=vn -> DynamicPPL.getsym(vn) != :s,
)
@testset "TestModels" begin
# Different "equivalent" models.
# NOTE: Because `ESS` only supports "single" variables with
# Gaussian priors, we restrict ourselves to this subspace by conditioning
# on the non-Gaussian variables in `DEMO_MODELS`.
models_conditioned = map(DynamicPPL.TestUtils.DEMO_MODELS) do model
# Condition on the non-Gaussian random variables.
model | (s=DynamicPPL.TestUtils.posterior_mean(model).s,)
end

DynamicPPL.TestUtils.test_sampler(
models_conditioned,
DynamicPPL.Sampler(ESS()),
2000;
# Filter out the varnames we've conditioned on.
varnames_filter=vn -> DynamicPPL.getsym(vn) != :s,
)
end
end
end

Expand Down
82 changes: 40 additions & 42 deletions test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ using Test: @test, @test_logs, @testset, @test_throws
using Turing

@testset "Testing hmc.jl with $adbackend" for adbackend in ADUtils.adbackends
# Set a seed
@info "Starting HMC tests with $adbackend"
rng = StableRNG(123)

@testset "constrained bounded" begin
obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1]

Expand All @@ -33,14 +34,15 @@ using Turing
end

chain = sample(
rng,
copy(rng),
constrained_test(obs),
HMC(1.5, 3; adtype=adbackend),# using a large step size (1.5)
1000,
1_000,
)

check_numerical(chain, [:p], [10 / 14]; atol=0.1)
end

@testset "constrained simplex" begin
obs12 = [1, 2, 1, 2, 2, 2, 2, 2, 2, 2]

Expand All @@ -54,16 +56,18 @@ using Turing
end

chain = sample(
rng, constrained_simplex_test(obs12), HMC(0.75, 2; adtype=adbackend), 1000
copy(rng), constrained_simplex_test(obs12), HMC(0.75, 2; adtype=adbackend), 1000
)

check_numerical(chain, ["ps[1]", "ps[2]"], [5 / 16, 11 / 16]; atol=0.015)
end

@testset "hmc reverse diff" begin
alg = HMC(0.1, 10; adtype=adbackend)
res = sample(rng, gdemo_default, alg, 4000)
res = sample(copy(rng), gdemo_default, alg, 4_000)
check_gdemo(res; rtol=0.1)
end

@testset "matrix support" begin
@model function hmcmatrixsup()
return v ~ Wishart(7, [1 0.5; 0.5 1])
Expand All @@ -72,13 +76,15 @@ using Turing
model_f = hmcmatrixsup()
n_samples = 1_000
vs = map(1:3) do _
chain = sample(rng, model_f, HMC(0.15, 7; adtype=adbackend), n_samples)
chain = sample(copy(rng), model_f, HMC(0.15, 7; adtype=adbackend), n_samples)
r = reshape(Array(group(chain, :v)), n_samples, 2, 2)
reshape(mean(r; dims=1), 2, 2)
end

# TODO(mhauru) This test needs a comment explaining what is being tested.
@test maximum(abs, mean(vs) - (7 * [1 0.5; 0.5 1])) <= 0.5
end

@testset "multivariate support" begin
# Define NN flow
function nn(x, b1, w11, w12, w13, bo, wo)
Expand Down Expand Up @@ -124,58 +130,48 @@ using Turing
end

# Sampling
chain = sample(rng, bnn(ts), HMC(0.1, 5; adtype=adbackend), 10)
chain = sample(copy(rng), bnn(ts), HMC(0.1, 5; adtype=adbackend), 10)
end

@testset "hmcda inference" begin
alg1 = HMCDA(500, 0.8, 0.015; adtype=adbackend)
# alg2 = Gibbs(HMCDA(200, 0.8, 0.35, :m; adtype=adbackend), HMC(0.25, 3, :s; adtype=adbackend))

# alg3 = Gibbs(HMC(0.25, 3, :m; adtype=adbackend), PG(30, 3, :s))
# alg3 = PG(50, 2000)

res1 = sample(rng, gdemo_default, alg1, 3000)
res1 = sample(copy(rng), gdemo_default, alg1, 3_000)
check_gdemo(res1)

# res2 = sample(gdemo([1.5, 2.0]), alg2)
#
# @test mean(res2[:s]) ≈ 49/24 atol=0.2
# @test mean(res2[:m]) ≈ 7/6 atol=0.2
end

# TODO(mhauru) The below one is a) slow, b) flaky, in that changing the seed can
# easily make it fail, despite many more samples than taken by most other tests. Hence
# explicitly specifying the seeds here.
@testset "hmcda+gibbs inference" begin
rng = StableRNG(123)
Random.seed!(12345) # particle samplers do not support user-provided `rng` yet
alg3 = Gibbs(PG(20, :s), HMCDA(500, 0.8, 0.25, :m; init_ϵ=0.05, adtype=adbackend))

res3 = sample(rng, gdemo_default, alg3, 3000; discard_initial=1000)
check_gdemo(res3)
Random.seed!(12345)
alg = Gibbs(PG(20, :s), HMCDA(500, 0.8, 0.25, :m; init_ϵ=0.05, adtype=adbackend))
res = sample(StableRNG(123), gdemo_default, alg, 3000; discard_initial=1000)
check_gdemo(res)
end

@testset "hmcda constructor" begin
alg = HMCDA(0.8, 0.75; adtype=adbackend)
println(alg)
sampler = Sampler(alg, gdemo_default)
@test DynamicPPL.alg_str(sampler) == "HMCDA"

alg = HMCDA(200, 0.8, 0.75; adtype=adbackend)
println(alg)
sampler = Sampler(alg, gdemo_default)
@test DynamicPPL.alg_str(sampler) == "HMCDA"

alg = HMCDA(200, 0.8, 0.75, :s; adtype=adbackend)
println(alg)
sampler = Sampler(alg, gdemo_default)
@test DynamicPPL.alg_str(sampler) == "HMCDA"

@test isa(alg, HMCDA)
@test isa(sampler, Sampler{<:Turing.Hamiltonian})
end

@testset "nuts inference" begin
alg = NUTS(1000, 0.8; adtype=adbackend)
res = sample(rng, gdemo_default, alg, 6000)
res = sample(copy(rng), gdemo_default, alg, 500)
check_gdemo(res)
end

@testset "nuts constructor" begin
alg = NUTS(200, 0.65; adtype=adbackend)
sampler = Sampler(alg, gdemo_default)
Expand All @@ -189,22 +185,24 @@ using Turing
sampler = Sampler(alg, gdemo_default)
@test DynamicPPL.alg_str(sampler) == "NUTS"
end

@testset "check discard" begin
alg = NUTS(100, 0.8; adtype=adbackend)

c1 = sample(rng, gdemo_default, alg, 500; discard_adapt=true)
c2 = sample(rng, gdemo_default, alg, 500; discard_adapt=false)
c1 = sample(copy(rng), gdemo_default, alg, 500; discard_adapt=true)
c2 = sample(copy(rng), gdemo_default, alg, 500; discard_adapt=false)

@test size(c1, 1) == 500
@test size(c2, 1) == 500
end

@testset "AHMC resize" begin
alg1 = Gibbs(PG(10, :m), NUTS(100, 0.65, :s; adtype=adbackend))
alg2 = Gibbs(PG(10, :m), HMC(0.1, 3, :s; adtype=adbackend))
alg3 = Gibbs(PG(10, :m), HMCDA(100, 0.65, 0.3, :s; adtype=adbackend))
@test sample(rng, gdemo_default, alg1, 300) isa Chains
@test sample(rng, gdemo_default, alg2, 300) isa Chains
@test sample(rng, gdemo_default, alg3, 300) isa Chains
@test sample(copy(rng), gdemo_default, alg1, 10) isa Chains
@test sample(copy(rng), gdemo_default, alg2, 10) isa Chains
@test sample(copy(rng), gdemo_default, alg3, 10) isa Chains
end

@testset "Regression tests" begin
Expand All @@ -213,28 +211,28 @@ using Turing
m = Matrix{T}(undef, 2, 3)
return m .~ MvNormal(zeros(2), I)
end
@test sample(rng, mwe1(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains
@test sample(copy(rng), mwe1(), HMC(0.2, 4; adtype=adbackend), 100) isa Chains

@model function mwe2(::Type{T}=Matrix{Float64}) where {T}
m = T(undef, 2, 3)
return m .~ MvNormal(zeros(2), I)
end
@test sample(rng, mwe2(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains
@test sample(copy(rng), mwe2(), HMC(0.2, 4; adtype=adbackend), 100) isa Chains

# https://github.com/TuringLang/Turing.jl/issues/1308
@model function mwe3(::Type{T}=Array{Float64}) where {T}
m = T(undef, 2, 3)
return m .~ MvNormal(zeros(2), I)
end
@test sample(rng, mwe3(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains
@test sample(copy(rng), mwe3(), HMC(0.2, 4; adtype=adbackend), 100) isa Chains
end

# issue #1923
@testset "reproducibility" begin
alg = NUTS(1000, 0.8; adtype=adbackend)
res1 = sample(StableRNG(123), gdemo_default, alg, 1000)
res2 = sample(StableRNG(123), gdemo_default, alg, 1000)
res3 = sample(StableRNG(123), gdemo_default, alg, 1000)
res1 = sample(copy(rng), gdemo_default, alg, 10)
res2 = sample(copy(rng), gdemo_default, alg, 10)
res3 = sample(copy(rng), gdemo_default, alg, 10)
@test Array(res1) == Array(res2) == Array(res3)
end

Expand All @@ -249,7 +247,7 @@ using Turing
gdemo_default_prior = DynamicPPL.contextualize(
demo_hmc_prior(), DynamicPPL.PriorContext()
)
chain = sample(gdemo_default_prior, alg, 10_000; initial_params=[3.0, 0.0])
chain = sample(gdemo_default_prior, alg, 500; initial_params=[3.0, 0.0])
check_numerical(
chain, [:s, :m], [mean(truncated(Normal(3, 1); lower=0)), 0]; atol=0.2
)
Expand Down Expand Up @@ -288,7 +286,7 @@ using Turing
return xs[2] ~ Dirichlet(ones(5))
end
model = vector_of_dirichlet()
chain = sample(model, NUTS(), 1000)
chain = sample(model, NUTS(), 1_000)
@test mean(Array(chain)) ≈ 0.2
end

Expand Down Expand Up @@ -335,7 +333,7 @@ using Turing
gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context)
)
# These will error if the adbackend being used is not the one set.
sample(rng, m, alg, 10)
sample(copy(rng), m, alg, 10)
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/mcmc/is.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ using Turing
return x
end

chains = sample(test(), IS(), 10000)
chains = sample(test(), IS(), 1_000)

@test all(isone, chains[:x])
@test chains.logevidence ≈ -2 * log(2)
Expand Down
Loading
Loading