Skip to content

Commit 6e88b49

Browse files
authored
Merge pull request #32 from TuringLang/py/modularise
isolate each model in its own module
2 parents 5ffd008 + 336fa78 commit 6e88b49

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+140
-85
lines changed

.github/workflows/generate_website.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ jobs:
8787
run: uv run ad.py run --model ${{ matrix.model }}
8888
env:
8989
ADTYPE_KEYS: ${{ needs.setup-keys.outputs.adtype_keys }}
90+
ADTESTS_MODELS_TO_LOAD: ${{ matrix.model }}
9091

9192
- name: Output matrix values
9293
id: output-matrix

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1515
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
16+
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
1617
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1718

1819
[compat]

main.jl

Lines changed: 83 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,51 +32,90 @@ macro register(model)
3232
:(MODELS[string($(esc(model)).f)] = $(esc(model)))
3333
end
3434

35-
# These imports tend to get used a lot in models
36-
using DynamicPPL: @model, to_submodel
37-
using Distributions
38-
using LinearAlgebra
35+
"""
36+
include_model(model_name::AbstractString)
3937
40-
include("models/assume_beta.jl")
41-
include("models/assume_dirichlet.jl")
42-
include("models/assume_lkjcholu.jl")
43-
include("models/assume_mvnormal.jl")
44-
include("models/assume_normal.jl")
45-
include("models/assume_submodel.jl")
46-
include("models/assume_wishart.jl")
47-
include("models/broadcast_macro.jl")
48-
include("models/control_flow.jl")
49-
include("models/demo_assume_dot_observe.jl")
50-
include("models/demo_assume_dot_observe_literal.jl")
51-
include("models/demo_assume_index_observe.jl")
52-
include("models/demo_assume_matrix_observe_matrix_index.jl")
53-
include("models/demo_assume_multivariate_observe.jl")
54-
include("models/demo_assume_multivariate_observe_literal.jl")
55-
include("models/demo_assume_observe_literal.jl")
56-
include("models/demo_assume_submodel_observe_index_literal.jl")
57-
include("models/demo_dot_assume_observe.jl")
58-
include("models/demo_dot_assume_observe_index.jl")
59-
include("models/demo_dot_assume_observe_index_literal.jl")
60-
include("models/demo_dot_assume_observe_matrix_index.jl")
61-
include("models/demo_dot_assume_observe_submodel.jl")
62-
include("models/dot_assume.jl")
63-
include("models/dot_observe.jl")
64-
include("models/dynamic_constraint.jl")
65-
include("models/multiple_constraints_same_var.jl")
66-
include("models/multithreaded.jl")
67-
include("models/n010.jl")
68-
include("models/n050.jl")
69-
include("models/n100.jl")
70-
include("models/n500.jl")
71-
include("models/observe_bernoulli.jl")
72-
include("models/observe_categorical.jl")
73-
include("models/observe_index.jl")
74-
include("models/observe_literal.jl")
75-
include("models/observe_multivariate.jl")
76-
include("models/observe_submodel.jl")
77-
include("models/pdb_eight_schools_centered.jl")
78-
include("models/pdb_eight_schools_noncentered.jl")
79-
include("models/von_mises.jl")
38+
Add the model defined in `models/model_name.jl` to the full list of models
39+
tested in this script.
40+
41+
We want to isolate every model in its own module, so that we avoid e.g.
42+
variable clashes, and also so that we make imports explicit.
43+
44+
However, we don't want the _model files_ themselves to be cluttered with e.g.
45+
`module ... end` blocks as well as boring imports like @model,
46+
Distributions.jl, etc. which Turing re-exports by default anyway. This is
47+
because (a) it's boring and repetitive; and (b) the model definition shown on
48+
the website is exactly the contents of each file, so we would like to keep it
49+
as clean as possible.
50+
51+
To this end, instead of using `include(filename)` we write a small macro that
52+
does this for us. We require the following to be true for each model
53+
definition file:
54+
- The file is in the `models/` directory.
55+
- The file is named `model_name.jl`, where `model_name` is the name of the
56+
model, i.e. it's defined with `@model function model_name(...) ... end`.
57+
- Once defined, the model is created using `model = model_name(...)`. The
58+
`model` on the left-hand side is mandatory.
59+
"""
60+
macro include_model(model_name::AbstractString)
61+
MODELS_TO_LOAD = get(ENV, "ADTESTS_MODELS_TO_LOAD", "__all__")
62+
if MODELS_TO_LOAD == "__all__" || model_name in split(MODELS_TO_LOAD, ",")
63+
# Declare a module containing the model. In principle esc() shouldn't
64+
# be needed, but see https://github.com/JuliaLang/julia/issues/55677
65+
Expr(:toplevel, esc(:(
66+
module $(gensym())
67+
using .Main: @register
68+
using Turing
69+
include("models/" * $(model_name) * ".jl")
70+
@register model
71+
end
72+
)))
73+
else
74+
# Empty expression
75+
:()
76+
end
77+
end
78+
79+
@include_model "assume_beta"
80+
@include_model "assume_dirichlet"
81+
@include_model "assume_lkjcholu"
82+
@include_model "assume_mvnormal"
83+
@include_model "assume_normal"
84+
@include_model "assume_submodel"
85+
@include_model "assume_wishart"
86+
@include_model "broadcast_macro"
87+
@include_model "control_flow"
88+
@include_model "demo_assume_dot_observe"
89+
@include_model "demo_assume_dot_observe_literal"
90+
@include_model "demo_assume_index_observe"
91+
@include_model "demo_assume_matrix_observe_matrix_index"
92+
@include_model "demo_assume_multivariate_observe"
93+
@include_model "demo_assume_multivariate_observe_literal"
94+
@include_model "demo_assume_observe_literal"
95+
@include_model "demo_assume_submodel_observe_index_literal"
96+
@include_model "demo_dot_assume_observe"
97+
@include_model "demo_dot_assume_observe_index"
98+
@include_model "demo_dot_assume_observe_index_literal"
99+
@include_model "demo_dot_assume_observe_matrix_index"
100+
@include_model "demo_dot_assume_observe_submodel"
101+
@include_model "dot_assume"
102+
@include_model "dot_observe"
103+
@include_model "dynamic_constraint"
104+
@include_model "multiple_constraints_same_var"
105+
@include_model "multithreaded"
106+
@include_model "n010"
107+
@include_model "n050"
108+
@include_model "n100"
109+
@include_model "n500"
110+
@include_model "observe_bernoulli"
111+
@include_model "observe_categorical"
112+
@include_model "observe_index"
113+
@include_model "observe_literal"
114+
@include_model "observe_multivariate"
115+
@include_model "observe_submodel"
116+
@include_model "pdb_eight_schools_centered"
117+
@include_model "pdb_eight_schools_noncentered"
118+
@include_model "von_mises"
80119

81120
# The entry point to this script itself begins here
82121
if ARGS == ["--list-model-keys"]

models/assume_beta.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
a ~ Beta(2, 2)
33
end
44

5-
@register assume_beta()
5+
model = assume_beta()

models/assume_dirichlet.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
a ~ Dirichlet([1.0, 5.0])
33
end
44

5-
@register assume_dirichlet()
5+
model = assume_dirichlet()

models/assume_lkjcholu.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
a ~ LKJCholesky(5, 1.0, 'U')
33
end
44

5-
@register assume_lkjcholu()
5+
model = assume_lkjcholu()

models/assume_mvnormal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
a ~ MvNormal([0.0, 0.0], [1.0 0.5; 0.5 1.0])
33
end
44

5-
@register assume_mvnormal()
5+
model = assume_mvnormal()

models/assume_normal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
a ~ Normal()
33
end
44

5-
@register assume_normal()
5+
model = assume_normal()

models/assume_submodel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ end
66
x ~ Normal(a)
77
end
88

9-
@register assume_submodel()
9+
model = assume_submodel()

models/assume_wishart.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
a ~ Wishart(7, [1.0 0.5; 0.5 1.0])
33
end
44

5-
@register assume_wishart()
5+
model = assume_wishart()

0 commit comments

Comments
 (0)