Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
Expand Down Expand Up @@ -52,7 +53,7 @@ Combinatorics = "1"
Distributions = "0.25"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.36"
DynamicPPL = "0.36.6"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
HypothesisTests = "0.11"
Expand Down
133 changes: 116 additions & 17 deletions test/test_utils/ad_utils.jl → test/ad.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
module ADUtils
module TuringADTests

using ForwardDiff: ForwardDiff
using Pkg: Pkg
using Turing
using DynamicPPL
using DynamicPPL.TestUtils: DEMO_MODELS
using DynamicPPL.TestUtils.AD: run_ad
using Random: Random
using ReverseDiff: ReverseDiff
using Mooncake: Mooncake
using Test: Test
using Turing: Turing
using Turing: DynamicPPL

export ADTypeCheckContext, adbackends

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Stuff for checking that the right AD backend is being used.
using StableRNGs: StableRNG
using Test
using ..Models: gdemo_default
import ForwardDiff, ReverseDiff, Mooncake

"""Element types that are always valid for a VarInfo regardless of ADType."""
const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational)
Expand Down Expand Up @@ -178,16 +174,119 @@ function DynamicPPL.tilde_observe(context::ADTypeCheckContext, sampler, right, l
return logp, vi
end

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# List of AD backends to test.

"""
All the ADTypes on which we want to run the tests.
"""
adbackends = [
ADTYPES = [
Turing.AutoForwardDiff(),
Turing.AutoReverseDiff(; compile=false),
Turing.AutoMooncake(; config=nothing),
]

# Check that ADTypeCheckContext itself works as expected.
@testset "ADTypeCheckContext" begin
@model test_model() = x ~ Normal(0, 1)
tm = test_model()
adtypes = (
Turing.AutoForwardDiff(),
Turing.AutoReverseDiff(),
# Don't need to test Mooncake as it doesn't use tracer types
)
for actual_adtype in adtypes
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
for expected_adtype in adtypes
contextualised_tm = DynamicPPL.contextualize(
tm, ADTypeCheckContext(expected_adtype, tm.context)
)
@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin
if actual_adtype == expected_adtype
# Check that this does not throw an error.
Turing.sample(contextualised_tm, sampler, 2)
else
@test_throws AbstractWrongADBackendError Turing.sample(
contextualised_tm, sampler, 2
)
end
end
end
end
end

@testset verbose = true "AD / ADTypeCheckContext" begin
# This testset ensures that samplers or optimisers don't accidentally
# override the AD backend set in it.
@testset "adtype=$adtype" for adtype in ADTYPES
seed = 123
alg = HMC(0.1, 10; adtype=adtype)
m = DynamicPPL.contextualize(
gdemo_default, ADTypeCheckContext(adtype, gdemo_default.context)
)
# These will error if the adbackend being used is not the one set.
sample(StableRNG(seed), m, alg, 10)
maximum_likelihood(m; adtype=adtype)
maximum_a_posteriori(m; adtype=adtype)
end
end

@testset verbose = true "AD / SamplingContext" begin
# AD tests for gradient-based samplers need to be run with SamplingContext
# because samplers can potentially use this to define custom behaviour in
# the tilde-pipeline and thus change the code executed during model
# evaluation.
@testset "adtype=$adtype" for adtype in ADTYPES
@testset "alg=$alg" for alg in [
HMC(0.1, 10; adtype=adtype),
HMCDA(0.8, 0.75; adtype=adtype),
NUTS(1000, 0.8; adtype=adtype),
SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype),
SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype),
]
@info "Testing AD for $alg"

@testset "model=$(model.f)" for model in DEMO_MODELS
rng = StableRNG(123)
ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg))
@test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any
end
end
end
end

@testset verbose = true "AD / GibbsContext" begin
# Gibbs sampling also needs extra AD testing because the models are
# executed with GibbsContext and a subsetted varinfo. (see e.g.
# `gibbs_initialstep_recursive` and `gibbs_step_recursive` in
# src/mcmc/gibbs.jl -- the code here mimics what happens in those
# functions)
@testset "adtype=$adtype" for adtype in ADTYPES
@testset "model=$(model.f)" for model in DEMO_MODELS
# All the demo models have variables `s` and `m`, so we'll pretend
# that we're using a Gibbs sampler where both of them are sampled
# with a gradient-based sampler (say HMC(0.1, 10)).
# This means we need to construct one with only `s`, and one model with
# only `m`.
global_vi = DynamicPPL.VarInfo(model)
@testset for varnames in ([@varname(s)], [@varname(m)])
@info "Testing Gibbs AD with model=$(model.f), varnames=$varnames"
conditioned_model = Turing.Inference.make_conditional(
model, varnames, deepcopy(global_vi)
)
rng = StableRNG(123)
ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10)))
@test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any
end
end
end
end

@testset verbose = true "AD / Gibbs sampling" begin
# Make sure that Gibbs sampling doesn't fall over when using AD.
spl = Gibbs(@varname(s) => HMC(0.1, 10), @varname(m) => HMC(0.1, 10))
@testset "adtype=$adtype" for adtype in ADTYPES
@testset "model=$(model.f)" for model in DEMO_MODELS
@test sample(model, spl, 2) isa Any
end
end
end

end # module
16 changes: 6 additions & 10 deletions test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ module InferenceTests

using ..Models: gdemo_d, gdemo_default
using ..NumericalTests: check_gdemo, check_numerical
import ..ADUtils
using Distributions: Bernoulli, Beta, InverseGamma, Normal
using Distributions: sample
import DynamicPPL
Expand All @@ -17,9 +16,11 @@ import Mooncake
using Test: @test, @test_throws, @testset
using Turing

@testset "Testing inference.jl with $adbackend" for adbackend in ADUtils.adbackends
@info "Starting Inference.jl tests with $adbackend"
@testset verbose = true "Testing Inference.jl" begin
@info "Starting Inference.jl tests"

seed = 23
adbackend = Turing.DEFAULT_ADTYPE

@testset "threaded sampling" begin
# Test that chains with the same seed will sample identically.
Expand All @@ -44,7 +45,7 @@ using Turing
@test chain1.value == chain2.value
end

# Should also be stable with am explicit RNG
# Should also be stable with an explicit RNG
seed = 5
rng = Random.MersenneTwister(seed)
for sampler in samplers
Expand Down Expand Up @@ -273,17 +274,12 @@ using Turing

@testset "forbid global" begin
xs = [1.5 2.0]
# xx = 1

@model function fggibbstest(xs)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
# xx ~ Normal(m, sqrt(s)) # this is illegal

for i in 1:length(xs)
xs[i] ~ Normal(m, sqrt(s))
# for xx in xs
# xx ~ Normal(m, sqrt(s))
end
return s, m
end
Expand Down Expand Up @@ -353,7 +349,7 @@ using Turing
)
end

# TODO(mhauru) What is this testing? Why does it not use the looped-over adbackend?
# TODO(mhauru) What is this testing? Why does it use a different adbackend?
@testset "new interface" begin
obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1]

Expand Down
61 changes: 30 additions & 31 deletions test/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module AbstractMCMCTests

import ..ADUtils
using AbstractMCMC: AbstractMCMC
using AdvancedMH: AdvancedMH
using Distributions: sample
Expand Down Expand Up @@ -110,38 +109,38 @@ function test_initial_params(
end
end

@testset "External samplers" begin
@testset verbose = true "External samplers" begin
@testset "AdvancedHMC.jl" begin
@testset "adtype=$adtype" for adtype in ADUtils.adbackends
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
# Need some functionality to initialize the sampler.
# TODO: Remove this once the constructors in the respective packages become "lazy".
sampler = initialize_nuts(model)
sampler_ext = DynamicPPL.Sampler(
externalsampler(sampler; adtype, unconstrained=true)
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
adtype = Turing.DEFAULT_ADTYPE

# Need some functionality to initialize the sampler.
# TODO: Remove this once the constructors in the respective packages become "lazy".
sampler = initialize_nuts(model)
sampler_ext = DynamicPPL.Sampler(
externalsampler(sampler; adtype, unconstrained=true)
)
# FIXME: Once https://github.com/TuringLang/AdvancedHMC.jl/pull/366 goes through, uncomment.
# @testset "initial_params" begin
# test_initial_params(model, sampler_ext; n_adapts=0)
# end

sample_kwargs = (
n_adapts=1_000,
discard_initial=1_000,
# FIXME: Remove this once we can run `test_initial_params` above.
initial_params=DynamicPPL.VarInfo(model)[:],
)

@testset "inference" begin
DynamicPPL.TestUtils.test_sampler(
[model],
sampler_ext,
2_000;
rtol=0.2,
sampler_name="AdvancedHMC",
sample_kwargs...,
)
# FIXME: Once https://github.com/TuringLang/AdvancedHMC.jl/pull/366 goes through, uncomment.
# @testset "initial_params" begin
# test_initial_params(model, sampler_ext; n_adapts=0)
# end

sample_kwargs = (
n_adapts=1_000,
discard_initial=1_000,
# FIXME: Remove this once we can run `test_initial_params` above.
initial_params=DynamicPPL.VarInfo(model)[:],
)

@testset "inference" begin
DynamicPPL.TestUtils.test_sampler(
[model],
sampler_ext,
2_000;
rtol=0.2,
sampler_name="AdvancedHMC",
sample_kwargs...,
)
end
end
end
end
Expand Down
48 changes: 11 additions & 37 deletions test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ using ..NumericalTests:
check_gdemo,
check_numerical,
two_sample_test
import ..ADUtils
import Combinatorics
using AbstractMCMC: AbstractMCMC
using Distributions: InverseGamma, Normal
Expand All @@ -34,19 +33,7 @@ function check_transition_varnames(transition::Turing.Inference.Transition, pare
end
end

const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{
DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_index_observe)},
DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe)},
DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe)},
DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe_literal)},
DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_observe_literal)},
DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe_literal)},
DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_matrix_observe_matrix_index)},
}
has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false
has_dot_assume(::DynamicPPL.Model) = true

@testset "GibbsContext" begin
@testset verbose = true "GibbsContext" begin
@testset "type stability" begin
struct Wrapper{T<:Real}
a::T
Expand Down Expand Up @@ -384,8 +371,9 @@ end
@test wuc.non_warmup_count == (num_samples - 1) * num_reps
end

@testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends
@info "Starting Gibbs tests with $adbackend"
@testset verbose = true "Testing gibbs.jl" begin
@info "Starting Gibbs tests"
adbackend = Turing.DEFAULT_ADTYPE

@testset "Gibbs constructors" begin
# Create Gibbs samplers with various configurations and ways of passing the
Expand Down Expand Up @@ -597,41 +585,27 @@ end
@model function dynamic_model_with_dot_tilde(
num_zs=10, ::Type{M}=Vector{Float64}
) where {M}
z = M(undef, num_zs)
z = Vector{Int}(undef, num_zs)
z .~ Poisson(1.0)
num_ms = sum(z)
m = M(undef, num_ms)
return m .~ Normal(1.0, 1.0)
end
model = dynamic_model_with_dot_tilde()
# TODO(mhauru) This is broken because of
# https://github.com/TuringLang/DynamicPPL.jl/issues/700.
@test_broken (
sample(model, Gibbs(:z => PG(10), :m => HMC(0.01, 4; adtype=adbackend)), 100);
true
)
sample(model, Gibbs(:z => PG(10), :m => HMC(0.01, 4; adtype=adbackend)), 100)
end

@testset "Demo models" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
@testset "Demo model" begin
@testset verbose = true "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
vns = DynamicPPL.TestUtils.varnames(model)
samplers = [
Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => NUTS()),
Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => HMC(0.01, 4)),
Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => ESS()),
Turing.Gibbs(@varname(s) => HMC(0.01, 4), @varname(m) => MH()),
Turing.Gibbs(@varname(s) => MH(), @varname(m) => HMC(0.01, 4)),
]

if !has_dot_assume(model)
# Add in some MH samplers, which are not compatible with `.~`.
append!(
samplers,
[
Turing.Gibbs(@varname(s) => HMC(0.01, 4), @varname(m) => MH()),
Turing.Gibbs(@varname(s) => MH(), @varname(m) => HMC(0.01, 4)),
],
)
end

Comment on lines +601 to -634
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also removed this (outdated) check for dot assume so now it checks with every sampler combination

@testset "$sampler" for sampler in samplers
# Check that taking steps performs as expected.
rng = Random.default_rng()
Expand Down Expand Up @@ -846,7 +820,7 @@ end
check_MoGtest_default_z_vector(chain; atol=0.2)
end

@testset "externsalsampler" begin
@testset "externalsampler" begin
@model function demo_gibbs_external()
m1 ~ Normal()
m2 ~ Normal()
Expand Down
Loading
Loading