Skip to content

Commit 45451f7

Browse files
committed
removed redundant SampleableModelWrapper in favour of
`ReturnedModelWrapper` + introduced `rand_like!!` to hide explicit calls to `_evaluate!!`
1 parent 5134ff7 commit 45451f7

File tree

3 files changed

+231
-177
lines changed

3 files changed

+231
-177
lines changed

src/compiler.jl

Lines changed: 21 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -1,152 +1,5 @@
11
const INTERNALNAMES = (:__model__, :__context__, :__varinfo__)
22

3-
struct SampleableModelWrapper{M}
4-
model::M
5-
end
6-
7-
"""
8-
to_sampleable(model::Model)
9-
10-
Return a wrapper around `model` which indicates that this model can only be sampled from.
11-
12-
This is mainly meant to be used on the right-hand side of a `~` operator to indicate that
13-
the model can be sampled from but not necessarily evaluated for its log density.
14-
15-
!!! warning
16-
Note that other operations that one typically associate with expressions of the form `left ~ right`
17-
such as [`condition`](@ref) or [`fix`](@ref), will also not work with `to_sampleable`.
18-
19-
!!! warning
20-
It's generally recommended to use [`prefix(::Model, input)`](@ref) when working with submodels
21-
to ensure that the variables in `model` are unique and do not clash with other variables in the
22-
parent model or in other submodels.
23-
24-
# Examples
25-
26-
## Simple example
27-
```jldoctest submodel-to-sampleable; setup=:(using Distributions)
28-
julia> @model function demo1(x)
29-
x ~ Normal()
30-
return 1 + abs(x)
31-
end;
32-
33-
julia> @model function demo2(x, y)
34-
a ~ to_sampleable(demo1(x))
35-
return y ~ Uniform(0, a)
36-
end;
37-
```
38-
39-
When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled:
40-
```jldoctest submodel-to-sampleable
41-
julia> vi = VarInfo(demo2(missing, 0.4));
42-
43-
julia> @varname(x) in keys(vi)
44-
true
45-
```
46-
47-
Variable `a` is not tracked since it can be computed from the random variable `x` that was
48-
tracked when running `demo1`:
49-
```jldoctest submodel-to-sampleable
50-
julia> @varname(a) in keys(vi)
51-
false
52-
```
53-
54-
We can check that the log joint probability of the model accumulated in `vi` is correct:
55-
56-
```jldoctest submodel-to-sampleable
57-
julia> x = vi[@varname(x)];
58-
59-
julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4)
60-
true
61-
```
62-
63-
## With prefixing
64-
```jldoctest submodel-to-sampleable-prefix; setup=:(using Distributions)
65-
julia> @model function demo1(x)
66-
x ~ Normal()
67-
return 1 + abs(x)
68-
end;
69-
70-
julia> @model function demo2(x, y, z)
71-
a ~ to_sampleable(prefix(demo1(x), :sub1))
72-
b ~ to_sampleable(prefix(demo1(y), :sub2))
73-
return z ~ Uniform(-a, b)
74-
end;
75-
```
76-
77-
When we sample from the model `demo2(missing, missing, 0.4)` random variables `sub1.x` and
78-
`sub2.x` will be sampled:
79-
```jldoctest submodel-to-sampleable-prefix
80-
julia> vi = VarInfo(demo2(missing, missing, 0.4));
81-
82-
julia> @varname(var"sub1.x") in keys(vi)
83-
true
84-
85-
julia> @varname(var"sub2.x") in keys(vi)
86-
true
87-
```
88-
89-
Variables `a` and `b` are not tracked since they can be computed from the random variables `sub1.x` and
90-
`sub2.x` that were tracked when running `demo1`:
91-
```jldoctest submodel-to-sampleable-prefix
92-
julia> @varname(a) in keys(vi)
93-
false
94-
95-
julia> @varname(b) in keys(vi)
96-
false
97-
```
98-
99-
We can check that the log joint probability of the model accumulated in `vi` is correct:
100-
101-
```jldoctest submodel-to-sampleable-prefix
102-
julia> sub1_x = vi[@varname(var"sub1.x")];
103-
104-
julia> sub2_x = vi[@varname(var"sub2.x")];
105-
106-
julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x);
107-
108-
julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4);
109-
110-
julia> getlogp(vi) ≈ logprior + loglikelihood
111-
true
112-
```
113-
114-
## Different ways of setting the prefix
115-
```jldoctest submodel-to-sampleable-prefix-alts; setup=:(using DynamicPPL, Distributions)
116-
julia> @model inner() = x ~ Normal()
117-
inner (generic function with 2 methods)
118-
119-
julia> # When `prefix` is unspecified, no prefix is used.
120-
@model submodel_noprefix() = a ~ to_sampleable(inner())
121-
submodel_noprefix (generic function with 2 methods)
122-
123-
julia> @varname(x) in keys(VarInfo(submodel_noprefix()))
124-
true
125-
126-
julia> # Using a static string.
127-
@model submodel_prefix_string() = a ~ to_sampleable(prefix(inner(), "my prefix"))
128-
submodel_prefix_string (generic function with 2 methods)
129-
130-
julia> @varname(var"my prefix.x") in keys(VarInfo(submodel_prefix_string()))
131-
true
132-
133-
julia> # Using string interpolation.
134-
@model submodel_prefix_interpolation() = a = to_sampleable(prefix(inner(), "\$(nameof(inner()))"))
135-
submodel_prefix_interpolation (generic function with 2 methods)
136-
137-
julia> @varname(var"inner.x") in keys(VarInfo(submodel_prefix_interpolation()))
138-
true
139-
140-
julia> # Or using some arbitrary expression.
141-
@model submodel_prefix_expr() = a ~ to_sampleable(prefix(inner(), 1 + 2))
142-
submodel_prefix_expr (generic function with 2 methods)
143-
144-
julia> @varname(var"3.x") in keys(VarInfo(submodel_prefix_expr()))
145-
true
146-
```
147-
"""
148-
to_sampleable(model::Model) = SampleableModelWrapper(model)
149-
1503
"""
1514
need_concretize(expr)
1525
@@ -325,6 +178,7 @@ function check_tilde_rhs(@nospecialize(x))
325178
end
326179
check_tilde_rhs(x::Distribution) = x
327180
check_tilde_rhs(x::AbstractArray{<:Distribution}) = x
181+
check_tilde_rhs(x::ReturnedModelWrapper) = x
328182

329183
"""
330184
unwrap_right_vn(right, vn)
@@ -574,34 +428,28 @@ function generate_tilde(left, right)
574428
# more selective with our escape. Until that's the case, we remove them all.
575429
return quote
576430
$dist = $right
577-
578-
if $dist isa $(SampleableModelWrapper)
579-
$left, __varinfo__ = $(_evaluate!!)($dist.model, __varinfo__, __context__)
580-
$left
431+
$vn = $(DynamicPPL.resolve_varnames)(
432+
$(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist
433+
)
434+
$isassumption = $(DynamicPPL.isassumption(left, vn))
435+
if $(DynamicPPL.isfixed(left, vn))
436+
$left = $(DynamicPPL.getfixed_nested)(__context__, $vn)
437+
elseif $isassumption
438+
$(generate_tilde_assume(left, dist, vn))
581439
else
582-
$vn = $(DynamicPPL.resolve_varnames)(
583-
$(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist
584-
)
585-
$isassumption = $(DynamicPPL.isassumption(left, vn))
586-
if $(DynamicPPL.isfixed(left, vn))
587-
$left = $(DynamicPPL.getfixed_nested)(__context__, $vn)
588-
elseif $isassumption
589-
$(generate_tilde_assume(left, dist, vn))
590-
else
591-
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
592-
if !$(DynamicPPL.inargnames)($vn, __model__)
593-
$left = $(DynamicPPL.getconditioned_nested)(__context__, $vn)
594-
end
595-
596-
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
597-
__context__,
598-
$(DynamicPPL.check_tilde_rhs)($dist),
599-
$(maybe_view(left)),
600-
$vn,
601-
__varinfo__,
602-
)
603-
$value
440+
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
441+
if !$(DynamicPPL.inargnames)($vn, __model__)
442+
$left = $(DynamicPPL.getconditioned_nested)(__context__, $vn)
604443
end
444+
445+
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
446+
__context__,
447+
$(DynamicPPL.check_tilde_rhs)($dist),
448+
$(maybe_view(left)),
449+
$vn,
450+
__varinfo__,
451+
)
452+
$value
605453
end
606454
end
607455
end

src/context_implementations.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,12 @@ By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log
141141
probability of `vi` with the returned value.
142142
"""
143143
function tilde_assume!!(context, right, vn, vi)
144-
value, logp, vi = tilde_assume(context, right, vn, vi)
145-
return value, acclogp_assume!!(context, vi, logp)
144+
return if is_rhs_model(right)
145+
rand_like!!(right, context, vi)
146+
else
147+
value, logp, vi = tilde_assume(context, right, vn, vi)
148+
value, acclogp_assume!!(context, vi, logp)
149+
end
146150
end
147151

148152
# observe
@@ -197,6 +201,7 @@ Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the informati
197201
and indices; if needed, these can be accessed through this function, though.
198202
"""
199203
function tilde_observe!!(context, right, left, vname, vi)
204+
is_rhs_model(right) && throw(ArgumentError("`~` with a model on the right-hand side of an observe statement is not supported"))
200205
return tilde_observe!!(context, right, left, vi)
201206
end
202207

@@ -210,6 +215,7 @@ By default, calls `tilde_observe(context, right, left, vi)` and accumulates the
210215
probability of `vi` with the returned value.
211216
"""
212217
function tilde_observe!!(context, right, left, vi)
218+
is_rhs_model(right) && throw(ArgumentError("`~` with a model on the right-hand side of an observe statement is not supported"))
213219
logp, vi = tilde_observe(context, right, left, vi)
214220
return left, acclogp_observe!!(context, vi, logp)
215221
end
@@ -420,8 +426,12 @@ model inputs), accumulate the log probability, and return the sampled value and
420426
Falls back to `dot_tilde_assume(context, right, left, vn, vi)`.
421427
"""
422428
function dot_tilde_assume!!(context, right, left, vn, vi)
423-
value, logp, vi = dot_tilde_assume(context, right, left, vn, vi)
424-
return value, acclogp_assume!!(context, vi, logp), vi
429+
return if is_rhs_model(right)
430+
rand_like!!(right, context, vi)
431+
else
432+
value, logp, vi = dot_tilde_assume(context, right, left, vn, vi)
433+
value, acclogp_assume!!(context, vi, logp)
434+
end
425435
end
426436

427437
# `dot_assume`
@@ -672,6 +682,7 @@ Falls back to `dot_tilde_observe!!(context, right, left, vi)` ignoring the infor
672682
name and indices; if needed, these can be accessed through this function, though.
673683
"""
674684
function dot_tilde_observe!!(context, right, left, vn, vi)
685+
is_rhs_model(right) && throw(ArgumentError("`~` with a model on the right-hand side of an observe statement is not supported"))
675686
return dot_tilde_observe!!(context, right, left, vi)
676687
end
677688

@@ -684,6 +695,7 @@ probability, and return the observed value and updated `vi`.
684695
Falls back to `dot_tilde_observe(context, right, left, vi)`.
685696
"""
686697
function dot_tilde_observe!!(context, right, left, vi)
698+
is_rhs_model(right) && throw(ArgumentError("`~` with a model on the right-hand side of an observe statement is not supported"))
687699
logp, vi = dot_tilde_observe(context, right, left, vi)
688700
return left, acclogp_observe!!(context, vi, logp)
689701
end

0 commit comments

Comments
 (0)