Skip to content

Commit 44590ab

Browse files
committed
Separate models into one-per-file
1 parent d26838c commit 44590ab

28 files changed

+248
-234
lines changed

README.md

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,11 @@ You can modify the list of AD types in `main.jl`.
1212

1313
## I want to add more models!
1414

15-
You can modify the list of models in `models.jl`.
15+
You can modify the list of models by adding a new file inside the `models` directory.
16+
This file should contain the model definition, and call the `@register` macro to register the model with the `ADTests` package.
17+
See the existing files in that directory for examples.
1618

17-
Note that if you want the model definition to be shown on the website, your model definitions must be of the 'standard' form `@model function name() ... end`.
18-
This means that:
19-
- One-liner function definitions like `@model f(x) = ...`will not work.
20-
- Fancy metaprogramming tricks to generate a family of models at one go (like [this old code](https://github.com/TuringLang/ADTests/blob/266d7ab85fea2e01e7e05af6cee179d7f6200b0f/models.jl#L108-L129)) will not work.
21-
22-
To understand why, see the `get_model_definition` function in `ad.py`.
19+
To make sure that the definition is included in the final website, you will have to make sure that the filename is the same as the model name (i.e. a model `@model function f()` is in `models/f.jl`).
2320

2421
## I want to edit the website!
2522

ad.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -102,26 +102,8 @@ def run_ad(args):
102102

103103

104104
def get_model_definition(model_key):
105-
"""Get the model definition from the Julia script."""
106-
lines = []
107-
submodels = []
108-
record = False
109-
with open("models.jl", "r") as file:
110-
for line in file:
111-
line = line.rstrip()
112-
if line.startswith(f"@model function {model_key}"):
113-
record = True
114-
if record:
115-
lines.append(line)
116-
117-
if "to_submodel" in line:
118-
submodel_name = line.split("to_submodel(")[1].split("(")[0]
119-
submodels.append(submodel_name)
120-
if line == "end":
121-
break
122-
for submodel in submodels:
123-
lines = [get_model_definition(submodel), *lines]
124-
return "\n".join(lines)
105+
"""Get the model definition from the file that contains it."""
106+
return Path(f"models/{model_key}.jl").read_text().strip()
125107

126108

127109
def html(_args):

main.jl

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,44 @@ ADTYPES = Dict(
2323
"Zygote" => AutoZygote(),
2424
)
2525

26-
# Models to test.
27-
include("models.jl")
28-
using .Models: MODELS
26+
# Models to test. The convention is that:
27+
# a, b, c, ... are assumed variables
28+
# x, y, z, ... are observed variables
29+
# although it's hardly a big deal.
30+
MODELS = Dict{String,DynamicPPL.Model}()
31+
macro register(model)
32+
:(MODELS[string($(esc(model)).f)] = $(esc(model)))
33+
end
34+
35+
# These imports tend to get used a lot in models
36+
using DynamicPPL: @model
37+
using Distributions
38+
using LinearAlgebra: I
39+
40+
include("models/assume_dirichlet.jl")
41+
include("models/assume_lkjcholu.jl")
42+
include("models/assume_mvnormal.jl")
43+
include("models/assume_normal.jl")
44+
include("models/assume_submodel.jl")
45+
include("models/assume_wishart.jl")
46+
include("models/control_flow.jl")
47+
include("models/dot_assume_observe_index.jl")
48+
include("models/dot_assume.jl")
49+
include("models/dot_observe.jl")
50+
include("models/dynamic_constraint.jl")
51+
include("models/models.jl")
52+
include("models/multiple_constraints_same_var.jl")
53+
include("models/multithreaded.jl")
54+
include("models/n010.jl")
55+
include("models/n050.jl")
56+
include("models/n100.jl")
57+
include("models/n500.jl")
58+
include("models/observe_index.jl")
59+
include("models/observe_literal.jl")
60+
include("models/observe_multivariate.jl")
61+
include("models/observe_submodel.jl")
62+
include("models/pdb_eight_schools_centered.jl")
63+
include("models/pdb_eight_schools_noncentered.jl")
2964

3065
# The entry point to this script itself begins here
3166
if ARGS == ["--list-model-keys"]

models.jl

Lines changed: 0 additions & 204 deletions
This file was deleted.

models/assume_beta.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
@model function assume_beta()
2+
a ~ Beta(2, 2)
3+
end
4+
5+
@register assume_beta()

models/assume_dirichlet.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
@model function assume_dirichlet()
2+
a ~ Dirichlet([1.0, 5.0])
3+
end
4+
5+
@register assume_dirichlet()

models/assume_lkjcholu.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
@model function assume_lkjcholu()
2+
a ~ LKJCholesky(5, 1.0, 'U')
3+
end
4+
5+
@register assume_lkjcholu()

models/assume_mvnormal.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
@model function assume_mvnormal()
2+
a ~ MvNormal([0.0, 0.0], [1.0 0.5; 0.5 1.0])
3+
end
4+
5+
@register assume_mvnormal()

models/assume_normal.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
@model function assume_normal()
2+
a ~ Normal()
3+
end
4+
5+
@register assume_normal()

models/assume_submodel.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
@model function inner1()
2+
return a ~ Normal()
3+
end
4+
@model function assume_submodel()
5+
a ~ to_submodel(inner1())
6+
x ~ Normal(a)
7+
end
8+
9+
@register assume_submodel()

0 commit comments

Comments
 (0)