1
- import Test: @test , @testset
2
1
using DynamicPPL: DynamicPPL, VarInfo
3
2
using DynamicPPL. TestUtils. AD: run_ad, ADResult, ADIncorrectException
4
3
using ADTypes
5
- using Printf: @printf
6
4
7
5
import FiniteDifferences: central_fdm
8
6
import ForwardDiff
@@ -13,30 +11,26 @@ import Zygote
13
11
14
12
# AD backends to test.
15
13
ADTYPES = Dict (
16
- " FiniteDifferences" => AutoFiniteDifferences (; fdm = central_fdm (5 , 1 )),
14
+ " FiniteDifferences" => AutoFiniteDifferences (; fdm= central_fdm (5 , 1 )),
17
15
" ForwardDiff" => AutoForwardDiff (),
18
- " ReverseDiff" => AutoReverseDiff (; compile = false ),
19
- " ReverseDiffCompiled" => AutoReverseDiff (; compile = true ),
20
- " Mooncake" => AutoMooncake (),
21
- " EnzymeForward" => AutoEnzyme (; mode = set_runtime_activity (Forward, true )),
22
- " EnzymeReverse" => AutoEnzyme (; mode = set_runtime_activity (Reverse, true )),
16
+ " ReverseDiff" => AutoReverseDiff (; compile= false ),
17
+ " ReverseDiffCompiled" => AutoReverseDiff (; compile= true ),
18
+ " Mooncake" => AutoMooncake (; config = nothing ),
19
+ " EnzymeForward" => AutoEnzyme (; mode= set_runtime_activity (Forward, true )),
20
+ " EnzymeReverse" => AutoEnzyme (; mode= set_runtime_activity (Reverse, true )),
23
21
" Zygote" => AutoZygote (),
24
22
)
25
23
26
- # Models to test. The convention is that:
27
- # a, b, c, ... are assumed variables
28
- # x, y, z, ... are observed variables
29
- # although it's hardly a big deal.
30
- MODELS = Dict {String,DynamicPPL.Model} ()
31
- macro register (model)
32
- :(MODELS[string ($ (esc (model)). f)] = $ (esc (model)))
24
+ MODELS = Dict {String,Tuple{String,DynamicPPL.Model}} ()
25
+ macro register (category, model)
26
+ :(MODELS[string ($ (esc (model)). f)] = ($ (esc (category)), $ (esc (model))))
33
27
end
34
28
35
29
"""
36
- include_model(model_name::AbstractString)
30
+ include_model(category::AbstractString, model_name::AbstractString)
37
31
38
32
Add the model defined in `models/model_name.jl` to the full list of models
39
- tested in this script.
33
+ tested in this script. The model is registered under the given `category`.
40
34
41
35
We want to isolate every model in its own module, so that we avoid e.g.
42
36
variable clashes, and also so that we make imports explicit.
@@ -57,17 +51,17 @@ definition file:
57
51
- Once defined, the model is created using `model = model_name(...)`. The
58
52
`model` on the left-hand side is mandatory.
59
53
"""
60
- macro include_model (model_name:: AbstractString )
54
+ macro include_model (category :: AbstractString , model_name:: AbstractString )
61
55
MODELS_TO_LOAD = get (ENV , " ADTESTS_MODELS_TO_LOAD" , " __all__" )
62
56
if MODELS_TO_LOAD == " __all__" || model_name in split (MODELS_TO_LOAD, " ," )
63
57
# Declare a module containing the model. In principle esc() shouldn't
64
58
# be needed, but see https://github.com/JuliaLang/julia/issues/55677
65
59
Expr (:toplevel , esc (:(
66
60
module $ (gensym ())
67
- using . Main: @register
68
- using Turing
69
- include (" models/" * $ (model_name) * " .jl" )
70
- @register model
61
+ using . Main: @register
62
+ using Turing
63
+ include (" models/" * $ (model_name) * " .jl" )
64
+ @register $ (category) model
71
65
end
72
66
)))
73
67
else
@@ -76,71 +70,78 @@ macro include_model(model_name::AbstractString)
76
70
end
77
71
end
78
72
79
- @include_model " assume_beta"
80
- @include_model " assume_dirichlet"
81
- @include_model " assume_lkjcholu"
82
- @include_model " assume_mvnormal"
83
- @include_model " assume_normal"
84
- @include_model " assume_submodel"
85
- @include_model " assume_wishart"
86
- @include_model " broadcast_macro"
87
- @include_model " control_flow"
88
- @include_model " demo_assume_dot_observe"
89
- @include_model " demo_assume_dot_observe_literal"
90
- @include_model " demo_assume_index_observe"
91
- @include_model " demo_assume_matrix_observe_matrix_index"
92
- @include_model " demo_assume_multivariate_observe"
93
- @include_model " demo_assume_multivariate_observe_literal"
94
- @include_model " demo_assume_observe_literal"
95
- @include_model " demo_assume_submodel_observe_index_literal"
96
- @include_model " demo_dot_assume_observe"
97
- @include_model " demo_dot_assume_observe_index"
98
- @include_model " demo_dot_assume_observe_index_literal"
99
- @include_model " demo_dot_assume_observe_matrix_index"
100
- @include_model " demo_dot_assume_observe_submodel"
101
- @include_model " dot_assume"
102
- @include_model " dot_observe"
103
- @include_model " dppl_gauss_unknown"
104
- @include_model " dppl_hier_poisson"
105
- @include_model " dppl_high_dim_gauss"
106
- @include_model " dppl_hmm_semisup"
107
- @include_model " dppl_lda"
108
- @include_model " dppl_logistic_regression"
109
- @include_model " dppl_naive_bayes"
110
- @include_model " dppl_sto_volatility"
111
- @include_model " dynamic_constraint"
112
- @include_model " multiple_constraints_same_var"
113
- @include_model " multithreaded"
114
- @include_model " n010"
115
- @include_model " n050"
116
- @include_model " n100"
117
- @include_model " n500"
118
- @include_model " observe_bernoulli"
119
- @include_model " observe_categorical"
120
- @include_model " observe_index"
121
- @include_model " observe_literal"
122
- @include_model " observe_multivariate"
123
- @include_model " observe_submodel"
124
- @include_model " pdb_eight_schools_centered"
125
- @include_model " pdb_eight_schools_noncentered"
126
- @include_model " von_mises"
73
+ # Models to test. The convention is that:
74
+ # a, b, c, ... are assumed variables
75
+ # x, y, z, ... are observed variables
76
+ # although it's hardly a big deal.
77
+ @include_model " Base Julia features" " control_flow"
78
+ @include_model " Base Julia features" " multithreaded"
79
+ @include_model " Core Turing syntax" " broadcast_macro"
80
+ @include_model " Core Turing syntax" " dot_assume"
81
+ @include_model " Core Turing syntax" " dot_observe"
82
+ @include_model " Core Turing syntax" " dynamic_constraint"
83
+ @include_model " Core Turing syntax" " multiple_constraints_same_var"
84
+ @include_model " Core Turing syntax" " observe_index"
85
+ @include_model " Core Turing syntax" " observe_literal"
86
+ @include_model " Core Turing syntax" " observe_multivariate"
87
+ @include_model " Core Turing syntax" " observe_submodel"
88
+ @include_model " Distributions" " assume_beta"
89
+ @include_model " Distributions" " assume_dirichlet"
90
+ @include_model " Distributions" " assume_lkjcholu"
91
+ @include_model " Distributions" " assume_mvnormal"
92
+ @include_model " Distributions" " assume_normal"
93
+ @include_model " Distributions" " assume_submodel"
94
+ @include_model " Distributions" " assume_wishart"
95
+ @include_model " Distributions" " observe_bernoulli"
96
+ @include_model " Distributions" " observe_categorical"
97
+ @include_model " Distributions" " observe_von_mises"
98
+ @include_model " DynamicPPL arXiV paper" " dppl_gauss_unknown"
99
+ @include_model " DynamicPPL arXiV paper" " dppl_hier_poisson"
100
+ @include_model " DynamicPPL arXiV paper" " dppl_high_dim_gauss"
101
+ @include_model " DynamicPPL arXiV paper" " dppl_hmm_semisup"
102
+ @include_model " DynamicPPL arXiV paper" " dppl_lda"
103
+ @include_model " DynamicPPL arXiV paper" " dppl_logistic_regression"
104
+ @include_model " DynamicPPL arXiV paper" " dppl_naive_bayes"
105
+ @include_model " DynamicPPL arXiV paper" " dppl_sto_volatility"
106
+ @include_model " DynamicPPL demo models" " demo_assume_dot_observe"
107
+ @include_model " DynamicPPL demo models" " demo_assume_dot_observe_literal"
108
+ @include_model " DynamicPPL demo models" " demo_assume_index_observe"
109
+ @include_model " DynamicPPL demo models" " demo_assume_matrix_observe_matrix_index"
110
+ @include_model " DynamicPPL demo models" " demo_assume_multivariate_observe"
111
+ @include_model " DynamicPPL demo models" " demo_assume_multivariate_observe_literal"
112
+ @include_model " DynamicPPL demo models" " demo_assume_observe_literal"
113
+ @include_model " DynamicPPL demo models" " demo_assume_submodel_observe_index_literal"
114
+ @include_model " DynamicPPL demo models" " demo_dot_assume_observe"
115
+ @include_model " DynamicPPL demo models" " demo_dot_assume_observe_index"
116
+ @include_model " DynamicPPL demo models" " demo_dot_assume_observe_index_literal"
117
+ @include_model " DynamicPPL demo models" " demo_dot_assume_observe_matrix_index"
118
+ @include_model " DynamicPPL demo models" " demo_dot_assume_observe_submodel"
119
+ @include_model " Effect of model size" " n010"
120
+ @include_model " Effect of model size" " n050"
121
+ @include_model " Effect of model size" " n100"
122
+ @include_model " Effect of model size" " n500"
123
+ @include_model " PosteriorDB" " pdb_eight_schools_centered"
124
+ @include_model " PosteriorDB" " pdb_eight_schools_noncentered"
127
125
128
126
# The entry point to this script itself begins here
129
127
if ARGS == [" --list-model-keys" ]
130
128
foreach (println, sort (collect (keys (MODELS))))
131
129
elseif ARGS == [" --list-adtype-keys" ]
132
130
foreach (println, sort (collect (keys (ADTYPES))))
131
+ elseif length (ARGS ) == 2 && ARGS [1 ] == " --get-category"
132
+ println (MODELS[ARGS [2 ]][1 ])
133
133
elseif length (ARGS ) == 3 && ARGS [1 ] == " --run"
134
- model, adtype = MODELS[ARGS [2 ]], ADTYPES[ARGS [3 ]]
134
+ model_name, adtype_name = ARGS [2 ], ARGS [3 ]
135
+ model, adtype = MODELS[model_name][2 ], ADTYPES[adtype_name]
135
136
136
137
try
137
- if ARGS [ 2 ] == " control_flow"
138
+ if model_name == " control_flow"
138
139
# https://github.com/TuringLang/ADTests/issues/4
139
140
vi = DynamicPPL. unflatten (VarInfo (model), [0.5 , - 0.5 ])
140
141
params = [- 0.5 , 0.5 ]
141
- result = run_ad (model, adtype; varinfo = vi, params = params, benchmark = true )
142
+ result = run_ad (model, adtype; varinfo= vi, params= params, benchmark= true )
142
143
else
143
- result = run_ad (model, adtype; benchmark = true )
144
+ result = run_ad (model, adtype; benchmark= true )
144
145
end
145
146
# If reached here - nothing went wrong
146
147
println (result. time_vs_primal)
170
171
println (" Usage: julia main.jl --list-model-keys" )
171
172
println (" julia main.jl --list-adtype-keys" )
172
173
println (" julia main.jl --run <model> <adtype>" )
174
+ println (" julia main.jl --get-category <model>" )
173
175
end
0 commit comments