1
1
using DynamicPPL: DynamicPPL, VarInfo
2
- using DynamicPPL. TestUtils. AD: run_ad, ADResult, ADIncorrectException
2
+ using DynamicPPL. TestUtils. AD: run_ad, ADResult, ADIncorrectException, WithBackend
3
3
using ADTypes
4
4
using Random: Xoshiro
5
5
@@ -12,13 +12,14 @@ import Zygote
12
12
13
13
# AD backends to test.
14
14
ADTYPES = Dict (
15
- " FiniteDifferences" => AutoFiniteDifferences (; fdm = central_fdm (5 , 1 )),
15
+ " FiniteDifferences" => AutoFiniteDifferences (; fdm= central_fdm (5 , 1 )),
16
16
" ForwardDiff" => AutoForwardDiff (),
17
- " ReverseDiff" => AutoReverseDiff (; compile = false ),
18
- " ReverseDiffCompiled" => AutoReverseDiff (; compile = true ),
19
- " Mooncake" => AutoMooncake (; config = nothing ),
20
- " EnzymeForward" => AutoEnzyme (; mode = set_runtime_activity (Forward, true )),
21
- " EnzymeReverse" => AutoEnzyme (; mode = set_runtime_activity (Reverse, true )),
17
+ " ReverseDiff" => AutoReverseDiff (; compile= false ),
18
+ " ReverseDiffCompiled" => AutoReverseDiff (; compile= true ),
19
+ " MooncakeReverse" => AutoMooncake (),
20
+ " MooncakeForward" => AutoMooncakeForward (),
21
+ " EnzymeForward" => AutoEnzyme (; mode= set_runtime_activity (Forward, true )),
22
+ " EnzymeReverse" => AutoEnzyme (; mode= set_runtime_activity (Reverse, true )),
22
23
" Zygote" => AutoZygote (),
23
24
)
24
25
@@ -132,21 +133,18 @@ elseif length(ARGS) == 3 && ARGS[1] == "--run"
132
133
# https://github.com/TuringLang/ADTests/issues/4
133
134
vi = DynamicPPL. unflatten (VarInfo (model), [0.5 , - 0.5 ])
134
135
params = [- 0.5 , 0.5 ]
135
- result = run_ad (model, adtype; varinfo = vi, params = params, benchmark = true )
136
+ result = run_ad (model, adtype; varinfo= vi, params= params, test = WithBackend (ADTYPES[ " FiniteDifferences " ]), benchmark = true )
136
137
else
137
- vi = VarInfo (Xoshiro (468 ), model)
138
- linked_vi = DynamicPPL. link!! (vi, model)
139
- params = linked_vi[:]
140
138
result = run_ad (
141
139
model,
142
140
adtype;
143
- params = params ,
144
- reference_adtype = ADTYPES[" FiniteDifferences" ],
145
- benchmark = true ,
141
+ rng = Xoshiro ( 468 ) ,
142
+ test = WithBackend ( ADTYPES[" FiniteDifferences" ]) ,
143
+ benchmark= true ,
146
144
)
147
145
end
148
146
# If reached here - nothing went wrong
149
- println (result. time_vs_primal )
147
+ println (result. grad_time / result . primal_time )
150
148
catch e
151
149
@show e
152
150
if e isa ADIncorrectException
0 commit comments