@@ -13,13 +13,13 @@ import Zygote
13
13
14
14
# AD backends to test.
15
15
ADTYPES = Dict (
16
- " FiniteDifferences" => AutoFiniteDifferences (; fdm= central_fdm (5 , 1 )),
16
+ " FiniteDifferences" => AutoFiniteDifferences (; fdm = central_fdm (5 , 1 )),
17
17
" ForwardDiff" => AutoForwardDiff (),
18
- " ReverseDiff" => AutoReverseDiff (; compile= false ),
19
- " ReverseDiffCompiled" => AutoReverseDiff (; compile= true ),
20
- " Mooncake" => AutoMooncake (; config= nothing ),
21
- " EnzymeForward" => AutoEnzyme (; mode= set_runtime_activity (Forward, true )),
22
- " EnzymeReverse" => AutoEnzyme (; mode= set_runtime_activity (Reverse, true )),
18
+ " ReverseDiff" => AutoReverseDiff (; compile = false ),
19
+ " ReverseDiffCompiled" => AutoReverseDiff (; compile = true ),
20
+ " Mooncake" => AutoMooncake (; config = nothing ),
21
+ " EnzymeForward" => AutoEnzyme (; mode = set_runtime_activity (Forward, true )),
22
+ " EnzymeReverse" => AutoEnzyme (; mode = set_runtime_activity (Reverse, true )),
23
23
" Zygote" => AutoZygote (),
24
24
)
25
25
35
35
# These imports tend to get used a lot in models
36
36
using DynamicPPL: @model , to_submodel
37
37
using Distributions
38
- using LinearAlgebra: I
38
+ using LinearAlgebra
39
39
40
40
include (" models/assume_dirichlet.jl" )
41
41
include (" models/assume_lkjcholu.jl" )
@@ -44,8 +44,21 @@ include("models/assume_normal.jl")
44
44
include (" models/assume_submodel.jl" )
45
45
include (" models/assume_wishart.jl" )
46
46
include (" models/control_flow.jl" )
47
- include (" models/dot_assume_observe_index.jl" )
47
+ include (" models/demo_assume_dot_observe_literal.jl" )
48
+ include (" models/demo_assume_dot_observe.jl" )
49
+ include (" models/demo_assume_index_observe.jl" )
50
+ include (" models/demo_assume_matrix_observe_matrix_index.jl" )
51
+ include (" models/demo_assume_multivariate_observe_literal.jl" )
52
+ include (" models/demo_assume_multivariate_observe.jl" )
53
+ include (" models/demo_assume_observe_literal.jl" )
54
+ include (" models/demo_assume_submodel_observe_index_literal.jl" )
55
+ include (" models/demo_dot_assume_observe_index_literal.jl" )
56
+ include (" models/demo_dot_assume_observe_index.jl" )
57
+ include (" models/demo_dot_assume_observe_matrix_index.jl" )
58
+ include (" models/demo_dot_assume_observe_submodel.jl" )
59
+ include (" models/demo_dot_assume_observe.jl" )
48
60
include (" models/dot_assume.jl" )
61
+ include (" models/demo_dot_assume_observe.jl" )
49
62
include (" models/dot_observe.jl" )
50
63
include (" models/dynamic_constraint.jl" )
51
64
include (" models/multiple_constraints_same_var.jl" )
@@ -74,9 +87,9 @@ elseif length(ARGS) == 3 && ARGS[1] == "--run"
74
87
# https://github.com/TuringLang/ADTests/issues/4
75
88
vi = DynamicPPL. unflatten (VarInfo (model), [0.5 , - 0.5 ])
76
89
params = [- 0.5 , 0.5 ]
77
- result = run_ad (model, adtype; varinfo= vi, params= params, benchmark= true )
90
+ result = run_ad (model, adtype; varinfo = vi, params = params, benchmark = true )
78
91
else
79
- result = run_ad (model, adtype; benchmark= true )
92
+ result = run_ad (model, adtype; benchmark = true )
80
93
end
81
94
# If reached here - nothing went wrong
82
95
@printf (" %.3f" , result. time_vs_primal)
0 commit comments