|
1 | 1 | const INTERNALNAMES = (:__model__, :__context__, :__varinfo__) |
2 | 2 |
|
3 | | -struct SampleableModelWrapper{M} |
4 | | - model::M |
5 | | -end |
6 | | - |
7 | | -""" |
8 | | - to_sampleable(model::Model) |
9 | | -
|
10 | | -Return a wrapper around `model` which indicates that this model can only be sampled from. |
11 | | -
|
12 | | -This is mainly meant to be used on the right-hand side of a `~` operator to indicate that |
13 | | -the model can be sampled from but not necessarily evaluated for its log density. |
14 | | -
|
15 | | -!!! warning |
16 | | - Note that other operations that one typically associate with expressions of the form `left ~ right` |
17 | | - such as [`condition`](@ref) or [`fix`](@ref), will also not work with `to_sampleable`. |
18 | | -
|
19 | | -!!! warning |
20 | | - It's generally recommended to use [`prefix(::Model, input)`](@ref) when working with submodels |
21 | | - to ensure that the variables in `model` are unique and do not clash with other variables in the |
22 | | - parent model or in other submodels. |
23 | | -
|
24 | | -# Examples |
25 | | -
|
26 | | -## Simple example |
27 | | -```jldoctest submodel-to-sampleable; setup=:(using Distributions) |
28 | | -julia> @model function demo1(x) |
29 | | - x ~ Normal() |
30 | | - return 1 + abs(x) |
31 | | - end; |
32 | | -
|
33 | | -julia> @model function demo2(x, y) |
34 | | - a ~ to_sampleable(demo1(x)) |
35 | | - return y ~ Uniform(0, a) |
36 | | - end; |
37 | | -``` |
38 | | -
|
39 | | -When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: |
40 | | -```jldoctest submodel-to-sampleable |
41 | | -julia> vi = VarInfo(demo2(missing, 0.4)); |
42 | | -
|
43 | | -julia> @varname(x) in keys(vi) |
44 | | -true |
45 | | -``` |
46 | | -
|
47 | | -Variable `a` is not tracked since it can be computed from the random variable `x` that was |
48 | | -tracked when running `demo1`: |
49 | | -```jldoctest submodel-to-sampleable |
50 | | -julia> @varname(a) in keys(vi) |
51 | | -false |
52 | | -``` |
53 | | -
|
54 | | -We can check that the log joint probability of the model accumulated in `vi` is correct: |
55 | | -
|
56 | | -```jldoctest submodel-to-sampleable |
57 | | -julia> x = vi[@varname(x)]; |
58 | | -
|
59 | | -julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) |
60 | | -true |
61 | | -``` |
62 | | -
|
63 | | -## With prefixing |
64 | | -```jldoctest submodel-to-sampleable-prefix; setup=:(using Distributions) |
65 | | -julia> @model function demo1(x) |
66 | | - x ~ Normal() |
67 | | - return 1 + abs(x) |
68 | | - end; |
69 | | -
|
70 | | -julia> @model function demo2(x, y, z) |
71 | | - a ~ to_sampleable(prefix(demo1(x), :sub1)) |
72 | | - b ~ to_sampleable(prefix(demo1(y), :sub2)) |
73 | | - return z ~ Uniform(-a, b) |
74 | | - end; |
75 | | -``` |
76 | | -
|
77 | | -When we sample from the model `demo2(missing, missing, 0.4)` random variables `sub1.x` and |
78 | | -`sub2.x` will be sampled: |
79 | | -```jldoctest submodel-to-sampleable-prefix |
80 | | -julia> vi = VarInfo(demo2(missing, missing, 0.4)); |
81 | | -
|
82 | | -julia> @varname(var"sub1.x") in keys(vi) |
83 | | -true |
84 | | -
|
85 | | -julia> @varname(var"sub2.x") in keys(vi) |
86 | | -true |
87 | | -``` |
88 | | -
|
89 | | -Variables `a` and `b` are not tracked since they can be computed from the random variables `sub1.x` and |
90 | | -`sub2.x` that were tracked when running `demo1`: |
91 | | -```jldoctest submodel-to-sampleable-prefix |
92 | | -julia> @varname(a) in keys(vi) |
93 | | -false |
94 | | -
|
95 | | -julia> @varname(b) in keys(vi) |
96 | | -false |
97 | | -``` |
98 | | -
|
99 | | -We can check that the log joint probability of the model accumulated in `vi` is correct: |
100 | | -
|
101 | | -```jldoctest submodel-to-sampleable-prefix |
102 | | -julia> sub1_x = vi[@varname(var"sub1.x")]; |
103 | | -
|
104 | | -julia> sub2_x = vi[@varname(var"sub2.x")]; |
105 | | -
|
106 | | -julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); |
107 | | -
|
108 | | -julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); |
109 | | -
|
110 | | -julia> getlogp(vi) ≈ logprior + loglikelihood |
111 | | -true |
112 | | -``` |
113 | | -
|
114 | | -## Different ways of setting the prefix |
115 | | -```jldoctest submodel-to-sampleable-prefix-alts; setup=:(using DynamicPPL, Distributions) |
116 | | -julia> @model inner() = x ~ Normal() |
117 | | -inner (generic function with 2 methods) |
118 | | -
|
119 | | -julia> # When `prefix` is unspecified, no prefix is used. |
120 | | - @model submodel_noprefix() = a ~ to_sampleable(inner()) |
121 | | -submodel_noprefix (generic function with 2 methods) |
122 | | -
|
123 | | -julia> @varname(x) in keys(VarInfo(submodel_noprefix())) |
124 | | -true |
125 | | -
|
126 | | -julia> # Using a static string. |
127 | | - @model submodel_prefix_string() = a ~ to_sampleable(prefix(inner(), "my prefix")) |
128 | | -submodel_prefix_string (generic function with 2 methods) |
129 | | -
|
130 | | -julia> @varname(var"my prefix.x") in keys(VarInfo(submodel_prefix_string())) |
131 | | -true |
132 | | -
|
133 | | -julia> # Using string interpolation. |
134 | | - @model submodel_prefix_interpolation() = a = to_sampleable(prefix(inner(), "\$(nameof(inner()))")) |
135 | | -submodel_prefix_interpolation (generic function with 2 methods) |
136 | | -
|
137 | | -julia> @varname(var"inner.x") in keys(VarInfo(submodel_prefix_interpolation())) |
138 | | -true |
139 | | -
|
140 | | -julia> # Or using some arbitrary expression. |
141 | | - @model submodel_prefix_expr() = a ~ to_sampleable(prefix(inner(), 1 + 2)) |
142 | | -submodel_prefix_expr (generic function with 2 methods) |
143 | | -
|
144 | | -julia> @varname(var"3.x") in keys(VarInfo(submodel_prefix_expr())) |
145 | | -true |
146 | | -``` |
147 | | -""" |
148 | | -to_sampleable(model::Model) = SampleableModelWrapper(model) |
149 | | - |
150 | 3 | """ |
151 | 4 | need_concretize(expr) |
152 | 5 |
|
@@ -325,6 +178,7 @@ function check_tilde_rhs(@nospecialize(x)) |
325 | 178 | end |
326 | 179 | check_tilde_rhs(x::Distribution) = x |
327 | 180 | check_tilde_rhs(x::AbstractArray{<:Distribution}) = x |
| 181 | +check_tilde_rhs(x::ReturnedModelWrapper) = x |
328 | 182 |
|
329 | 183 | """ |
330 | 184 | unwrap_right_vn(right, vn) |
@@ -574,34 +428,28 @@ function generate_tilde(left, right) |
574 | 428 | # more selective with our escape. Until that's the case, we remove them all. |
575 | 429 | return quote |
576 | 430 | $dist = $right |
577 | | - |
578 | | - if $dist isa $(SampleableModelWrapper) |
579 | | - $left, __varinfo__ = $(_evaluate!!)($dist.model, __varinfo__, __context__) |
580 | | - $left |
| 431 | + $vn = $(DynamicPPL.resolve_varnames)( |
| 432 | + $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist |
| 433 | + ) |
| 434 | + $isassumption = $(DynamicPPL.isassumption(left, vn)) |
| 435 | + if $(DynamicPPL.isfixed(left, vn)) |
| 436 | + $left = $(DynamicPPL.getfixed_nested)(__context__, $vn) |
| 437 | + elseif $isassumption |
| 438 | + $(generate_tilde_assume(left, dist, vn)) |
581 | 439 | else |
582 | | - $vn = $(DynamicPPL.resolve_varnames)( |
583 | | - $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist |
584 | | - ) |
585 | | - $isassumption = $(DynamicPPL.isassumption(left, vn)) |
586 | | - if $(DynamicPPL.isfixed(left, vn)) |
587 | | - $left = $(DynamicPPL.getfixed_nested)(__context__, $vn) |
588 | | - elseif $isassumption |
589 | | - $(generate_tilde_assume(left, dist, vn)) |
590 | | - else |
591 | | - # If `vn` is not in `argnames`, we need to make sure that the variable is defined. |
592 | | - if !$(DynamicPPL.inargnames)($vn, __model__) |
593 | | - $left = $(DynamicPPL.getconditioned_nested)(__context__, $vn) |
594 | | - end |
595 | | - |
596 | | - $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( |
597 | | - __context__, |
598 | | - $(DynamicPPL.check_tilde_rhs)($dist), |
599 | | - $(maybe_view(left)), |
600 | | - $vn, |
601 | | - __varinfo__, |
602 | | - ) |
603 | | - $value |
| 440 | + # If `vn` is not in `argnames`, we need to make sure that the variable is defined. |
| 441 | + if !$(DynamicPPL.inargnames)($vn, __model__) |
| 442 | + $left = $(DynamicPPL.getconditioned_nested)(__context__, $vn) |
604 | 443 | end |
| 444 | + |
| 445 | + $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( |
| 446 | + __context__, |
| 447 | + $(DynamicPPL.check_tilde_rhs)($dist), |
| 448 | + $(maybe_view(left)), |
| 449 | + $vn, |
| 450 | + __varinfo__, |
| 451 | + ) |
| 452 | + $value |
605 | 453 | end |
606 | 454 | end |
607 | 455 | end |
|
0 commit comments