Skip to content

Commit 96ba389

Browse files
committed
Add categories (?)
1 parent 6e88b49 commit 96ba389

File tree

4 files changed

+85
-72
lines changed

4 files changed

+85
-72
lines changed

ad.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,15 @@ def run_ad(args):
8383
else:
8484
RUN_JULIA_COMMAND = JULIA_COMMAND
8585

86+
# Get category
87+
try:
88+
category = run_and_capture([*RUN_JULIA_COMMAND, "--get-category", model_key])
89+
except sp.CalledProcessError as e:
90+
print(f"Julia crashed when getting category for {model_key}.")
91+
print(f"To reproduce, run: `julia --project=. main.jl --get-category {model_key}`")
92+
category = "error"
93+
results["__category__"] = category
94+
8695
# Run tests
8796
for adtype in adtypes:
8897
print(f"Running {model_key} with {adtype}...")
@@ -127,12 +136,14 @@ def html(_args):
127136
# [
128137
# {"model_name": "model1",
129138
# "results": {
139+
# "__category_": "category1",
130140
# "AD1": "result1",
131141
# "AD2": "result2"
132142
# }
133143
# },
134144
# {"model_name": "model2",
135145
# "results": {
146+
# "__category_": "category2",
136147
# "AD1": "result3",
137148
# "AD2": "result4"
138149
# }

main.jl

Lines changed: 68 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import Test: @test, @testset
21
using DynamicPPL: DynamicPPL, VarInfo
32
using DynamicPPL.TestUtils.AD: run_ad, ADResult, ADIncorrectException
43
using ADTypes
5-
using Printf: @printf
64

75
import FiniteDifferences: central_fdm
86
import ForwardDiff
@@ -13,30 +11,26 @@ import Zygote
1311

1412
# AD backends to test.
1513
ADTYPES = Dict(
16-
"FiniteDifferences" => AutoFiniteDifferences(; fdm = central_fdm(5, 1)),
14+
"FiniteDifferences" => AutoFiniteDifferences(; fdm=central_fdm(5, 1)),
1715
"ForwardDiff" => AutoForwardDiff(),
18-
"ReverseDiff" => AutoReverseDiff(; compile = false),
19-
"ReverseDiffCompiled" => AutoReverseDiff(; compile = true),
20-
"Mooncake" => AutoMooncake(; config = nothing),
21-
"EnzymeForward" => AutoEnzyme(; mode = set_runtime_activity(Forward, true)),
22-
"EnzymeReverse" => AutoEnzyme(; mode = set_runtime_activity(Reverse, true)),
16+
"ReverseDiff" => AutoReverseDiff(; compile=false),
17+
"ReverseDiffCompiled" => AutoReverseDiff(; compile=true),
18+
"Mooncake" => AutoMooncake(; config=nothing),
19+
"EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward, true)),
20+
"EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse, true)),
2321
"Zygote" => AutoZygote(),
2422
)
2523

26-
# Models to test. The convention is that:
27-
# a, b, c, ... are assumed variables
28-
# x, y, z, ... are observed variables
29-
# although it's hardly a big deal.
30-
MODELS = Dict{String,DynamicPPL.Model}()
31-
macro register(model)
32-
:(MODELS[string($(esc(model)).f)] = $(esc(model)))
24+
MODELS = Dict{String,Tuple{String,DynamicPPL.Model}}()
25+
macro register(category, model)
26+
:(MODELS[string($(esc(model)).f)] = ($(esc(category)), $(esc(model))))
3327
end
3428

3529
"""
36-
include_model(model_name::AbstractString)
30+
include_model(category::AbstractString, model_name::AbstractString)
3731
3832
Add the model defined in `models/model_name.jl` to the full list of models
39-
tested in this script.
33+
tested in this script. The model is registered under the given `category`.
4034
4135
We want to isolate every model in its own module, so that we avoid e.g.
4236
variable clashes, and also so that we make imports explicit.
@@ -57,17 +51,17 @@ definition file:
5751
- Once defined, the model is created using `model = model_name(...)`. The
5852
`model` on the left-hand side is mandatory.
5953
"""
60-
macro include_model(model_name::AbstractString)
54+
macro include_model(category::AbstractString, model_name::AbstractString)
6155
MODELS_TO_LOAD = get(ENV, "ADTESTS_MODELS_TO_LOAD", "__all__")
6256
if MODELS_TO_LOAD == "__all__" || model_name in split(MODELS_TO_LOAD, ",")
6357
# Declare a module containing the model. In principle esc() shouldn't
6458
# be needed, but see https://github.com/JuliaLang/julia/issues/55677
6559
Expr(:toplevel, esc(:(
6660
module $(gensym())
67-
using .Main: @register
68-
using Turing
69-
include("models/" * $(model_name) * ".jl")
70-
@register model
61+
using .Main: @register
62+
using Turing
63+
include("models/" * $(model_name) * ".jl")
64+
@register $(category) model
7165
end
7266
)))
7367
else
@@ -76,63 +70,70 @@ macro include_model(model_name::AbstractString)
7670
end
7771
end
7872

79-
@include_model "assume_beta"
80-
@include_model "assume_dirichlet"
81-
@include_model "assume_lkjcholu"
82-
@include_model "assume_mvnormal"
83-
@include_model "assume_normal"
84-
@include_model "assume_submodel"
85-
@include_model "assume_wishart"
86-
@include_model "broadcast_macro"
87-
@include_model "control_flow"
88-
@include_model "demo_assume_dot_observe"
89-
@include_model "demo_assume_dot_observe_literal"
90-
@include_model "demo_assume_index_observe"
91-
@include_model "demo_assume_matrix_observe_matrix_index"
92-
@include_model "demo_assume_multivariate_observe"
93-
@include_model "demo_assume_multivariate_observe_literal"
94-
@include_model "demo_assume_observe_literal"
95-
@include_model "demo_assume_submodel_observe_index_literal"
96-
@include_model "demo_dot_assume_observe"
97-
@include_model "demo_dot_assume_observe_index"
98-
@include_model "demo_dot_assume_observe_index_literal"
99-
@include_model "demo_dot_assume_observe_matrix_index"
100-
@include_model "demo_dot_assume_observe_submodel"
101-
@include_model "dot_assume"
102-
@include_model "dot_observe"
103-
@include_model "dynamic_constraint"
104-
@include_model "multiple_constraints_same_var"
105-
@include_model "multithreaded"
106-
@include_model "n010"
107-
@include_model "n050"
108-
@include_model "n100"
109-
@include_model "n500"
110-
@include_model "observe_bernoulli"
111-
@include_model "observe_categorical"
112-
@include_model "observe_index"
113-
@include_model "observe_literal"
114-
@include_model "observe_multivariate"
115-
@include_model "observe_submodel"
116-
@include_model "pdb_eight_schools_centered"
117-
@include_model "pdb_eight_schools_noncentered"
118-
@include_model "von_mises"
73+
# Models to test. The convention is that:
74+
# a, b, c, ... are assumed variables
75+
# x, y, z, ... are observed variables
76+
# although it's hardly a big deal.
77+
@include_model "Base Julia features" "control_flow"
78+
@include_model "Base Julia features" "multithreaded"
79+
@include_model "Core Turing syntax" "broadcast_macro"
80+
@include_model "Core Turing syntax" "dot_assume"
81+
@include_model "Core Turing syntax" "dot_observe"
82+
@include_model "Core Turing syntax" "dynamic_constraint"
83+
@include_model "Core Turing syntax" "multiple_constraints_same_var"
84+
@include_model "Core Turing syntax" "observe_index"
85+
@include_model "Core Turing syntax" "observe_literal"
86+
@include_model "Core Turing syntax" "observe_multivariate"
87+
@include_model "Core Turing syntax" "observe_submodel"
88+
@include_model "Distributions" "assume_beta"
89+
@include_model "Distributions" "assume_dirichlet"
90+
@include_model "Distributions" "assume_lkjcholu"
91+
@include_model "Distributions" "assume_mvnormal"
92+
@include_model "Distributions" "assume_normal"
93+
@include_model "Distributions" "assume_submodel"
94+
@include_model "Distributions" "assume_wishart"
95+
@include_model "Distributions" "observe_bernoulli"
96+
@include_model "Distributions" "observe_categorical"
97+
@include_model "Distributions" "observe_von_mises"
98+
@include_model "DynamicPPL demo models" "demo_assume_dot_observe"
99+
@include_model "DynamicPPL demo models" "demo_assume_dot_observe_literal"
100+
@include_model "DynamicPPL demo models" "demo_assume_index_observe"
101+
@include_model "DynamicPPL demo models" "demo_assume_matrix_observe_matrix_index"
102+
@include_model "DynamicPPL demo models" "demo_assume_multivariate_observe"
103+
@include_model "DynamicPPL demo models" "demo_assume_multivariate_observe_literal"
104+
@include_model "DynamicPPL demo models" "demo_assume_observe_literal"
105+
@include_model "DynamicPPL demo models" "demo_assume_submodel_observe_index_literal"
106+
@include_model "DynamicPPL demo models" "demo_dot_assume_observe"
107+
@include_model "DynamicPPL demo models" "demo_dot_assume_observe_index"
108+
@include_model "DynamicPPL demo models" "demo_dot_assume_observe_index_literal"
109+
@include_model "DynamicPPL demo models" "demo_dot_assume_observe_matrix_index"
110+
@include_model "DynamicPPL demo models" "demo_dot_assume_observe_submodel"
111+
@include_model "Effect of model size" "n010"
112+
@include_model "Effect of model size" "n050"
113+
@include_model "Effect of model size" "n100"
114+
@include_model "Effect of model size" "n500"
115+
@include_model "PosteriorDB" "pdb_eight_schools_centered"
116+
@include_model "PosteriorDB" "pdb_eight_schools_noncentered"
119117

120118
# The entry point to this script itself begins here
121119
if ARGS == ["--list-model-keys"]
122120
foreach(println, sort(collect(keys(MODELS))))
123121
elseif ARGS == ["--list-adtype-keys"]
124122
foreach(println, sort(collect(keys(ADTYPES))))
123+
elseif length(ARGS) == 2 && ARGS[1] == "--get-category"
124+
println(MODELS[ARGS[2]][1])
125125
elseif length(ARGS) == 3 && ARGS[1] == "--run"
126-
model, adtype = MODELS[ARGS[2]], ADTYPES[ARGS[3]]
126+
model_name, adtype_name = ARGS[2], ARGS[3]
127+
model, adtype = MODELS[model_name][2], ADTYPES[adtype_name]
127128

128129
try
129-
if ARGS[2] == "control_flow"
130+
if model_name == "control_flow"
130131
# https://github.com/TuringLang/ADTests/issues/4
131132
vi = DynamicPPL.unflatten(VarInfo(model), [0.5, -0.5])
132133
params = [-0.5, 0.5]
133-
result = run_ad(model, adtype; varinfo = vi, params = params, benchmark = true)
134+
result = run_ad(model, adtype; varinfo=vi, params=params, benchmark=true)
134135
else
135-
result = run_ad(model, adtype; benchmark = true)
136+
result = run_ad(model, adtype; benchmark=true)
136137
end
137138
# If reached here - nothing went wrong
138139
println(result.time_vs_primal)
@@ -162,4 +163,5 @@ else
162163
println("Usage: julia main.jl --list-model-keys")
163164
println(" julia main.jl --list-adtype-keys")
164165
println(" julia main.jl --run <model> <adtype>")
166+
println(" julia main.jl --get-category <model>")
165167
end

models/observe_von_mises.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
@model function observe_von_mises(x)
2+
a ~ InverseGamma(2, 3)
3+
x ~ VonMises(0, a)
4+
end
5+
6+
model = observe_von_mises(0.4)

models/von_mises.jl

Lines changed: 0 additions & 6 deletions
This file was deleted.

0 commit comments

Comments
 (0)