Skip to content

Commit 7879726

Browse files
authored
Merge pull request #53 from phipsgabler/phg/split_tilde
Split up tilde functions based on usage
2 parents 312e7da + 63ca3f4 commit 7879726

File tree

4 files changed

+133
-52
lines changed

4 files changed

+133
-52
lines changed

src/compiler.jl

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -27,45 +27,44 @@ function wrong_dist_errormsg(l)
2727
end
2828

2929
"""
30-
@preprocess(data_vars, missing_vars, ex)
30+
@isassumption(model, expr)
3131
32-
Let `ex` be `x[1]`. This macro returns `@varname x[1]` in any of the following cases:
32+
Let `expr` be `x[1]`. `vn` is an assumption in the following cases:
3333
1. `x` was not among the input data to the model,
3434
2. `x` was among the input data to the model but with a value `missing`, or
3535
3. `x` was among the input data to the model with a value other than missing,
36-
but `x[1] === missing`.
37-
Otherwise, the value of `x[1]` is returned.
36+
but `x[1] === missing`.
37+
When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
3838
"""
39-
macro preprocess(data_vars, missing_vars, ex)
40-
ex
41-
end
42-
macro preprocess(model, ex::Union{Symbol, Expr})
43-
sym = gensym(:sym)
44-
lhs = gensym(:lhs)
45-
return esc(quote
46-
# Extract symbol
47-
$sym = Val($(vsym(ex)))
39+
macro isassumption(model, expr::Union{Symbol, Expr})
40+
# Note: never put a return in this... don't forget it's a macro!
41+
vn = gensym(:vn)
42+
43+
return quote
44+
$vn = @varname($expr)
45+
4846
# This branch should compile nicely in all cases except for partial missing data
49-
# For example, when `ex` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
50-
if !DynamicPPL.inargnames($sym, $model) || DynamicPPL.inmissings($sym, $model)
51-
$(varname(ex)), $(vinds(ex))
47+
# For example, when `expr` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
48+
if !DynamicPPL.inargnames($vn, $model) || DynamicPPL.inmissings($vn, $model)
49+
true
5250
else
53-
if DynamicPPL.inargnames($sym, $model)
51+
if DynamicPPL.inargnames($vn, $model)
5452
# Evaluate the lhs
55-
$lhs = $ex
56-
if $lhs === missing
57-
$(varname(ex)), $(vinds(ex))
58-
else
59-
$lhs
60-
end
53+
$expr === missing
6154
else
6255
throw("This point should not be reached. Please report this error.")
6356
end
6457
end
65-
end)
58+
end |> esc
59+
end
60+
61+
macro isassumption(model, expr)
62+
# failsafe: a literal is never an assumption
63+
false
6664
end
6765

6866

67+
6968
#################
7069
# Main Compiler #
7170
#################
@@ -300,32 +299,36 @@ function generate_tilde(left, right, model_info)
300299
lp = gensym(:lp)
301300
vn = gensym(:vn)
302301
inds = gensym(:inds)
303-
preprocessed = gensym(:preprocessed)
302+
isassumption = gensym(:isassumption)
304303
assert_ex = :(DynamicPPL.assert_dist($temp_right, msg = $(wrong_dist_errormsg(@__LINE__))))
304+
305305
if left isa Symbol || left isa Expr
306306
ex = quote
307307
$temp_right = $right
308308
$assert_ex
309-
$preprocessed = DynamicPPL.@preprocess($model, $left)
310-
if $preprocessed isa Tuple
311-
$vn, $inds = $preprocessed
312-
$out = DynamicPPL.tilde($ctx, $sampler, $temp_right, $vn, $inds, $vi)
309+
310+
$vn, $inds = $(varname(left)), $(vinds(left))
311+
$isassumption = DynamicPPL.@isassumption($model, $left)
312+
if $isassumption
313+
$out = DynamicPPL.tilde_assume($ctx, $sampler, $temp_right, $vn, $inds, $vi)
313314
$left = $out[1]
314315
DynamicPPL.acclogp!($vi, $out[2])
315316
else
316317
DynamicPPL.acclogp!(
317318
$vi,
318-
DynamicPPL.tilde($ctx, $sampler, $temp_right, $preprocessed, $vi),
319+
DynamicPPL.tilde_observe($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi),
319320
)
320321
end
321322
end
322323
else
324+
# we have a literal, which is automatically an observation
323325
ex = quote
324326
$temp_right = $right
325327
$assert_ex
328+
326329
DynamicPPL.acclogp!(
327330
$vi,
328-
DynamicPPL.tilde($ctx, $sampler, $temp_right, $left, $vi),
331+
DynamicPPL.tilde_observe($ctx, $sampler, $temp_right, $left, $vi),
329332
)
330333
end
331334
end
@@ -335,48 +338,51 @@ end
335338
"""
336339
generate_dot_tilde(left, right, model_info)
337340
338-
This function returns the expression that replaces `left .~ right` in the model body. If `preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block will be run.
341+
This function returns the expression that replaces `left .~ right` in the model body. If
342+
`preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block
343+
will be run.
339344
"""
340345
function generate_dot_tilde(left, right, model_info)
341346
model = model_info[:main_body_names][:model]
342347
vi = model_info[:main_body_names][:vi]
343348
ctx = model_info[:main_body_names][:ctx]
344349
sampler = model_info[:main_body_names][:sampler]
345350
out = gensym(:out)
346-
temp_left = gensym(:temp_left)
347351
temp_right = gensym(:temp_right)
348-
preprocessed = gensym(:preprocessed)
352+
isassumption = gensym(:isassumption)
349353
lp = gensym(:lp)
350354
vn = gensym(:vn)
351355
inds = gensym(:inds)
352356
assert_ex = :(DynamicPPL.assert_dist($temp_right, msg = $(wrong_dist_errormsg(@__LINE__))))
357+
353358
if left isa Symbol || left isa Expr
354359
ex = quote
355360
$temp_right = $right
356361
$assert_ex
357-
$preprocessed = DynamicPPL.@preprocess($model, $left)
358-
if $preprocessed isa Tuple
359-
$vn, $inds = $preprocessed
360-
$temp_left = $left
361-
$out = DynamicPPL.dot_tilde($ctx, $sampler, $temp_right, $temp_left, $vn, $inds, $vi)
362+
363+
$vn, $inds = $(varname(left)), $(vinds(left))
364+
$isassumption = DynamicPPL.@isassumption($model, $left)
365+
366+
if $isassumption
367+
$out = DynamicPPL.dot_tilde_assume($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi)
362368
$left .= $out[1]
363369
DynamicPPL.acclogp!($vi, $out[2])
364370
else
365-
$temp_left = $preprocessed
366371
DynamicPPL.acclogp!(
367372
$vi,
368-
DynamicPPL.dot_tilde($ctx, $sampler, $temp_right, $temp_left, $vi),
373+
DynamicPPL.dot_tilde_observe($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi),
369374
)
370375
end
371376
end
372377
else
378+
# we have a literal, which is automatically an observation
373379
ex = quote
374-
$temp_left = $left
375380
$temp_right = $right
376381
$assert_ex
382+
377383
DynamicPPL.acclogp!(
378384
$vi,
379-
DynamicPPL.dot_tilde($ctx, $sampler, $temp_right, $temp_left, $vi),
385+
DynamicPPL.dot_tilde_observe($ctx, $sampler, $temp_right, $left, $vi),
380386
)
381387
end
382388
end
@@ -416,7 +422,7 @@ function build_output(model_info)
416422
model_gen = model_info[:name]
417423
# Main body of the model
418424
main_body = model_info[:main_body]
419-
425+
420426
unwrap_data_expr = Expr(:block)
421427
for var in arg_syms
422428
temp_var = gensym(:temp_var)

src/context_implementations.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ function tilde(ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi)
3535
return tilde(ctx.ctx, sampler, right, left, inds, vi)
3636
end
3737

38+
"""
39+
tilde_assume(ctx, sampler, right, vn, inds, vi)
40+
41+
This method is applied in the generated code for assumed variables, e.g., `x ~ Normal()` where
42+
`x` does not occur in the model inputs.
43+
44+
Falls back to `tilde(ctx, sampler, right, vn, inds, vi)`.
45+
"""
46+
function tilde_assume(ctx, sampler, right, vn, inds, vi)
47+
return tilde(ctx, sampler, right, vn, inds, vi)
48+
end
49+
50+
3851
function _tilde(sampler, right, vn::VarName, vi)
3952
return assume(sampler, right, vn, vi)
4053
end
@@ -68,6 +81,30 @@ function tilde(ctx::MiniBatchContext, sampler, right, left, vi)
6881
return ctx.loglike_scalar * tilde(ctx.ctx, sampler, right, left, vi)
6982
end
7083

84+
"""
85+
tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
86+
87+
This method is applied in the generated code for observed variables, e.g., `x ~ Normal()` where
88+
`x` does occur in the model inputs.
89+
90+
Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable
91+
name and indices; if needed, these can be accessed through this function, though.
92+
"""
93+
function tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
94+
return tilde(ctx, sampler, right, left, vi)
95+
end
96+
97+
"""
98+
tilde_observe(ctx, sampler, right, left, vi)
99+
100+
This method is applied in the generated code for observed constants, e.g., `1.0 ~ Normal()`.
101+
Falls back to `tilde(ctx, sampler, right, left, vi)`.
102+
"""
103+
function tilde_observe(ctx, sampler, right, left, vi)
104+
return tilde(ctx, sampler, right, left, vi)
105+
end
106+
107+
71108
_tilde(sampler, right, left, vi) = observe(sampler, right, left, vi)
72109

73110
function assume(spl::Sampler, dist)
@@ -163,6 +200,19 @@ function dot_tilde(
163200
return _dot_tilde(sampler, dist, left, vns, vi)
164201
end
165202

203+
"""
204+
dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
205+
206+
This method is applied in the generated code for assumed vectorized variables, e.g., `x .~
207+
MvNormal()` where `x` does not occur in the model inputs.
208+
209+
Falls back to `dot_tilde(ctx, sampler, right, left, vn, inds, vi)`.
210+
"""
211+
function dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
212+
return dot_tilde(ctx, sampler, right, left, vn, inds, vi)
213+
end
214+
215+
166216
function get_vns_and_dist(dist::NamedDist, var, vn::VarName)
167217
name = dist.name
168218
if name isa String
@@ -337,6 +387,30 @@ function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi)
337387
return ctx.loglike_scalar * dot_tilde(ctx.ctx, sampler, right, left, left, vi)
338388
end
339389

390+
"""
391+
dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
392+
393+
This method is applied in the generated code for vectorized observed variables, e.g., `x .~
394+
MvNormal()` where `x` does occur the model inputs.
395+
396+
Falls back to `dot_tilde(ctx, sampler, right, left, vi)` ignoring the information about variable
397+
name and indices; if needed, these can be accessed through this function, though.
398+
"""
399+
function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi)
400+
return dot_tilde(ctx, sampler, right, left, vi)
401+
end
402+
403+
"""
404+
dot_tilde_observe(ctx, sampler, right, left, vi)
405+
406+
This method is applied in the generated code for vectorized observed constants, e.g., `[1.0] .~
407+
MvNormal()`. Falls back to `dot_tilde(ctx, sampler, right, left, vi)`.
408+
"""
409+
function dot_tilde_observe(ctx, sampler, right, left, vi)
410+
return dot_tilde(ctx, sampler, right, left, vi)
411+
end
412+
413+
340414
function _dot_tilde(sampler, right, left::AbstractArray, vi)
341415
return dot_observe(sampler, right, left, vi)
342416
end

src/model.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,6 @@ Get a tuple of the argument names of the `model`.
147147
"""
148148
getargnames(model::Model{_F, argnames}) where {argnames, _F} = argnames
149149

150-
@generated function inargnames(::Val{s}, ::Model{_F, argnames}) where {s, argnames, _F}
151-
return s in argnames
152-
end
153-
154150

155151
"""
156152
getmissings(model::Model)
@@ -162,10 +158,6 @@ getmissings(model::Model{_F, _a, _T, missings}) where {missings, _F, _a, _T} = m
162158
getmissing(model::Model) = getmissings(model)
163159
@deprecate getmissing(model) getmissings(model)
164160

165-
@generated function inmissings(::Val{s}, ::Model{_F, _a, _T, missings}) where {s, missings, _F, _a, _T}
166-
return s in missings
167-
end
168-
169161

170162
"""
171163
getgenerator(model::Model)

src/varname.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,12 @@ function split_var_str(var_str, inds_as = Vector)
134134
end
135135
return sym, inds
136136
end
137+
138+
139+
@generated function inargnames(::VarName{s}, ::Model{_F, argnames}) where {s, argnames, _F}
140+
return s in argnames
141+
end
142+
143+
@generated function inmissings(::VarName{s}, ::Model{_F, _a, _T, missings}) where {s, missings, _F, _a, _T}
144+
return s in missings
145+
end

0 commit comments

Comments
 (0)