|
1 |
| -@timed_testset "ad_backends" begin |
| 1 | +using ForwardDiff: ForwardDiff |
| 2 | +using ReverseDiff: ReverseDiff |
| 3 | +using Mooncake: Mooncake |
| 4 | + |
| 5 | +@timed_testset "ad_backends.jl" begin |
2 | 6 | DATA_DIR = joinpath("..", "data")
|
3 | 7 | cheese = CSV.read(joinpath(DATA_DIR, "cheese.csv"), DataFrame)
|
4 | 8 | f = @formula(y ~ (1 | cheese) + background)
|
5 | 9 | m = turing_model(f, cheese)
|
6 |
| - # only running 2 samples to test if the different ADs runs |
7 |
| - @timed_testset "ForwardDiff" begin |
8 |
| - chn = sample(m, NUTS(; adtype=AutoForwardDiff(; chunksize=8)), 2) |
9 |
| - @test chn isa Chains |
10 |
| - end |
11 |
| - # TODO: fix Tracker tests |
12 |
| - # @timed_testset "Tracker" begin |
13 |
| - # using Tracker |
14 |
| - # chn = sample(m, NUTS(; adtype=AutoTracker()), 2) |
15 |
| - # @test chn isa Chains |
16 |
| - # end |
17 |
| - # TODO: fix Zygote tests |
18 |
| - # @timed_testset "Zygote" begin |
19 |
| - # using Zygote |
20 |
| - # chn = sample(m, NUTS(; adtype=AutoZygote()), 2) |
21 |
| - # @test chn isa Chains |
22 |
| - # end |
23 |
| - @timed_testset "ReverseDiff" begin |
24 |
| - using ReverseDiff |
25 |
| - chn = sample(m, NUTS(; adtype=AutoReverseDiff(; compile=true)), 2) |
26 |
| - @test chn isa Chains |
| 10 | + |
| 11 | + ADTYPES = [ |
| 12 | + AutoForwardDiff(), |
| 13 | + AutoReverseDiff(; compile=false), |
| 14 | + AutoReverseDiff(; compile=true), |
| 15 | + AutoMooncake(; config=nothing), |
| 16 | + ] |
| 17 | + @testset "$adtype" for adtype in ADTYPES |
| 18 | + @test sample(m, NUTS(; adtype=adtype), 20) isa Chains |
27 | 19 | end
|
28 | 20 | end
|
0 commit comments