@@ -13,13 +13,13 @@ import Zygote
13
13
14
14
# AD backends to test.
15
15
ADTYPES = Dict (
16
- " FiniteDifferences" => AutoFiniteDifferences (; fdm = central_fdm (5 , 1 )),
16
+ " FiniteDifferences" => AutoFiniteDifferences (; fdm= central_fdm (5 , 1 )),
17
17
" ForwardDiff" => AutoForwardDiff (),
18
- " ReverseDiff" => AutoReverseDiff (; compile = false ),
19
- " ReverseDiffCompiled" => AutoReverseDiff (; compile = true ),
20
- " Mooncake" => AutoMooncake (; config = nothing ),
21
- " EnzymeForward" => AutoEnzyme (; mode = set_runtime_activity (Forward, true )),
22
- " EnzymeReverse" => AutoEnzyme (; mode = set_runtime_activity (Reverse, true )),
18
+ " ReverseDiff" => AutoReverseDiff (; compile= false ),
19
+ " ReverseDiffCompiled" => AutoReverseDiff (; compile= true ),
20
+ " Mooncake" => AutoMooncake (; config= nothing ),
21
+ " EnzymeForward" => AutoEnzyme (; mode= set_runtime_activity (Forward, true )),
22
+ " EnzymeReverse" => AutoEnzyme (; mode= set_runtime_activity (Reverse, true )),
23
23
" Zygote" => AutoZygote (),
24
24
)
25
25
35
35
# These imports tend to get used a lot in models
36
36
using DynamicPPL: @model , to_submodel
37
37
using Distributions
38
- using LinearAlgebra
38
+ using LinearAlgebra: I
39
39
40
40
include (" models/assume_dirichlet.jl" )
41
41
include (" models/assume_lkjcholu.jl" )
@@ -44,21 +44,8 @@ include("models/assume_normal.jl")
44
44
include (" models/assume_submodel.jl" )
45
45
include (" models/assume_wishart.jl" )
46
46
include (" models/control_flow.jl" )
47
- include (" models/demo_assume_dot_observe_literal.jl" )
48
- include (" models/demo_assume_dot_observe.jl" )
49
- include (" models/demo_assume_index_observe.jl" )
50
- include (" models/demo_assume_matrix_observe_matrix_index.jl" )
51
- include (" models/demo_assume_multivariate_observe_literal.jl" )
52
- include (" models/demo_assume_multivariate_observe.jl" )
53
- include (" models/demo_assume_observe_literal.jl" )
54
- include (" models/demo_assume_submodel_observe_index_literal.jl" )
55
- include (" models/demo_dot_assume_observe_index_literal.jl" )
56
- include (" models/demo_dot_assume_observe_index.jl" )
57
- include (" models/demo_dot_assume_observe_matrix_index.jl" )
58
- include (" models/demo_dot_assume_observe_submodel.jl" )
59
- include (" models/demo_dot_assume_observe.jl" )
47
+ include (" models/dot_assume_observe_index.jl" )
60
48
include (" models/dot_assume.jl" )
61
- include (" models/demo_dot_assume_observe.jl" )
62
49
include (" models/dot_observe.jl" )
63
50
include (" models/dynamic_constraint.jl" )
64
51
include (" models/multiple_constraints_same_var.jl" )
@@ -87,9 +74,9 @@ elseif length(ARGS) == 3 && ARGS[1] == "--run"
87
74
# https://github.com/TuringLang/ADTests/issues/4
88
75
vi = DynamicPPL. unflatten (VarInfo (model), [0.5 , - 0.5 ])
89
76
params = [- 0.5 , 0.5 ]
90
- result = run_ad (model, adtype; varinfo = vi, params = params, benchmark = true )
77
+ result = run_ad (model, adtype; varinfo= vi, params= params, benchmark= true )
91
78
else
92
- result = run_ad (model, adtype; benchmark = true )
79
+ result = run_ad (model, adtype; benchmark= true )
93
80
end
94
81
# If reached here - nothing went wrong
95
82
@printf (" %.3f" , result. time_vs_primal)
0 commit comments