Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ ADTYPES = [
adtypes = (
Turing.AutoForwardDiff(),
Turing.AutoReverseDiff(),
# TODO: Mooncake
# Turing.AutoMooncake(config=nothing),
# Don't need to test Mooncake as it doesn't use tracer types
)
for actual_adtype in adtypes
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
Expand Down Expand Up @@ -280,4 +279,14 @@ end
end
end

@testset verbose = true "AD / Gibbs sampling" begin
# Make sure that Gibbs sampling doesn't fall over when using AD.
spl = Gibbs(@varname(s) => HMC(0.1, 10), @varname(m) => HMC(0.1, 10))
@testset "adtype=$adtype" for adtype in ADTYPES
@testset "model=$(model.f)" for model in DEMO_MODELS
@test sample(model, spl, 2) isa Any
end
end
end

end # module
25 changes: 2 additions & 23 deletions test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,6 @@ function check_transition_varnames(transition::Turing.Inference.Transition, pare
end
end

const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{
DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_index_observe)},
DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe)},
DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe)},
DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe_literal)},
DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_observe_literal)},
DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe_literal)},
DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_matrix_observe_matrix_index)},
}
has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false
has_dot_assume(::DynamicPPL.Model) = true

@testset verbose = true "GibbsContext" begin
@testset "type stability" begin
struct Wrapper{T<:Real}
Expand Down Expand Up @@ -614,19 +602,10 @@ end
Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => NUTS()),
Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => HMC(0.01, 4)),
Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => ESS()),
Turing.Gibbs(@varname(s) => HMC(0.01, 4), @varname(m) => MH()),
Turing.Gibbs(@varname(s) => MH(), @varname(m) => HMC(0.01, 4)),
]

if !has_dot_assume(model)
# Add in some MH samplers, which are not compatible with `.~`.
append!(
samplers,
[
Turing.Gibbs(@varname(s) => HMC(0.01, 4), @varname(m) => MH()),
Turing.Gibbs(@varname(s) => MH(), @varname(m) => HMC(0.01, 4)),
],
)
end

Comment on lines +601 to -634
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also removed this (outdated) check for dot assume so now it checks with every sampler combination

@testset "$sampler" for sampler in samplers
# Check that taking steps performs as expected.
rng = Random.default_rng()
Expand Down
Loading