Skip to content

Commit 1973b2f

Browse files
committed
isolate each model in its own module
1 parent 5ffd008 commit 1973b2f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+120
-84
lines changed

main.jl

Lines changed: 80 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,51 +32,87 @@ macro register(model)
3232
:(MODELS[string($(esc(model)).f)] = $(esc(model)))
3333
end
3434

35-
# These imports tend to get used a lot in models
36-
using DynamicPPL: @model, to_submodel
37-
using Distributions
38-
using LinearAlgebra
35+
"""
36+
include_model(model_name::AbstractString)
3937
40-
include("models/assume_beta.jl")
41-
include("models/assume_dirichlet.jl")
42-
include("models/assume_lkjcholu.jl")
43-
include("models/assume_mvnormal.jl")
44-
include("models/assume_normal.jl")
45-
include("models/assume_submodel.jl")
46-
include("models/assume_wishart.jl")
47-
include("models/broadcast_macro.jl")
48-
include("models/control_flow.jl")
49-
include("models/demo_assume_dot_observe.jl")
50-
include("models/demo_assume_dot_observe_literal.jl")
51-
include("models/demo_assume_index_observe.jl")
52-
include("models/demo_assume_matrix_observe_matrix_index.jl")
53-
include("models/demo_assume_multivariate_observe.jl")
54-
include("models/demo_assume_multivariate_observe_literal.jl")
55-
include("models/demo_assume_observe_literal.jl")
56-
include("models/demo_assume_submodel_observe_index_literal.jl")
57-
include("models/demo_dot_assume_observe.jl")
58-
include("models/demo_dot_assume_observe_index.jl")
59-
include("models/demo_dot_assume_observe_index_literal.jl")
60-
include("models/demo_dot_assume_observe_matrix_index.jl")
61-
include("models/demo_dot_assume_observe_submodel.jl")
62-
include("models/dot_assume.jl")
63-
include("models/dot_observe.jl")
64-
include("models/dynamic_constraint.jl")
65-
include("models/multiple_constraints_same_var.jl")
66-
include("models/multithreaded.jl")
67-
include("models/n010.jl")
68-
include("models/n050.jl")
69-
include("models/n100.jl")
70-
include("models/n500.jl")
71-
include("models/observe_bernoulli.jl")
72-
include("models/observe_categorical.jl")
73-
include("models/observe_index.jl")
74-
include("models/observe_literal.jl")
75-
include("models/observe_multivariate.jl")
76-
include("models/observe_submodel.jl")
77-
include("models/pdb_eight_schools_centered.jl")
78-
include("models/pdb_eight_schools_noncentered.jl")
79-
include("models/von_mises.jl")
38+
Add the model defined in `models/model_name.jl` to the full list of models
39+
tested in this script.
40+
41+
We want to isolate every model in its own module, so that we avoid e.g.
42+
variable clashes, and also so that we make imports explicit.
43+
44+
However, we don't want the _model files_ themselves to be cluttered with e.g.
45+
`module ... end` blocks as well as boring imports like @model,
46+
Distributions.jl, etc. which Turing re-exports by default anyway. This is
47+
because (a) it's boring and repetitive; and (b) the model definition shown on
48+
the website is exactly the contents of each file, so we would like to keep it
49+
as clean as possible.
50+
51+
To this end, instead of using `include(filename)` we write a small macro that
52+
does this for us. We require the following to be true for each model
53+
definition file:
54+
- The file is in the `models/` directory.
55+
- The file is named `model_name.jl`, where `model_name` is the name of the
56+
model, i.e. it's defined with `@model function model_name(...) ... end`.
57+
- Once defined, the model is created using `model = model_name(...)`. The
58+
`model` on the left-hand side is mandatory.
59+
"""
60+
function include_model(model_name::AbstractString)
61+
module_contents = quote
62+
using DynamicPPL: @model, to_submodel
63+
using Distributions
64+
using LinearAlgebra: I
65+
using .Main: @register
66+
include("models/" * $(model_name) * ".jl")
67+
@register model
68+
end
69+
module_name = Symbol("ADTests_", model_name)
70+
# Ideally, this would be a macro. However, defining a module in a macro is
71+
# either incredibly difficult or impossible. See
72+
# https://github.com/TuringLang/ADTests/issues/31
73+
eval(Expr(:module, true, module_name, module_contents))
74+
end
75+
76+
include_model("assume_beta")
77+
include_model("assume_dirichlet")
78+
include_model("assume_lkjcholu")
79+
include_model("assume_mvnormal")
80+
include_model("assume_normal")
81+
include_model("assume_submodel")
82+
include_model("assume_wishart")
83+
include_model("broadcast_macro")
84+
include_model("control_flow")
85+
include_model("demo_assume_dot_observe")
86+
include_model("demo_assume_dot_observe_literal")
87+
include_model("demo_assume_index_observe")
88+
include_model("demo_assume_matrix_observe_matrix_index")
89+
include_model("demo_assume_multivariate_observe")
90+
include_model("demo_assume_multivariate_observe_literal")
91+
include_model("demo_assume_observe_literal")
92+
include_model("demo_assume_submodel_observe_index_literal")
93+
include_model("demo_dot_assume_observe")
94+
include_model("demo_dot_assume_observe_index")
95+
include_model("demo_dot_assume_observe_index_literal")
96+
include_model("demo_dot_assume_observe_matrix_index")
97+
include_model("demo_dot_assume_observe_submodel")
98+
include_model("dot_assume")
99+
include_model("dot_observe")
100+
include_model("dynamic_constraint")
101+
include_model("multiple_constraints_same_var")
102+
include_model("multithreaded")
103+
include_model("n010")
104+
include_model("n050")
105+
include_model("n100")
106+
include_model("n500")
107+
include_model("observe_bernoulli")
108+
include_model("observe_categorical")
109+
include_model("observe_index")
110+
include_model("observe_literal")
111+
include_model("observe_multivariate")
112+
include_model("observe_submodel")
113+
include_model("pdb_eight_schools_centered")
114+
include_model("pdb_eight_schools_noncentered")
115+
include_model("von_mises")
80116

81117
# The entry point to this script itself begins here
82118
if ARGS == ["--list-model-keys"]

models/assume_beta.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
a ~ Beta(2, 2)
33
end
44

5-
@register assume_beta()
5+
model = assume_beta()

models/assume_dirichlet.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
a ~ Dirichlet([1.0, 5.0])
33
end
44

5-
@register assume_dirichlet()
5+
model = assume_dirichlet()

models/assume_lkjcholu.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
a ~ LKJCholesky(5, 1.0, 'U')
33
end
44

5-
@register assume_lkjcholu()
5+
model = assume_lkjcholu()

models/assume_mvnormal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
a ~ MvNormal([0.0, 0.0], [1.0 0.5; 0.5 1.0])
33
end
44

5-
@register assume_mvnormal()
5+
model = assume_mvnormal()

models/assume_normal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
a ~ Normal()
33
end
44

5-
@register assume_normal()
5+
model = assume_normal()

models/assume_submodel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ end
66
x ~ Normal(a)
77
end
88

9-
@register assume_submodel()
9+
model = assume_submodel()

models/assume_wishart.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
a ~ Wishart(7, [1.0 0.5; 0.5 1.0])
33
end
44

5-
@register assume_wishart()
5+
model = assume_wishart()

models/broadcast_macro.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77
@. x ~ Normal(a, $(sqrt(b)))
88
end
99

10-
@register broadcast_macro()
10+
model = broadcast_macro()

models/control_flow.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@ evaluated at a value of `a < 0`. See `main.jl` for more information.
1818
end
1919
end
2020

21-
@register control_flow()
21+
model = control_flow()

0 commit comments

Comments
 (0)