|
5 | 5 | m = turing_model(f, cheese)
|
6 | 6 | # only running 2 samples to test if the different ADs runs
|
7 | 7 | @timed_testset "ForwardDiff" begin
|
8 |
| - Turing.setadbackend(:forwarddiff) |
9 |
| - chn = sample(m, NUTS(), 2) |
| 8 | + chn = sample(m, NUTS(; adtype=AutoForwardDiff(; chunksize=8)), 2) |
10 | 9 | @test chn isa Chains
|
11 | 10 | end
|
12 | 11 | # TODO: fix Tracker tests
|
13 | 12 | # @timed_testset "Tracker" begin
|
14 | 13 | # using Tracker
|
15 |
| - # Turing.setadbackend(:tracker) |
16 |
| - # chn = sample(m, NUTS(), 2) |
| 14 | + # chn = sample(m, NUTS(; adtype=AutoTracker()), 2) |
17 | 15 | # @test chn isa Chains
|
18 | 16 | # end
|
19 | 17 | # TODO: fix Zygote tests
|
20 | 18 | # @timed_testset "Zygote" begin
|
21 | 19 | # using Zygote
|
22 |
| - # Turing.setadbackend(:zygote) |
23 |
| - # chn = sample(m, NUTS(), 2) |
| 20 | + # chn = sample(m, NUTS(; adtype=AutoZygote()), 2) |
24 | 21 | # @test chn isa Chains
|
25 | 22 | # end
|
26 | 23 | @timed_testset "ReverseDiff" begin
|
27 | 24 | using ReverseDiff
|
28 |
| - Turing.setadbackend(:reversediff) |
29 |
| - chn = sample(m, NUTS(), 2) |
| 25 | + chn = sample(m, NUTS(; adtype=AutoReverseDiff(; compile=true)), 2) |
30 | 26 | @test chn isa Chains
|
31 | 27 | end
|
32 |
| - # go back to defaults |
33 |
| - Turing.setadbackend(:forwarddiff) |
34 | 28 | end
|
0 commit comments