Skip to content

Commit 46bbf06

Browse files
committed
Locate tests better
1 parent 3877cf0 commit 46bbf06

File tree

2 files changed

+16
-18
lines changed

2 files changed

+16
-18
lines changed

test/ad.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testset "AD: ForwardDiff and ReverseDiff" begin
1+
@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin
22
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
33
f = DynamicPPL.LogDensityFunction(m)
44
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
@@ -17,11 +17,21 @@
1717
θ = convert(Vector{Float64}, varinfo[:])
1818
logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ)
1919

20-
@testset "ReverseDiff with compile=$compile" for compile in (false, true)
21-
adtype = ADTypes.AutoReverseDiff(; compile=compile)
22-
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
23-
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
24-
@test grad ref_grad
20+
@testset "$adtype" for adtype in [
21+
ADTypes.AutoReverseDiff(; compile=false),
22+
ADTypes.AutoReverseDiff(; compile=true),
23+
ADTypes.AutoMooncake(; config=nothing),
24+
]
25+
# Mooncake can't currently handle something that is going on in
26+
# SimpleVarInfo{<:VarNamedVector}. Disable tests for now.
27+
if adtype isa ADTypes.AutoMooncake &&
28+
varinfo isa DynamicPPL.SimpleVarInfo{<:DynamicPPL.VarNamedVector}
29+
@test_broken 1 == 0
30+
else
31+
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
32+
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
33+
@test grad ref_grad
34+
end
2535
end
2636
end
2737
end

test/logdensityfunction.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,6 @@ end
3131
θ = varinfo[:]
3232
@test LogDensityProblems.logdensity(logdensity, θ) logjoint(model, varinfo)
3333
@test LogDensityProblems.dimension(logdensity) == length(θ)
34-
35-
# Test a single backend on the generic
36-
# ADgradient(::AbstractADType, ::LogDensityFunction) method. This really just
37-
# checks that it runs at all.
38-
if varinfo isa DynamicPPL.TypedVarInfo
39-
ad = ADTypes.AutoMooncake(; config=nothing)
40-
∇ℓ = LogDensityProblemsAD.ADgradient(ad, logdensity)
41-
@test isa(
42-
LogDensityProblems.logdensity_and_gradient(∇ℓ, θ),
43-
Tuple{Float64, Vector{Float64}},
44-
)
45-
end
4634
end
4735
end
4836
end

0 commit comments

Comments
 (0)