Skip to content

Commit 0e7da06

Browse files
committed
Add demo_assume_literal_observe + rename demo_assume_observe_literal -> demo_assume_multivariate_observe_literal
1 parent cbc5f7b commit 0e7da06

File tree

1 file changed

+40
-10
lines changed

1 file changed

+40
-10
lines changed

src/test_utils/models.jl

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -323,28 +323,30 @@ function varnames(model::Model{typeof(demo_assume_dot_observe)})
323323
return [@varname(s), @varname(m)]
324324
end
325325

326-
@model function demo_assume_observe_literal()
327-
# `assume` and literal `observe`
326+
@model function demo_assume_multivariate_observe_literal()
327+
# multivariate `assume` and literal `observe`
328328
s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
329329
m ~ MvNormal(zeros(2), Diagonal(s))
330330
[1.5, 2.0] ~ MvNormal(m, Diagonal(s))
331331

332332
return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
333333
end
334-
function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m)
334+
function logprior_true(model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m)
335335
s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
336336
m_dist = MvNormal(zeros(2), Diagonal(s))
337337
return logpdf(s_dist, s) + logpdf(m_dist, m)
338338
end
339-
function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m)
339+
function loglikelihood_true(
340+
model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m
341+
)
340342
return logpdf(MvNormal(m, Diagonal(s)), [1.5, 2.0])
341343
end
342344
function logprior_true_with_logabsdet_jacobian(
343-
model::Model{typeof(demo_assume_observe_literal)}, s, m
345+
model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m
344346
)
345347
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
346348
end
347-
function varnames(model::Model{typeof(demo_assume_observe_literal)})
349+
function varnames(model::Model{typeof(demo_assume_multivariate_observe_literal)})
348350
return [@varname(s), @varname(m)]
349351
end
350352

@@ -377,6 +379,30 @@ function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)})
377379
return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])]
378380
end
379381

382+
@model function demo_assume_literal_observe()
383+
# univariate `assume` and literal `observe`
384+
s ~ InverseGamma(2, 3)
385+
m ~ Normal(0, sqrt(s))
386+
1.5 ~ Normal(m, sqrt(s))
387+
2.0 ~ Normal(m, sqrt(s))
388+
389+
return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
390+
end
391+
function logprior_true(model::Model{typeof(demo_assume_literal_observe)}, s, m)
392+
return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m)
393+
end
394+
function loglikelihood_true(model::Model{typeof(demo_assume_literal_observe)}, s, m)
395+
return logpdf(Normal(m, sqrt(s)), 1.5) + logpdf(Normal(m, sqrt(s)), 2.0)
396+
end
397+
function logprior_true_with_logabsdet_jacobian(
398+
model::Model{typeof(demo_assume_literal_observe)}, s, m
399+
)
400+
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
401+
end
402+
function varnames(model::Model{typeof(demo_assume_literal_observe)})
403+
return [@varname(s), @varname(m)]
404+
end
405+
380406
@model function demo_assume_literal_dot_observe()
381407
# `assume` and literal `dot_observe`
382408
s ~ InverseGamma(2, 3)
@@ -575,7 +601,8 @@ const DemoModels = Union{
575601
Model{typeof(demo_dot_assume_observe_index)},
576602
Model{typeof(demo_assume_dot_observe)},
577603
Model{typeof(demo_assume_literal_dot_observe)},
578-
Model{typeof(demo_assume_observe_literal)},
604+
Model{typeof(demo_assume_literal_observe)},
605+
Model{typeof(demo_assume_multivariate_observe_literal)},
579606
Model{typeof(demo_dot_assume_observe_index_literal)},
580607
Model{typeof(demo_assume_submodel_observe_index_literal)},
581608
Model{typeof(demo_dot_assume_observe_submodel)},
@@ -585,7 +612,9 @@ const DemoModels = Union{
585612
}
586613

587614
const UnivariateAssumeDemoModels = Union{
588-
Model{typeof(demo_assume_dot_observe)},Model{typeof(demo_assume_literal_dot_observe)}
615+
Model{typeof(demo_assume_dot_observe)},
616+
Model{typeof(demo_assume_literal_dot_observe)}
617+
Model{typeof(demo_assume_literal_observe)}
589618
}
590619
function posterior_mean(model::UnivariateAssumeDemoModels)
591620
return (s=49 / 24, m=7 / 6)
@@ -609,7 +638,7 @@ const MultivariateAssumeDemoModels = Union{
609638
Model{typeof(demo_assume_index_observe)},
610639
Model{typeof(demo_assume_multivariate_observe)},
611640
Model{typeof(demo_dot_assume_observe_index)},
612-
Model{typeof(demo_assume_observe_literal)},
641+
Model{typeof(demo_assume_multivariate_observe_literal)},
613642
Model{typeof(demo_dot_assume_observe_index_literal)},
614643
Model{typeof(demo_assume_submodel_observe_index_literal)},
615644
Model{typeof(demo_dot_assume_observe_submodel)},
@@ -759,9 +788,10 @@ const DEMO_MODELS = (
759788
demo_assume_multivariate_observe(),
760789
demo_dot_assume_observe_index(),
761790
demo_assume_dot_observe(),
762-
demo_assume_observe_literal(),
791+
demo_assume_multivariate_observe_literal(),
763792
demo_dot_assume_observe_index_literal(),
764793
demo_assume_literal_dot_observe(),
794+
demo_assume_literal_observe(),
765795
demo_assume_submodel_observe_index_literal(),
766796
demo_dot_assume_observe_submodel(),
767797
demo_dot_assume_dot_observe_matrix(),

0 commit comments

Comments
 (0)