Skip to content

Commit 887b82a

Browse files
committed
Use FiniteDifferences as comparison; fix parameter generation
1 parent c049fd8 commit 887b82a

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

main.jl

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using DynamicPPL: DynamicPPL, VarInfo
22
using DynamicPPL.TestUtils.AD: run_ad, ADResult, ADIncorrectException
33
using ADTypes
4+
using Random: Xoshiro
45

56
import FiniteDifferences: central_fdm
67
import ForwardDiff
@@ -11,13 +12,13 @@ import Zygote
1112

1213
# AD backends to test.
1314
ADTYPES = Dict(
14-
"FiniteDifferences" => AutoFiniteDifferences(; fdm=central_fdm(5, 1)),
15+
"FiniteDifferences" => AutoFiniteDifferences(; fdm = central_fdm(5, 1)),
1516
"ForwardDiff" => AutoForwardDiff(),
16-
"ReverseDiff" => AutoReverseDiff(; compile=false),
17-
"ReverseDiffCompiled" => AutoReverseDiff(; compile=true),
18-
"Mooncake" => AutoMooncake(; config=nothing),
19-
"EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward, true)),
20-
"EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse, true)),
17+
"ReverseDiff" => AutoReverseDiff(; compile = false),
18+
"ReverseDiffCompiled" => AutoReverseDiff(; compile = true),
19+
"Mooncake" => AutoMooncake(; config = nothing),
20+
"EnzymeForward" => AutoEnzyme(; mode = set_runtime_activity(Forward, true)),
21+
"EnzymeReverse" => AutoEnzyme(; mode = set_runtime_activity(Reverse, true)),
2122
"Zygote" => AutoZygote(),
2223
)
2324

@@ -56,14 +57,12 @@ macro include_model(category::AbstractString, model_name::AbstractString)
5657
if MODELS_TO_LOAD == "__all__" || model_name in split(MODELS_TO_LOAD, ",")
5758
# Declare a module containing the model. In principle esc() shouldn't
5859
# be needed, but see https://github.com/JuliaLang/julia/issues/55677
59-
Expr(:toplevel, esc(:(
60-
module $(gensym())
61-
using .Main: @register
62-
using Turing
63-
include("models/" * $(model_name) * ".jl")
64-
@register $(category) model
65-
end
66-
)))
60+
Expr(:toplevel, esc(:(module $(gensym())
61+
using .Main: @register
62+
using Turing
63+
include("models/" * $(model_name) * ".jl")
64+
@register $(category) model
65+
end)))
6766
else
6867
# Empty expression
6968
:()
@@ -133,9 +132,18 @@ elseif length(ARGS) == 3 && ARGS[1] == "--run"
133132
# https://github.com/TuringLang/ADTests/issues/4
134133
vi = DynamicPPL.unflatten(VarInfo(model), [0.5, -0.5])
135134
params = [-0.5, 0.5]
136-
result = run_ad(model, adtype; varinfo=vi, params=params, benchmark=true)
135+
result = run_ad(model, adtype; varinfo = vi, params = params, benchmark = true)
137136
else
138-
result = run_ad(model, adtype; benchmark=true)
137+
vi = VarInfo(Xoshiro(468), model)
138+
linked_vi = DynamicPPL.link!!(vi, model)
139+
params = linked_vi[:]
140+
result = run_ad(
141+
model,
142+
adtype;
143+
params = params,
144+
reference_adtype = ADTYPES["FiniteDifferences"],
145+
benchmark = true,
146+
)
139147
end
140148
# If reached here - nothing went wrong
141149
println(result.time_vs_primal)

0 commit comments

Comments
 (0)