|
1 |
| -@testset "AD: ForwardDiff and ReverseDiff" begin |
| 1 | +@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin |
2 | 2 | @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
|
3 | 3 | f = DynamicPPL.LogDensityFunction(m)
|
4 | 4 | rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
|
|
17 | 17 | θ = convert(Vector{Float64}, varinfo[:])
|
18 | 18 | logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ)
|
19 | 19 |
|
20 |
| - @testset "ReverseDiff with compile=$compile" for compile in (false, true) |
21 |
| - adtype = ADTypes.AutoReverseDiff(; compile=compile) |
22 |
| - ad_f = LogDensityProblemsAD.ADgradient(adtype, f) |
23 |
| - _, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ) |
24 |
| - @test grad ≈ ref_grad |
| 20 | + @testset "$adtype" for adtype in [ |
| 21 | + ADTypes.AutoReverseDiff(; compile=false), |
| 22 | + ADTypes.AutoReverseDiff(; compile=true), |
| 23 | + ADTypes.AutoMooncake(; config=nothing), |
| 24 | + ] |
| 25 | + # Mooncake can't currently handle something that is going on in |
| 26 | + # SimpleVarInfo{<:VarNamedVector}. Disable tests for now. |
| 27 | + if adtype isa ADTypes.AutoMooncake && |
| 28 | + varinfo isa DynamicPPL.SimpleVarInfo{<:DynamicPPL.VarNamedVector} |
| 29 | + @test_broken 1 == 0 |
| 30 | + else |
| 31 | + ad_f = LogDensityProblemsAD.ADgradient(adtype, f) |
| 32 | + _, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ) |
| 33 | + @test grad ≈ ref_grad |
| 34 | + end |
25 | 35 | end
|
26 | 36 | end
|
27 | 37 | end
|
|
0 commit comments