Skip to content

Commit aa81d07

Browse files
committed
Add full DynamicPPL demo models
1 parent 7378112 commit aa81d07

28 files changed

+206
-28
lines changed

main.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ include("models/assume_normal.jl")
4444
include("models/assume_submodel.jl")
4545
include("models/assume_wishart.jl")
4646
include("models/control_flow.jl")
47-
include("models/dot_assume_observe_index.jl")
47+
include("models/demo_dot_assume_observe_index.jl")
4848
include("models/dot_assume.jl")
49+
include("models/demo_dot_assume_observe.jl")
4950
include("models/dot_observe.jl")
5051
include("models/dynamic_constraint.jl")
5152
include("models/multiple_constraints_same_var.jl")

models/demo_assume_dot_observe.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
@model function demo_assume_dot_observe(x=[1.5, 2.0])
2+
# `assume` and `dot_observe`
3+
s ~ InverseGamma(2, 3)
4+
m ~ Normal(0, sqrt(s))
5+
x .~ Normal(m, sqrt(s))
6+
7+
return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
8+
end
9+
10+
@register demo_assume_dot_observe()
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
@model function demo_assume_dot_observe_literal()
2+
# `assume` and literal `dot_observe`
3+
s ~ InverseGamma(2, 3)
4+
m ~ Normal(0, sqrt(s))
5+
[1.5, 2.0] .~ Normal(m, sqrt(s))
6+
7+
return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
8+
end
9+
10+
@model demo_assume_dot_observe_literal()
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
@model function demo_assume_index_observe(
2+
x = [1.5, 2.0],
3+
::Type{TV} = Vector{Float64},
4+
) where {TV}
5+
# `assume` with indexing and `observe`
6+
s = TV(undef, length(x))
7+
for i in eachindex(s)
8+
s[i] ~ InverseGamma(2, 3)
9+
end
10+
m = TV(undef, length(x))
11+
for i in eachindex(m)
12+
m[i] ~ Normal(0, sqrt(s[i]))
13+
end
14+
x ~ MvNormal(m, Diagonal(s))
15+
16+
return (; s = s, m = m, x = x, logp = getlogp(__varinfo__))
17+
end
18+
19+
@register demo_assume_index_observe()
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
@model function demo_assume_matrix_observe_matrix_index(
2+
x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64}
3+
) where {TV}
4+
n = length(x)
5+
d = n ÷ 2
6+
s ~ reshape(product_distribution(fill(InverseGamma(2, 3), n)), d, 2)
7+
s_vec = vec(s)
8+
m ~ MvNormal(zeros(n), Diagonal(s_vec))
9+
10+
x[:, 1] ~ MvNormal(m, Diagonal(s_vec))
11+
12+
return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
13+
end
14+
15+
@register demo_assume_matrix_observe_matrix_index()
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
@model function demo_assume_multivariate_observe(x=[1.5, 2.0])
2+
# Multivariate `assume` and `observe`
3+
s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
4+
m ~ MvNormal(zero(x), Diagonal(s))
5+
x ~ MvNormal(m, Diagonal(s))
6+
7+
return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
8+
end
9+
10+
@register demo_assume_multivariate_observe()
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
@model function demo_assume_multivariate_observe_literal()
2+
# multivariate `assume` and literal `observe`
3+
s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
4+
m ~ MvNormal(zeros(2), Diagonal(s))
5+
[1.5, 2.0] ~ MvNormal(m, Diagonal(s))
6+
7+
return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
8+
end
9+
10+
@register demo_assume_multivariate_observe_literal()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
@model function demo_assume_observe_literal()
2+
# univariate `assume` and literal `observe`
3+
s ~ InverseGamma(2, 3)
4+
m ~ Normal(0, sqrt(s))
5+
1.5 ~ Normal(m, sqrt(s))
6+
2.0 ~ Normal(m, sqrt(s))
7+
8+
return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
9+
end
10+
11+
@register demo_assume_observe_literal()
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV}
2+
s = TV(undef, 2)
3+
s .~ InverseGamma(2, 3)
4+
m = TV(undef, 2)
5+
m ~ product_distribution(Normal.(0, sqrt.(s)))
6+
return s, m
7+
end
8+
9+
@model function demo_assume_submodel_observe_index_literal()
10+
# Submodel prior
11+
priors ~ to_submodel(_prior_dot_assume(), false)
12+
s, m = priors
13+
1.5 ~ Normal(m[1], sqrt(s[1]))
14+
2.0 ~ Normal(m[2], sqrt(s[2]))
15+
16+
return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
17+
end
18+
19+
@register demo_assume_submodel_observe_index_literal()

models/demo_dot_assume_observe.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
@model function demo_dot_assume_observe(
2+
x = [1.5, 2.0],
3+
::Type{TV} = Vector{Float64},
4+
) where {TV}
5+
# `dot_assume` and `observe`
6+
s = TV(undef, length(x))
7+
m = TV(undef, length(x))
8+
s .~ InverseGamma(2, 3)
9+
m ~ product_distribution(Normal.(0, sqrt.(s)))
10+
11+
x ~ MvNormal(m, Diagonal(s))
12+
return (; s = s, m = m, x = x, logp = getlogp(__varinfo__))
13+
end
14+
15+
@register demo_dot_assume_observe()

0 commit comments

Comments
 (0)