Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/generate_website.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ jobs:
run: uv run ad.py run --model ${{ matrix.model }}
env:
ADTYPE_KEYS: ${{ needs.setup-keys.outputs.adtype_keys }}
ADTESTS_MODELS_TO_LOAD: ${{ matrix.model }}

- name: Output matrix values
id: output-matrix
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Expand Down
127 changes: 83 additions & 44 deletions main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,51 +32,90 @@ macro register(model)
:(MODELS[string($(esc(model)).f)] = $(esc(model)))
end

# These imports tend to get used a lot in models
using DynamicPPL: @model, to_submodel
using Distributions
using LinearAlgebra
"""
include_model(model_name::AbstractString)

include("models/assume_beta.jl")
include("models/assume_dirichlet.jl")
include("models/assume_lkjcholu.jl")
include("models/assume_mvnormal.jl")
include("models/assume_normal.jl")
include("models/assume_submodel.jl")
include("models/assume_wishart.jl")
include("models/broadcast_macro.jl")
include("models/control_flow.jl")
include("models/demo_assume_dot_observe.jl")
include("models/demo_assume_dot_observe_literal.jl")
include("models/demo_assume_index_observe.jl")
include("models/demo_assume_matrix_observe_matrix_index.jl")
include("models/demo_assume_multivariate_observe.jl")
include("models/demo_assume_multivariate_observe_literal.jl")
include("models/demo_assume_observe_literal.jl")
include("models/demo_assume_submodel_observe_index_literal.jl")
include("models/demo_dot_assume_observe.jl")
include("models/demo_dot_assume_observe_index.jl")
include("models/demo_dot_assume_observe_index_literal.jl")
include("models/demo_dot_assume_observe_matrix_index.jl")
include("models/demo_dot_assume_observe_submodel.jl")
include("models/dot_assume.jl")
include("models/dot_observe.jl")
include("models/dynamic_constraint.jl")
include("models/multiple_constraints_same_var.jl")
include("models/multithreaded.jl")
include("models/n010.jl")
include("models/n050.jl")
include("models/n100.jl")
include("models/n500.jl")
include("models/observe_bernoulli.jl")
include("models/observe_categorical.jl")
include("models/observe_index.jl")
include("models/observe_literal.jl")
include("models/observe_multivariate.jl")
include("models/observe_submodel.jl")
include("models/pdb_eight_schools_centered.jl")
include("models/pdb_eight_schools_noncentered.jl")
include("models/von_mises.jl")
Add the model defined in `models/model_name.jl` to the full list of models
tested in this script.

We want to isolate every model in its own module, so that we avoid e.g.
variable clashes, and also so that we make imports explicit.

However, we don't want the _model files_ themselves to be cluttered with e.g.
`module ... end` blocks as well as boring imports like @model,
Distributions.jl, etc. which Turing re-exports by default anyway. This is
because (a) it's boring and repetitive; and (b) the model definition shown on
the website is exactly the contents of each file, so we would like to keep it
as clean as possible.

To this end, instead of using `include(filename)` we write a small macro that
does this for us. We require the following to be true for each model
definition file:
- The file is in the `models/` directory.
- The file is named `model_name.jl`, where `model_name` is the name of the
model, i.e. it's defined with `@model function model_name(...) ... end`.
- Once defined, the model is created using `model = model_name(...)`. The
`model` on the left-hand side is mandatory.
"""
macro include_model(model_name::AbstractString)
MODELS_TO_LOAD = get(ENV, "ADTESTS_MODELS_TO_LOAD", "__all__")
if MODELS_TO_LOAD == "__all__" || model_name in split(MODELS_TO_LOAD, ",")
# Declare a module containing the model. In principle esc() shouldn't
# be needed, but see https://github.com/JuliaLang/julia/issues/55677
Expr(:toplevel, esc(:(
module $(gensym())
using .Main: @register
using Turing
include("models/" * $(model_name) * ".jl")
@register model
end
)))
else
# Empty expression
:()
end
end

@include_model "assume_beta"
@include_model "assume_dirichlet"
@include_model "assume_lkjcholu"
@include_model "assume_mvnormal"
@include_model "assume_normal"
@include_model "assume_submodel"
@include_model "assume_wishart"
@include_model "broadcast_macro"
@include_model "control_flow"
@include_model "demo_assume_dot_observe"
@include_model "demo_assume_dot_observe_literal"
@include_model "demo_assume_index_observe"
@include_model "demo_assume_matrix_observe_matrix_index"
@include_model "demo_assume_multivariate_observe"
@include_model "demo_assume_multivariate_observe_literal"
@include_model "demo_assume_observe_literal"
@include_model "demo_assume_submodel_observe_index_literal"
@include_model "demo_dot_assume_observe"
@include_model "demo_dot_assume_observe_index"
@include_model "demo_dot_assume_observe_index_literal"
@include_model "demo_dot_assume_observe_matrix_index"
@include_model "demo_dot_assume_observe_submodel"
@include_model "dot_assume"
@include_model "dot_observe"
@include_model "dynamic_constraint"
@include_model "multiple_constraints_same_var"
@include_model "multithreaded"
@include_model "n010"
@include_model "n050"
@include_model "n100"
@include_model "n500"
@include_model "observe_bernoulli"
@include_model "observe_categorical"
@include_model "observe_index"
@include_model "observe_literal"
@include_model "observe_multivariate"
@include_model "observe_submodel"
@include_model "pdb_eight_schools_centered"
@include_model "pdb_eight_schools_noncentered"
@include_model "von_mises"

# The entry point to this script itself begins here
if ARGS == ["--list-model-keys"]
Expand Down
2 changes: 1 addition & 1 deletion models/assume_beta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
a ~ Beta(2, 2)
end

@register assume_beta()
model = assume_beta()
2 changes: 1 addition & 1 deletion models/assume_dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
a ~ Dirichlet([1.0, 5.0])
end

@register assume_dirichlet()
model = assume_dirichlet()
2 changes: 1 addition & 1 deletion models/assume_lkjcholu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
a ~ LKJCholesky(5, 1.0, 'U')
end

@register assume_lkjcholu()
model = assume_lkjcholu()
2 changes: 1 addition & 1 deletion models/assume_mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
a ~ MvNormal([0.0, 0.0], [1.0 0.5; 0.5 1.0])
end

@register assume_mvnormal()
model = assume_mvnormal()
2 changes: 1 addition & 1 deletion models/assume_normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
a ~ Normal()
end

@register assume_normal()
model = assume_normal()
2 changes: 1 addition & 1 deletion models/assume_submodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ end
x ~ Normal(a)
end

@register assume_submodel()
model = assume_submodel()
2 changes: 1 addition & 1 deletion models/assume_wishart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
a ~ Wishart(7, [1.0 0.5; 0.5 1.0])
end

@register assume_wishart()
model = assume_wishart()
2 changes: 1 addition & 1 deletion models/broadcast_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
@. x ~ Normal(a, $(sqrt(b)))
end

@register broadcast_macro()
model = broadcast_macro()
2 changes: 1 addition & 1 deletion models/control_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ evaluated at a value of `a < 0`. See `main.jl` for more information.
end
end

@register control_flow()
model = control_flow()
2 changes: 1 addition & 1 deletion models/demo_assume_dot_observe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
x .~ Normal(m, sqrt(s))
end

@register demo_assume_dot_observe()
model = demo_assume_dot_observe()
2 changes: 1 addition & 1 deletion models/demo_assume_dot_observe_literal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
[1.5, 2.0] .~ Normal(m, sqrt(s))
end

@register demo_assume_dot_observe_literal()
model = demo_assume_dot_observe_literal()
4 changes: 3 additions & 1 deletion models/demo_assume_index_observe.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using LinearAlgebra: Diagonal

@model function demo_assume_index_observe(
x = [1.5, 2.0],
::Type{TV} = Vector{Float64},
Expand All @@ -14,4 +16,4 @@
x ~ MvNormal(m, Diagonal(s))
end

@register demo_assume_index_observe()
model = demo_assume_index_observe()
5 changes: 3 additions & 2 deletions models/demo_assume_matrix_observe_matrix_index.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using LinearAlgebra: Diagonal

@model function demo_assume_matrix_observe_matrix_index(
x = transpose([1.5 2.0;]),
::Type{TV} = Array{Float64},
Expand All @@ -7,8 +9,7 @@
s ~ reshape(product_distribution(fill(InverseGamma(2, 3), n)), d, 2)
s_vec = vec(s)
m ~ MvNormal(zeros(n), Diagonal(s_vec))

x[:, 1] ~ MvNormal(m, Diagonal(s_vec))
end

@register demo_assume_matrix_observe_matrix_index()
model = demo_assume_matrix_observe_matrix_index()
5 changes: 4 additions & 1 deletion models/demo_assume_multivariate_observe.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
using LinearAlgebra: Diagonal

@model function demo_assume_multivariate_observe(x = [1.5, 2.0])
# Multivariate `assume` and `observe`
s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
m ~ MvNormal(zero(x), Diagonal(s))
x ~ MvNormal(m, Diagonal(s))
end
@register demo_assume_multivariate_observe()

model = demo_assume_multivariate_observe()
4 changes: 3 additions & 1 deletion models/demo_assume_multivariate_observe_literal.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
using LinearAlgebra: Diagonal

@model function demo_assume_multivariate_observe_literal()
# multivariate `assume` and literal `observe`
s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
m ~ MvNormal(zeros(2), Diagonal(s))
[1.5, 2.0] ~ MvNormal(m, Diagonal(s))
end

@register demo_assume_multivariate_observe_literal()
model = demo_assume_multivariate_observe_literal()
2 changes: 1 addition & 1 deletion models/demo_assume_observe_literal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
2.0 ~ Normal(m, sqrt(s))
end

@register demo_assume_observe_literal()
model = demo_assume_observe_literal()
2 changes: 1 addition & 1 deletion models/demo_assume_submodel_observe_index_literal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ end
2.0 ~ Normal(m[2], sqrt(s[2]))
end

@register demo_assume_submodel_observe_index_literal()
model = demo_assume_submodel_observe_index_literal()
4 changes: 3 additions & 1 deletion models/demo_dot_assume_observe.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using LinearAlgebra: Diagonal

@model function demo_dot_assume_observe(
x = [1.5, 2.0],
::Type{TV} = Vector{Float64},
Expand All @@ -10,4 +12,4 @@
x ~ MvNormal(m, Diagonal(s))
end

@register demo_dot_assume_observe()
model = demo_dot_assume_observe()
2 changes: 1 addition & 1 deletion models/demo_dot_assume_observe_index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
end
end

@register demo_dot_assume_observe_index()
model = demo_dot_assume_observe_index()
2 changes: 1 addition & 1 deletion models/demo_dot_assume_observe_index_literal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
2.0 ~ Normal(m[2], sqrt(s[2]))
end

@register demo_dot_assume_observe_index_literal()
model = demo_dot_assume_observe_index_literal()
4 changes: 3 additions & 1 deletion models/demo_dot_assume_observe_matrix_index.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using LinearAlgebra: Diagonal

@model function demo_dot_assume_observe_matrix_index(
x = transpose([1.5 2.0;]),
::Type{TV} = Vector{Float64},
Expand All @@ -9,4 +11,4 @@
x[:, 1] ~ MvNormal(m, Diagonal(s))
end

@register demo_dot_assume_observe_matrix_index()
model = demo_dot_assume_observe_matrix_index()
4 changes: 3 additions & 1 deletion models/demo_dot_assume_observe_submodel.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using LinearAlgebra: Diagonal

@model function _likelihood_multivariate_observe(s, m, x)
return x ~ MvNormal(m, Diagonal(s))
end
Expand All @@ -17,4 +19,4 @@ end
_ignore ~ to_submodel(_likelihood_multivariate_observe(s, m, x))
end

@register demo_dot_assume_observe_submodel()
model = demo_dot_assume_observe_submodel()
2 changes: 1 addition & 1 deletion models/dot_assume.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
a .~ Normal()
end

@register dot_assume()
model = dot_assume()
2 changes: 1 addition & 1 deletion models/dot_observe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
x .~ Normal(a)
end

@register dot_observe()
model = dot_observe()
2 changes: 1 addition & 1 deletion models/dynamic_constraint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
b ~ truncated(Normal(); lower = a)
end

@register dynamic_constraint()
model = dynamic_constraint()
2 changes: 1 addition & 1 deletion models/multiple_constraints_same_var.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
x[4:5] ~ Dirichlet([1.0, 2.0])
end

@register multiple_constraints_same_var()
model = multiple_constraints_same_var()
2 changes: 1 addition & 1 deletion models/multithreaded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ statements. See `main.jl` for more information.
end
end

@register multithreaded([1.5, 2.0, 2.5, 1.5, 2.0, 2.5])
model = multithreaded([1.5, 2.0, 2.5, 1.5, 2.0, 2.5])
2 changes: 1 addition & 1 deletion models/n010.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
end
end

@register n010()
model = n010()
2 changes: 1 addition & 1 deletion models/n050.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
end
end

@register n050()
model = n050()
2 changes: 1 addition & 1 deletion models/n100.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
end
end

@register n100()
model = n100()
2 changes: 1 addition & 1 deletion models/n500.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
end
end

@register n500()
model = n500()
2 changes: 1 addition & 1 deletion models/observe_bernoulli.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
end
end

@register observe_bernoulli()
model = observe_bernoulli()
2 changes: 1 addition & 1 deletion models/observe_categorical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
end
end

@register observe_categorical()
model = observe_categorical()
Loading
Loading