Skip to content

Commit b25fdab

Browse files
torfjeldedevmotion
andcommitted
Small simplification of compiler (#221)
## Overview At the moment, we perform a check at model-expansion as to whether or not `vsym(left) in args`, where `args` is the arguments of the model. 1. If `true`, we return a block of code which uses `DynamicPPL.isassumption` to check whether or not to call `assume` or `observe` for the the variable present in `args`. 2. Otherwise, we generate a block which is identical to the `assume` block in the if-statement mentioned in (1). The thing is, `DynamicPPL.isassumption` performs exactly the same check as above but using `DynamicPPL.inargnames`, i.e. at runtime. So if we're using `TypedVarInfo`, the check at macro-expansion vs. at runtime is completely redundant since all the information necessary to determine `DynamicPPL.inargnames` is available at compile-time. Therefore I suggest we remove this check at model-expansion, and simply handle it using `DynamicPPL.isassumption`. ## Pros & cons Pros: - No need to pass `args` around everywhere - `generate_tilde` and `generate_dot_tilde` are much simpler: two possible blocks we can generate, either a) assume/observe, or b) observe literal. Cons: - We need to perform _one_ more check at runtime when using `UntypedVarInfo`. **IMO, this is really worth it.** ## Motivation (sort of) The main motivation behind this PR is simplification, but there's a different reason why I came across this. I came to this because I was thinking about trying to "customize" the behavior of `~`, and I was thinking of using a macro to do it, e.g. `@mymacro x ~ Normal()`. Atm we're actually performing model-expansion on the code passed to the macro and thus trying to alter the way DynamicPPL treats `~` using a macro is veeeery difficult since you actually have to work with the *expanded* code, but let's ignore that issue for now (and take that discussion somewhere else, because IMO we shouldn't do this). Suppose we didn't perform model-expansions of the code fed to the macros, then you can just copy-paste `generate_tilde`, customize it do what you want, and BAM, you got yourself a working `@mymacro x ~ Normal()` which can do neat stuff! This is *not* possible atm because we don't have access to `args`, and so you have to take the approach in this PR to get there. That means that it's of course possible to do atm, but it's a bit icky since it ends up looking fundamentally different from `generate_tilde` rather than just slightly different. Then we can implement things like a `@tilde` which will expand to `generate_tilde` which can be used *internally* in functions (if the "internal" variables are present in the functions of course, but we can also simplify this in different ways), actually allowing people to modularize their models a bit, and `@reparam` from #220 using very similar pieces of code, a `@track` macro can be introduced to deal with the explicit tracking of variables rather than putting this directly in the compiler, etc. Endless opportunities! (Of course, I'm not suggesting we add these, but this makes it a bit easier to explore.) Co-authored-by: David Widmann <[email protected]>
1 parent 068e5d3 commit b25fdab

File tree

3 files changed

+40
-58
lines changed

3 files changed

+40
-58
lines changed

.github/workflows/IntegrationTest.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
- uses: actions/checkout@v2
2525
- uses: julia-actions/setup-julia@v1
2626
with:
27-
version: 1
27+
version: 1.5
2828
arch: x64
2929
- uses: julia-actions/julia-buildpkg@latest
3030
- name: Clone Downstream

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.10.11"
3+
version = "0.10.12"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/compiler.jl

Lines changed: 38 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ function model(mod, linenumbernode, expr, warn)
7272

7373
# Generate main body
7474
modelinfo[:body] = generate_mainbody(
75-
mod, modelinfo[:modeldef][:body], modelinfo[:allargs_syms], warn
75+
mod, modelinfo[:modeldef][:body], warn
7676
)
7777

7878
return build_output(modelinfo, linenumbernode)
@@ -155,92 +155,84 @@ function build_model_info(input_expr)
155155
end
156156

157157
"""
158-
generate_mainbody(mod, expr, args, warn)
158+
generate_mainbody(mod, expr, warn)
159159
160160
Generate the body of the main evaluation function from expression `expr` and arguments
161161
`args`.
162162
163163
If `warn` is true, a warning is displayed if internal variables are used in the model
164164
definition.
165165
"""
166-
generate_mainbody(mod, expr, args, warn) = generate_mainbody!(mod, Symbol[], expr, args, warn)
166+
generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn)
167167

168-
generate_mainbody!(mod, found, x, args, warn) = x
169-
function generate_mainbody!(mod, found, sym::Symbol, args, warn)
168+
generate_mainbody!(mod, found, x, warn) = x
169+
function generate_mainbody!(mod, found, sym::Symbol, warn)
170170
if warn && sym in INTERNALNAMES && sym found
171171
@warn "you are using the internal variable `$(sym)`"
172172
push!(found, sym)
173173
end
174174
return sym
175175
end
176-
function generate_mainbody!(mod, found, expr::Expr, args, warn)
176+
function generate_mainbody!(mod, found, expr::Expr, warn)
177177
# Do not touch interpolated expressions
178178
expr.head === :$ && return expr.args[1]
179179

180180
# If it's a macro, we expand it
181181
if Meta.isexpr(expr, :macrocall)
182-
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), args, warn)
182+
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
183183
end
184184

185185
# Modify dotted tilde operators.
186186
args_dottilde = getargs_dottilde(expr)
187187
if args_dottilde !== nothing
188188
L, R = args_dottilde
189-
return generate_dot_tilde(generate_mainbody!(mod, found, L, args, warn),
190-
generate_mainbody!(mod, found, R, args, warn),
191-
args) |> Base.remove_linenums!
189+
return generate_dot_tilde(
190+
generate_mainbody!(mod, found, L, warn),
191+
generate_mainbody!(mod, found, R, warn),
192+
) |> Base.remove_linenums!
192193
end
193194

194195
# Modify tilde operators.
195196
args_tilde = getargs_tilde(expr)
196197
if args_tilde !== nothing
197198
L, R = args_tilde
198-
return generate_tilde(generate_mainbody!(mod, found, L, args, warn),
199-
generate_mainbody!(mod, found, R, args, warn),
200-
args) |> Base.remove_linenums!
199+
return generate_tilde(
200+
generate_mainbody!(mod, found, L, warn),
201+
generate_mainbody!(mod, found, R, warn),
202+
) |> Base.remove_linenums!
201203
end
202204

203-
return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, args, warn), expr.args)...)
205+
return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
204206
end
205207

206208

207209

208210
"""
209-
generate_tilde(left, right, args)
211+
generate_tilde(left, right)
210212
211213
Generate an `observe` expression for data variables and `assume` expression for parameter
212214
variables.
213215
"""
214-
function generate_tilde(left, right, args)
216+
function generate_tilde(left, right)
215217
@gensym tmpright
216218
top = [:($tmpright = $right),
217219
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
218220
|| throw(ArgumentError($DISTMSG)))]
219221

220222
if left isa Symbol || left isa Expr
221-
@gensym out vn inds
223+
@gensym out vn inds isassumption
222224
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))
223225

224-
# It can only be an observation if the LHS is an argument of the model
225-
if vsym(left) in args
226-
@gensym isassumption
227-
return quote
228-
$(top...)
229-
$isassumption = $(DynamicPPL.isassumption(left))
230-
if $isassumption
231-
$left = $(DynamicPPL.tilde_assume)(
232-
_rng, _context, _sampler, $tmpright, $vn, $inds, _varinfo)
233-
else
234-
$(DynamicPPL.tilde_observe)(
235-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
236-
end
237-
end
238-
end
239-
240226
return quote
241227
$(top...)
242-
$left = $(DynamicPPL.tilde_assume)(_rng, _context, _sampler, $tmpright, $vn,
243-
$inds, _varinfo)
228+
$isassumption = $(DynamicPPL.isassumption(left))
229+
if $isassumption
230+
$left = $(DynamicPPL.tilde_assume)(
231+
_rng, _context, _sampler, $tmpright, $vn, $inds, _varinfo)
232+
else
233+
$(DynamicPPL.tilde_observe)(
234+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
235+
end
244236
end
245237
end
246238

@@ -252,40 +244,30 @@ function generate_tilde(left, right, args)
252244
end
253245

254246
"""
255-
generate_dot_tilde(left, right, args)
247+
generate_dot_tilde(left, right)
256248
257249
Generate the expression that replaces `left .~ right` in the model body.
258250
"""
259-
function generate_dot_tilde(left, right, args)
251+
function generate_dot_tilde(left, right)
260252
@gensym tmpright
261253
top = [:($tmpright = $right),
262254
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
263255
|| throw(ArgumentError($DISTMSG)))]
264256

265257
if left isa Symbol || left isa Expr
266-
@gensym out vn inds
258+
@gensym out vn inds isassumption
267259
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))
268260

269-
# It can only be an observation if the LHS is an argument of the model
270-
if vsym(left) in args
271-
@gensym isassumption
272-
return quote
273-
$(top...)
274-
$isassumption = $(DynamicPPL.isassumption(left))
275-
if $isassumption
276-
$left .= $(DynamicPPL.dot_tilde_assume)(
277-
_rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
278-
else
279-
$(DynamicPPL.dot_tilde_observe)(
280-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
281-
end
282-
end
283-
end
284-
285261
return quote
286262
$(top...)
287-
$left .= $(DynamicPPL.dot_tilde_assume)(
288-
_rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
263+
$isassumption = $(DynamicPPL.isassumption(left)) || $left === missing
264+
if $isassumption
265+
$left .= $(DynamicPPL.dot_tilde_assume)(
266+
_rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
267+
else
268+
$(DynamicPPL.dot_tilde_observe)(
269+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
270+
end
289271
end
290272
end
291273

0 commit comments

Comments
 (0)