|
1 | | -@testset "AD: ForwardDiff and ReverseDiff" begin |
| 1 | +@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin |
2 | 2 | @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS |
3 | 3 | f = DynamicPPL.LogDensityFunction(m) |
4 | 4 | rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) |
|
17 | 17 | θ = convert(Vector{Float64}, varinfo[:]) |
18 | 18 | logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ) |
19 | 19 |
|
20 | | - @testset "ReverseDiff with compile=$compile" for compile in (false, true) |
21 | | - adtype = ADTypes.AutoReverseDiff(; compile=compile) |
22 | | - ad_f = LogDensityProblemsAD.ADgradient(adtype, f) |
23 | | - _, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ) |
24 | | - @test grad ≈ ref_grad |
| 20 | + @testset "$adtype" for adtype in [ |
| 21 | + ADTypes.AutoReverseDiff(; compile=false), |
| 22 | + ADTypes.AutoReverseDiff(; compile=true), |
| 23 | + ADTypes.AutoMooncake(; config=nothing), |
| 24 | + ] |
| 25 | + # Mooncake can't currently handle something that is going on in |
| 26 | + # SimpleVarInfo{<:VarNamedVector}. Disable tests for now. |
| 27 | + if adtype isa ADTypes.AutoMooncake && |
| 28 | + varinfo isa DynamicPPL.SimpleVarInfo{<:DynamicPPL.VarNamedVector} |
| 29 | + @test_broken 1 == 0 |
| 30 | + else |
| 31 | + ad_f = LogDensityProblemsAD.ADgradient(adtype, f) |
| 32 | + _, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ) |
| 33 | + @test grad ≈ ref_grad |
| 34 | + end |
25 | 35 | end |
26 | 36 | end |
27 | 37 | end |
|
0 commit comments