|
1 | | -@timed_testset "ad_backends" begin |
| 1 | +using ForwardDiff: ForwardDiff |
| 2 | +using ReverseDiff: ReverseDiff |
| 3 | +using Mooncake: Mooncake |
| 4 | + |
| 5 | +@timed_testset "ad_backends.jl" begin |
2 | 6 | DATA_DIR = joinpath("..", "data") |
3 | 7 | cheese = CSV.read(joinpath(DATA_DIR, "cheese.csv"), DataFrame) |
4 | 8 | f = @formula(y ~ (1 | cheese) + background) |
5 | 9 | m = turing_model(f, cheese) |
6 | | - # only running 2 samples to test if the different ADs runs |
7 | | - @timed_testset "ForwardDiff" begin |
8 | | - chn = sample(m, NUTS(; adtype=AutoForwardDiff(; chunksize=8)), 2) |
9 | | - @test chn isa Chains |
10 | | - end |
11 | | - # TODO: fix Tracker tests |
12 | | - # @timed_testset "Tracker" begin |
13 | | - # using Tracker |
14 | | - # chn = sample(m, NUTS(; adtype=AutoTracker()), 2) |
15 | | - # @test chn isa Chains |
16 | | - # end |
17 | | - # TODO: fix Zygote tests |
18 | | - # @timed_testset "Zygote" begin |
19 | | - # using Zygote |
20 | | - # chn = sample(m, NUTS(; adtype=AutoZygote()), 2) |
21 | | - # @test chn isa Chains |
22 | | - # end |
23 | | - @timed_testset "ReverseDiff" begin |
24 | | - using ReverseDiff |
25 | | - chn = sample(m, NUTS(; adtype=AutoReverseDiff(; compile=true)), 2) |
26 | | - @test chn isa Chains |
| 10 | + |
| 11 | + ADTYPES = [ |
| 12 | + AutoForwardDiff(), |
| 13 | + AutoReverseDiff(; compile=false), |
| 14 | + AutoReverseDiff(; compile=true), |
| 15 | + AutoMooncake(; config=nothing), |
| 16 | + ] |
| 17 | + @testset "$adtype" for adtype in ADTYPES |
| 18 | + @test sample(m, NUTS(; adtype=adtype), 20) isa Chains |
27 | 19 | end |
28 | 20 | end |
0 commit comments