-
Notifications
You must be signed in to change notification settings - Fork 36
Depreciate@submodel l ~ m
in favour of l ~ to_submodel(m)
; rename generated_quantities
to returned
#696
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Depreciate@submodel l ~ m
in favour of l ~ to_submodel(m)
; rename generated_quantities
to returned
#696
Changes from 3 commits
5c746c4
0b081b7
dc699a5
7067695
8cb0796
2d887c9
692cfff
32fd6b9
5478fb3
5fe65b3
9e0730f
cc3af46
720053a
fe0403f
55b95a1
34fb6bd
9a7e18f
7aef65b
5ee727b
d92141c
64b519d
1b48f65
db2102c
da95aba
c8d567f
d477137
4896793
946fa6d
bf35de4
0f20624
99d99b3
0597b2a
0c6bada
5134ff7
45451f7
c00a9ae
f0af1d5
1b231a9
1faa627
92ac6b9
f73d1b0
b7b2e1d
ed4bb76
36f02f6
98538c5
d316306
0e05901
f073b25
2ec03c1
1f70dfc
f645259
23355ea
0e82a60
b9017c4
6e149a3
933e4ed
4fc7b76
b421687
c71242f
5c289c5
3c204d9
044f6c3
f716296
76aebc5
c150a87
b95e7d5
ecb4737
1e238ca
3525765
ed0cec3
13a2bf7
f94e07a
86d0e4c
d03eb4c
7c7ecc3
341b6b8
b467c75
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -248,3 +248,185 @@ | |
end | ||
end | ||
end | ||
|
||
""" | ||
@returned_quantities [prefix=...] model | ||
|
||
Run `model` nested inside of another model and return the return-values of the `model`. | ||
|
||
Valid expressions for `prefix=...` are: | ||
- `prefix=false`: no prefix is used. This is the default. | ||
- `prefix=expression`: results in the prefix `Symbol(expression)`. | ||
|
||
Prefixing makes it possible to run the same model multiple times while keeping track of | ||
all random variables correctly, i.e. without name clashes. | ||
|
||
# Examples | ||
|
||
## Simple example | ||
```jldoctest submodel-returned-quantities; setup=:(using Distributions) | ||
julia> @model function demo1(x) | ||
x ~ Normal() | ||
return 1 + abs(x) | ||
end; | ||
|
||
julia> @model function demo2(x, y) | ||
a = @returned_quantities(demo1(x)) | ||
return y ~ Uniform(0, a) | ||
end; | ||
``` | ||
|
||
When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: | ||
```jldoctest submodel-returned-quantities | ||
julia> vi = VarInfo(demo2(missing, 0.4)); | ||
|
||
julia> @varname(x) in keys(vi) | ||
true | ||
``` | ||
|
||
Variable `a` is not tracked since it can be computed from the random variable `x` that was | ||
tracked when running `demo1`: | ||
```jldoctest submodel-returned-quantities | ||
julia> @varname(a) in keys(vi) | ||
false | ||
``` | ||
|
||
We can check that the log joint probability of the model accumulated in `vi` is correct: | ||
|
||
```jldoctest submodel-returned-quantities | ||
julia> x = vi[@varname(x)]; | ||
|
||
julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) | ||
true | ||
``` | ||
|
||
## With prefixing | ||
```jldoctest submodel-returned-quantities-prefix; setup=:(using Distributions) | ||
julia> @model function demo1(x) | ||
x ~ Normal() | ||
return 1 + abs(x) | ||
end; | ||
|
||
julia> @model function demo2(x, y, z) | ||
a = @returned_quantities prefix="sub1" demo1(x) | ||
b = @returned_quantities prefix="sub2" demo1(y) | ||
return z ~ Uniform(-a, b) | ||
end; | ||
``` | ||
|
||
When we sample from the model `demo2(missing, missing, 0.4)` random variables `sub1.x` and | ||
`sub2.x` will be sampled: | ||
```jldoctest submodel-returned-quantities-prefix | ||
julia> vi = VarInfo(demo2(missing, missing, 0.4)); | ||
|
||
julia> @varname(var"sub1.x") in keys(vi) | ||
true | ||
|
||
julia> @varname(var"sub2.x") in keys(vi) | ||
true | ||
``` | ||
|
||
Variables `a` and `b` are not tracked since they can be computed from the random variables `sub1.x` and | ||
`sub2.x` that were tracked when running `demo1`: | ||
```jldoctest submodel-returned-quantities-prefix | ||
julia> @varname(a) in keys(vi) | ||
false | ||
|
||
julia> @varname(b) in keys(vi) | ||
false | ||
``` | ||
|
||
We can check that the log joint probability of the model accumulated in `vi` is correct: | ||
|
||
```jldoctest submodel-returned-quantities-prefix | ||
julia> sub1_x = vi[@varname(var"sub1.x")]; | ||
|
||
julia> sub2_x = vi[@varname(var"sub2.x")]; | ||
|
||
julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); | ||
|
||
julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); | ||
|
||
julia> getlogp(vi) ≈ logprior + loglikelihood | ||
true | ||
``` | ||
|
||
## Different ways of setting the prefix | ||
```jldoctest submodel-returned-quantities-prefix-alts; setup=:(using DynamicPPL, Distributions) | ||
julia> @model inner() = x ~ Normal() | ||
inner (generic function with 2 methods) | ||
|
||
julia> # When `prefix` is unspecified, no prefix is used. | ||
@model submodel_noprefix() = a = @returned_quantities inner() | ||
submodel_noprefix (generic function with 2 methods) | ||
|
||
julia> @varname(x) in keys(VarInfo(submodel_noprefix())) | ||
true | ||
|
||
julia> # Explicitely don't use any prefix. | ||
@model submodel_prefix_false() = a = @returned_quantities prefix=false inner() | ||
submodel_prefix_false (generic function with 2 methods) | ||
|
||
julia> @varname(x) in keys(VarInfo(submodel_prefix_false())) | ||
true | ||
|
||
julia> # Using a static string. | ||
@model submodel_prefix_string() = a = @returned_quantities prefix="my prefix" inner() | ||
submodel_prefix_string (generic function with 2 methods) | ||
|
||
julia> @varname(var"my prefix.x") in keys(VarInfo(submodel_prefix_string())) | ||
true | ||
|
||
julia> # Using string interpolation. | ||
@model submodel_prefix_interpolation() = a = @returned_quantities prefix="\$(nameof(inner()))" inner() | ||
submodel_prefix_interpolation (generic function with 2 methods) | ||
|
||
julia> @varname(var"inner.x") in keys(VarInfo(submodel_prefix_interpolation())) | ||
true | ||
|
||
julia> # Or using some arbitrary expression. | ||
@model submodel_prefix_expr() = a = @returned_quantities prefix=1 + 2 inner() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I found
hard and unintuitive to parse. I think
would be much clearer. Not sure if this a documentation issue, or if we should disallow the former. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a documentation issue IMO, as this is not doing any special parsing but reliying on Julia's expression parsing. |
||
submodel_prefix_expr (generic function with 2 methods) | ||
|
||
julia> @varname(var"3.x") in keys(VarInfo(submodel_prefix_expr())) | ||
true | ||
``` | ||
""" | ||
macro returned_quantities(expr) | ||
return returned_quantities_expr(:(prefix = false), expr) | ||
end | ||
|
||
macro returned_quantities(prefix_expr, expr) | ||
return returned_quantities_expr(prefix_expr, expr) | ||
end | ||
|
||
""" | ||
@returned_quantities_expr model | ||
|
||
Returns an expression that captures the return-values of a model in addition to the varinfo. | ||
""" | ||
function returned_quantities_expr(prefix_expr, expr, ctx=esc(:__context__)) | ||
mhauru marked this conversation as resolved.
Show resolved
Hide resolved
|
||
prefix_left, prefix = getargs_assignment(prefix_expr) | ||
if prefix_left !== :prefix | ||
error("$(prefix_left) is not a valid kwarg") | ||
end | ||
|
||
# The user expects `@submodel ...` to return the | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# return-value of the `...`, hence we need to capture | ||
# the return-value and handle it correctly. | ||
@gensym retval | ||
|
||
# Prefix. | ||
if prefix !== nothing | ||
ctx = prefix_submodel_context(prefix, ctx) | ||
end | ||
return quote | ||
# Evaluate the model and capture the return values + varinfo. | ||
$retval, $(esc(:__varinfo__)) = $(_evaluate!!)( | ||
$(esc(expr)), $(esc(:__varinfo__)), $(ctx) | ||
) | ||
|
||
# Return the return-value of the model. | ||
$retval | ||
end | ||
end |
Uh oh!
There was an error while loading. Please reload this page.