diff --git a/.github/workflows/generate_website.yml b/.github/workflows/generate_website.yml index 94e51e1..2218856 100644 --- a/.github/workflows/generate_website.yml +++ b/.github/workflows/generate_website.yml @@ -87,6 +87,7 @@ jobs: run: uv run ad.py run --model ${{ matrix.model }} env: ADTYPE_KEYS: ${{ needs.setup-keys.outputs.adtype_keys }} + ADTESTS_MODELS_TO_LOAD: ${{ matrix.model }} - name: Output matrix values id: output-matrix diff --git a/Project.toml b/Project.toml index c4b461f..af96777 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] diff --git a/main.jl b/main.jl index eca664e..7faffc9 100644 --- a/main.jl +++ b/main.jl @@ -32,51 +32,90 @@ macro register(model) :(MODELS[string($(esc(model)).f)] = $(esc(model))) end -# These imports tend to get used a lot in models -using DynamicPPL: @model, to_submodel -using Distributions -using LinearAlgebra +""" + include_model(model_name::AbstractString) -include("models/assume_beta.jl") -include("models/assume_dirichlet.jl") -include("models/assume_lkjcholu.jl") -include("models/assume_mvnormal.jl") -include("models/assume_normal.jl") -include("models/assume_submodel.jl") -include("models/assume_wishart.jl") -include("models/broadcast_macro.jl") -include("models/control_flow.jl") -include("models/demo_assume_dot_observe.jl") -include("models/demo_assume_dot_observe_literal.jl") -include("models/demo_assume_index_observe.jl") -include("models/demo_assume_matrix_observe_matrix_index.jl") -include("models/demo_assume_multivariate_observe.jl") -include("models/demo_assume_multivariate_observe_literal.jl") -include("models/demo_assume_observe_literal.jl") -include("models/demo_assume_submodel_observe_index_literal.jl") -include("models/demo_dot_assume_observe.jl") -include("models/demo_dot_assume_observe_index.jl") -include("models/demo_dot_assume_observe_index_literal.jl") -include("models/demo_dot_assume_observe_matrix_index.jl") -include("models/demo_dot_assume_observe_submodel.jl") -include("models/dot_assume.jl") -include("models/dot_observe.jl") -include("models/dynamic_constraint.jl") -include("models/multiple_constraints_same_var.jl") -include("models/multithreaded.jl") -include("models/n010.jl") -include("models/n050.jl") -include("models/n100.jl") -include("models/n500.jl") -include("models/observe_bernoulli.jl") -include("models/observe_categorical.jl") -include("models/observe_index.jl") -include("models/observe_literal.jl") -include("models/observe_multivariate.jl") -include("models/observe_submodel.jl") -include("models/pdb_eight_schools_centered.jl") -include("models/pdb_eight_schools_noncentered.jl") -include("models/von_mises.jl") +Add the model defined in `models/model_name.jl` to the full list of models +tested in this script. + +We want to isolate every model in its own module, so that we avoid e.g. +variable clashes, and also so that we make imports explicit. + +However, we don't want the _model files_ themselves to be cluttered with e.g. +`module ... end` blocks as well as boring imports like @model, +Distributions.jl, etc. which Turing re-exports by default anyway. This is +because (a) it's boring and repetitive; and (b) the model definition shown on +the website is exactly the contents of each file, so we would like to keep it +as clean as possible. + +To this end, instead of using `include(filename)` we write a small macro that +does this for us. We require the following to be true for each model +definition file: + - The file is in the `models/` directory. + - The file is named `model_name.jl`, where `model_name` is the name of the + model, i.e. it's defined with `@model function model_name(...) ... end`. + - Once defined, the model is created using `model = model_name(...)`. The + `model` on the left-hand side is mandatory. +""" +macro include_model(model_name::AbstractString) + MODELS_TO_LOAD = get(ENV, "ADTESTS_MODELS_TO_LOAD", "__all__") + if MODELS_TO_LOAD == "__all__" || model_name in split(MODELS_TO_LOAD, ",") + # Declare a module containing the model. In principle esc() shouldn't + # be needed, but see https://github.com/JuliaLang/julia/issues/55677 + Expr(:toplevel, esc(:( + module $(gensym()) + using .Main: @register + using Turing + include("models/" * $(model_name) * ".jl") + @register model + end + ))) + else + # Empty expression + :() + end +end + +@include_model "assume_beta" +@include_model "assume_dirichlet" +@include_model "assume_lkjcholu" +@include_model "assume_mvnormal" +@include_model "assume_normal" +@include_model "assume_submodel" +@include_model "assume_wishart" +@include_model "broadcast_macro" +@include_model "control_flow" +@include_model "demo_assume_dot_observe" +@include_model "demo_assume_dot_observe_literal" +@include_model "demo_assume_index_observe" +@include_model "demo_assume_matrix_observe_matrix_index" +@include_model "demo_assume_multivariate_observe" +@include_model "demo_assume_multivariate_observe_literal" +@include_model "demo_assume_observe_literal" +@include_model "demo_assume_submodel_observe_index_literal" +@include_model "demo_dot_assume_observe" +@include_model "demo_dot_assume_observe_index" +@include_model "demo_dot_assume_observe_index_literal" +@include_model "demo_dot_assume_observe_matrix_index" +@include_model "demo_dot_assume_observe_submodel" +@include_model "dot_assume" +@include_model "dot_observe" +@include_model "dynamic_constraint" +@include_model "multiple_constraints_same_var" +@include_model "multithreaded" +@include_model "n010" +@include_model "n050" +@include_model "n100" +@include_model "n500" +@include_model "observe_bernoulli" +@include_model "observe_categorical" +@include_model "observe_index" +@include_model "observe_literal" +@include_model "observe_multivariate" +@include_model "observe_submodel" +@include_model "pdb_eight_schools_centered" +@include_model "pdb_eight_schools_noncentered" +@include_model "von_mises" # The entry point to this script itself begins here if ARGS == ["--list-model-keys"] diff --git a/models/assume_beta.jl b/models/assume_beta.jl index becbf50..760ac89 100644 --- a/models/assume_beta.jl +++ b/models/assume_beta.jl @@ -2,4 +2,4 @@ a ~ Beta(2, 2) end -@register assume_beta() +model = assume_beta() diff --git a/models/assume_dirichlet.jl b/models/assume_dirichlet.jl index e859024..6727381 100644 --- a/models/assume_dirichlet.jl +++ b/models/assume_dirichlet.jl @@ -2,4 +2,4 @@ a ~ Dirichlet([1.0, 5.0]) end -@register assume_dirichlet() +model = assume_dirichlet() diff --git a/models/assume_lkjcholu.jl b/models/assume_lkjcholu.jl index 9e33b35..c0949ac 100644 --- a/models/assume_lkjcholu.jl +++ b/models/assume_lkjcholu.jl @@ -2,4 +2,4 @@ a ~ LKJCholesky(5, 1.0, 'U') end -@register assume_lkjcholu() +model = assume_lkjcholu() diff --git a/models/assume_mvnormal.jl b/models/assume_mvnormal.jl index 13b895a..fa2b7a4 100644 --- a/models/assume_mvnormal.jl +++ b/models/assume_mvnormal.jl @@ -2,4 +2,4 @@ a ~ MvNormal([0.0, 0.0], [1.0 0.5; 0.5 1.0]) end -@register assume_mvnormal() +model = assume_mvnormal() diff --git a/models/assume_normal.jl b/models/assume_normal.jl index 809ff2f..4c9b8d6 100644 --- a/models/assume_normal.jl +++ b/models/assume_normal.jl @@ -2,4 +2,4 @@ a ~ Normal() end -@register assume_normal() +model = assume_normal() diff --git a/models/assume_submodel.jl b/models/assume_submodel.jl index b8440ab..eb1a1b7 100644 --- a/models/assume_submodel.jl +++ b/models/assume_submodel.jl @@ -6,4 +6,4 @@ end x ~ Normal(a) end -@register assume_submodel() +model = assume_submodel() diff --git a/models/assume_wishart.jl b/models/assume_wishart.jl index d8373b7..8e1d722 100644 --- a/models/assume_wishart.jl +++ b/models/assume_wishart.jl @@ -2,4 +2,4 @@ a ~ Wishart(7, [1.0 0.5; 0.5 1.0]) end -@register assume_wishart() +model = assume_wishart() diff --git a/models/broadcast_macro.jl b/models/broadcast_macro.jl index 5b104ed..e1e8b12 100644 --- a/models/broadcast_macro.jl +++ b/models/broadcast_macro.jl @@ -7,4 +7,4 @@ @. x ~ Normal(a, $(sqrt(b))) end -@register broadcast_macro() +model = broadcast_macro() diff --git a/models/control_flow.jl b/models/control_flow.jl index 55f5329..2415215 100644 --- a/models/control_flow.jl +++ b/models/control_flow.jl @@ -18,4 +18,4 @@ evaluated at a value of `a < 0`. See `main.jl` for more information. end end -@register control_flow() +model = control_flow() diff --git a/models/demo_assume_dot_observe.jl b/models/demo_assume_dot_observe.jl index 0b83afa..51f2787 100644 --- a/models/demo_assume_dot_observe.jl +++ b/models/demo_assume_dot_observe.jl @@ -5,4 +5,4 @@ x .~ Normal(m, sqrt(s)) end -@register demo_assume_dot_observe() +model = demo_assume_dot_observe() diff --git a/models/demo_assume_dot_observe_literal.jl b/models/demo_assume_dot_observe_literal.jl index 304fc13..ef6553e 100644 --- a/models/demo_assume_dot_observe_literal.jl +++ b/models/demo_assume_dot_observe_literal.jl @@ -5,4 +5,4 @@ [1.5, 2.0] .~ Normal(m, sqrt(s)) end -@register demo_assume_dot_observe_literal() +model = demo_assume_dot_observe_literal() diff --git a/models/demo_assume_index_observe.jl b/models/demo_assume_index_observe.jl index 66d2570..bba7213 100644 --- a/models/demo_assume_index_observe.jl +++ b/models/demo_assume_index_observe.jl @@ -1,3 +1,5 @@ +using LinearAlgebra: Diagonal + @model function demo_assume_index_observe( x = [1.5, 2.0], ::Type{TV} = Vector{Float64}, @@ -14,4 +16,4 @@ x ~ MvNormal(m, Diagonal(s)) end -@register demo_assume_index_observe() +model = demo_assume_index_observe() diff --git a/models/demo_assume_matrix_observe_matrix_index.jl b/models/demo_assume_matrix_observe_matrix_index.jl index 0b6c58e..2900429 100644 --- a/models/demo_assume_matrix_observe_matrix_index.jl +++ b/models/demo_assume_matrix_observe_matrix_index.jl @@ -1,3 +1,5 @@ +using LinearAlgebra: Diagonal + @model function demo_assume_matrix_observe_matrix_index( x = transpose([1.5 2.0;]), ::Type{TV} = Array{Float64}, @@ -7,8 +9,7 @@ s ~ reshape(product_distribution(fill(InverseGamma(2, 3), n)), d, 2) s_vec = vec(s) m ~ MvNormal(zeros(n), Diagonal(s_vec)) - x[:, 1] ~ MvNormal(m, Diagonal(s_vec)) end -@register demo_assume_matrix_observe_matrix_index() +model = demo_assume_matrix_observe_matrix_index() diff --git a/models/demo_assume_multivariate_observe.jl b/models/demo_assume_multivariate_observe.jl index 05248c3..9ad7adc 100644 --- a/models/demo_assume_multivariate_observe.jl +++ b/models/demo_assume_multivariate_observe.jl @@ -1,7 +1,10 @@ +using LinearAlgebra: Diagonal + @model function demo_assume_multivariate_observe(x = [1.5, 2.0]) # Multivariate `assume` and `observe` s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) m ~ MvNormal(zero(x), Diagonal(s)) x ~ MvNormal(m, Diagonal(s)) end -@register demo_assume_multivariate_observe() + +model = demo_assume_multivariate_observe() diff --git a/models/demo_assume_multivariate_observe_literal.jl b/models/demo_assume_multivariate_observe_literal.jl index cbebdd8..adfa394 100644 --- a/models/demo_assume_multivariate_observe_literal.jl +++ b/models/demo_assume_multivariate_observe_literal.jl @@ -1,3 +1,5 @@ +using LinearAlgebra: Diagonal + @model function demo_assume_multivariate_observe_literal() # multivariate `assume` and literal `observe` s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) @@ -5,4 +7,4 @@ [1.5, 2.0] ~ MvNormal(m, Diagonal(s)) end -@register demo_assume_multivariate_observe_literal() +model = demo_assume_multivariate_observe_literal() diff --git a/models/demo_assume_observe_literal.jl b/models/demo_assume_observe_literal.jl index 262ae49..fcca461 100644 --- a/models/demo_assume_observe_literal.jl +++ b/models/demo_assume_observe_literal.jl @@ -6,4 +6,4 @@ 2.0 ~ Normal(m, sqrt(s)) end -@register demo_assume_observe_literal() +model = demo_assume_observe_literal() diff --git a/models/demo_assume_submodel_observe_index_literal.jl b/models/demo_assume_submodel_observe_index_literal.jl index d277ea4..f6f11b0 100644 --- a/models/demo_assume_submodel_observe_index_literal.jl +++ b/models/demo_assume_submodel_observe_index_literal.jl @@ -14,4 +14,4 @@ end 2.0 ~ Normal(m[2], sqrt(s[2])) end -@register demo_assume_submodel_observe_index_literal() +model = demo_assume_submodel_observe_index_literal() diff --git a/models/demo_dot_assume_observe.jl b/models/demo_dot_assume_observe.jl index 3225cf9..1ba46b0 100644 --- a/models/demo_dot_assume_observe.jl +++ b/models/demo_dot_assume_observe.jl @@ -1,3 +1,5 @@ +using LinearAlgebra: Diagonal + @model function demo_dot_assume_observe( x = [1.5, 2.0], ::Type{TV} = Vector{Float64}, @@ -10,4 +12,4 @@ x ~ MvNormal(m, Diagonal(s)) end -@register demo_dot_assume_observe() +model = demo_dot_assume_observe() diff --git a/models/demo_dot_assume_observe_index.jl b/models/demo_dot_assume_observe_index.jl index 12e5cbe..9c4d122 100644 --- a/models/demo_dot_assume_observe_index.jl +++ b/models/demo_dot_assume_observe_index.jl @@ -12,4 +12,4 @@ end end -@register demo_dot_assume_observe_index() +model = demo_dot_assume_observe_index() diff --git a/models/demo_dot_assume_observe_index_literal.jl b/models/demo_dot_assume_observe_index_literal.jl index 7eb93dc..713e3ff 100644 --- a/models/demo_dot_assume_observe_index_literal.jl +++ b/models/demo_dot_assume_observe_index_literal.jl @@ -11,4 +11,4 @@ 2.0 ~ Normal(m[2], sqrt(s[2])) end -@register demo_dot_assume_observe_index_literal() +model = demo_dot_assume_observe_index_literal() diff --git a/models/demo_dot_assume_observe_matrix_index.jl b/models/demo_dot_assume_observe_matrix_index.jl index 70c6995..6774da3 100644 --- a/models/demo_dot_assume_observe_matrix_index.jl +++ b/models/demo_dot_assume_observe_matrix_index.jl @@ -1,3 +1,5 @@ +using LinearAlgebra: Diagonal + @model function demo_dot_assume_observe_matrix_index( x = transpose([1.5 2.0;]), ::Type{TV} = Vector{Float64}, @@ -9,4 +11,4 @@ x[:, 1] ~ MvNormal(m, Diagonal(s)) end -@register demo_dot_assume_observe_matrix_index() +model = demo_dot_assume_observe_matrix_index() diff --git a/models/demo_dot_assume_observe_submodel.jl b/models/demo_dot_assume_observe_submodel.jl index f48d1e4..2a81a4d 100644 --- a/models/demo_dot_assume_observe_submodel.jl +++ b/models/demo_dot_assume_observe_submodel.jl @@ -1,3 +1,5 @@ +using LinearAlgebra: Diagonal + @model function _likelihood_multivariate_observe(s, m, x) return x ~ MvNormal(m, Diagonal(s)) end @@ -17,4 +19,4 @@ end _ignore ~ to_submodel(_likelihood_multivariate_observe(s, m, x)) end -@register demo_dot_assume_observe_submodel() +model = demo_dot_assume_observe_submodel() diff --git a/models/dot_assume.jl b/models/dot_assume.jl index 5638855..2819919 100644 --- a/models/dot_assume.jl +++ b/models/dot_assume.jl @@ -3,4 +3,4 @@ a .~ Normal() end -@register dot_assume() +model = dot_assume() diff --git a/models/dot_observe.jl b/models/dot_observe.jl index 8136ca5..acf255d 100644 --- a/models/dot_observe.jl +++ b/models/dot_observe.jl @@ -3,4 +3,4 @@ x .~ Normal(a) end -@register dot_observe() +model = dot_observe() diff --git a/models/dynamic_constraint.jl b/models/dynamic_constraint.jl index 3640367..8e831ab 100644 --- a/models/dynamic_constraint.jl +++ b/models/dynamic_constraint.jl @@ -3,4 +3,4 @@ b ~ truncated(Normal(); lower = a) end -@register dynamic_constraint() +model = dynamic_constraint() diff --git a/models/multiple_constraints_same_var.jl b/models/multiple_constraints_same_var.jl index 15fefdd..ae09bc0 100644 --- a/models/multiple_constraints_same_var.jl +++ b/models/multiple_constraints_same_var.jl @@ -6,4 +6,4 @@ x[4:5] ~ Dirichlet([1.0, 2.0]) end -@register multiple_constraints_same_var() +model = multiple_constraints_same_var() diff --git a/models/multithreaded.jl b/models/multithreaded.jl index c31b7a6..e54ea4c 100644 --- a/models/multithreaded.jl +++ b/models/multithreaded.jl @@ -11,4 +11,4 @@ statements. See `main.jl` for more information. end end -@register multithreaded([1.5, 2.0, 2.5, 1.5, 2.0, 2.5]) +model = multithreaded([1.5, 2.0, 2.5, 1.5, 2.0, 2.5]) diff --git a/models/n010.jl b/models/n010.jl index 98c20de..998a6ab 100644 --- a/models/n010.jl +++ b/models/n010.jl @@ -5,4 +5,4 @@ end end -@register n010() +model = n010() diff --git a/models/n050.jl b/models/n050.jl index 6979cae..2bdffe2 100644 --- a/models/n050.jl +++ b/models/n050.jl @@ -5,4 +5,4 @@ end end -@register n050() +model = n050() diff --git a/models/n100.jl b/models/n100.jl index ab17df1..f47587d 100644 --- a/models/n100.jl +++ b/models/n100.jl @@ -5,4 +5,4 @@ end end -@register n100() +model = n100() diff --git a/models/n500.jl b/models/n500.jl index 339b75b..5194bce 100644 --- a/models/n500.jl +++ b/models/n500.jl @@ -5,4 +5,4 @@ end end -@register n500() +model = n500() diff --git a/models/observe_bernoulli.jl b/models/observe_bernoulli.jl index 4071612..5abd93c 100644 --- a/models/observe_bernoulli.jl +++ b/models/observe_bernoulli.jl @@ -5,4 +5,4 @@ end end -@register observe_bernoulli() +model = observe_bernoulli() diff --git a/models/observe_categorical.jl b/models/observe_categorical.jl index ac3c402..e594c62 100644 --- a/models/observe_categorical.jl +++ b/models/observe_categorical.jl @@ -5,4 +5,4 @@ end end -@register observe_categorical() +model = observe_categorical() diff --git a/models/observe_index.jl b/models/observe_index.jl index f58d980..d7168e6 100644 --- a/models/observe_index.jl +++ b/models/observe_index.jl @@ -5,4 +5,4 @@ end end -@register observe_index() +model = observe_index() diff --git a/models/observe_literal.jl b/models/observe_literal.jl index e208cae..0fa1403 100644 --- a/models/observe_literal.jl +++ b/models/observe_literal.jl @@ -3,4 +3,4 @@ 1.5 ~ Normal(a) end -@register observe_literal() +model = observe_literal() diff --git a/models/observe_multivariate.jl b/models/observe_multivariate.jl index c064486..a1bf8e0 100644 --- a/models/observe_multivariate.jl +++ b/models/observe_multivariate.jl @@ -7,4 +7,4 @@ x ~ MvNormal(a, I) end -@register observe_multivariate() +model = observe_multivariate() diff --git a/models/observe_submodel.jl b/models/observe_submodel.jl index 1cf02cd..f8bfcb9 100644 --- a/models/observe_submodel.jl +++ b/models/observe_submodel.jl @@ -6,4 +6,4 @@ end _ignore ~ to_submodel(inner2(x, a)) end -@register observe_submodel() +model = observe_submodel() diff --git a/models/pdb_eight_schools_centered.jl b/models/pdb_eight_schools_centered.jl index c265429..17339eb 100644 --- a/models/pdb_eight_schools_centered.jl +++ b/models/pdb_eight_schools_centered.jl @@ -12,4 +12,4 @@ sigma = [15, 10, 16, 11, 9, 11, 10, 18] end end -@register pdb_eight_schools_centered(J, y, sigma) +model = pdb_eight_schools_centered(J, y, sigma) diff --git a/models/pdb_eight_schools_noncentered.jl b/models/pdb_eight_schools_noncentered.jl index 1a36d0c..3662ec7 100644 --- a/models/pdb_eight_schools_noncentered.jl +++ b/models/pdb_eight_schools_noncentered.jl @@ -13,4 +13,4 @@ sigma = [15, 10, 16, 11, 9, 11, 10, 18] end end -@register pdb_eight_schools_noncentered(J, y, sigma) +model = pdb_eight_schools_noncentered(J, y, sigma) diff --git a/models/von_mises.jl b/models/von_mises.jl index 3e21d03..e3b997f 100644 --- a/models/von_mises.jl +++ b/models/von_mises.jl @@ -3,4 +3,4 @@ x ~ VonMises(0, a) end -@register von_mises(0.4) +model = von_mises(0.4)