Skip to content

Commit 54b2977

Browse files
committed
Split up tilde namings in backwards-compatile way; some refactoring
1 parent 08561a6 commit 54b2977

File tree

2 files changed

+82
-39
lines changed

2 files changed

+82
-39
lines changed

src/compiler.jl

Lines changed: 55 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,19 @@ function wrong_dist_errormsg(l)
4949
end
5050

5151
"""
52-
@preprocess(data_vars, missing_vars, ex)
52+
@isassumption(data_vars, missing_vars, ex)
5353
54-
Let `ex` be `x[1]`. This macro returns `@varname x[1]` in any of the following cases:
54+
Let `ex` be `x[1]`. This macro returns `true` in any of the following cases:
5555
1. `x` was not among the input data to the model,
5656
2. `x` was among the input data to the model but with a value `missing`, or
5757
3. `x` was among the input data to the model with a value other than missing,
58-
but `x[1] === missing`.
59-
Otherwise, the value of `x[1]` is returned.
58+
but `x[1] === missing`.
59+
When `ex` is not a variable (e.g., a literal), the function returns `false` as well.
6060
"""
61-
macro preprocess(data_vars, missing_vars, ex)
62-
ex
61+
macro isassumption(data_vars, missing_vars, ex)
62+
:false
6363
end
64-
macro preprocess(data_vars, missing_vars, ex::Union{Symbol, Expr})
64+
macro isassumption(data_vars, missing_vars, ex::Union{Symbol, Expr})
6565
sym = gensym(:sym)
6666
lhs = gensym(:lhs)
6767
return esc(quote
@@ -70,22 +70,23 @@ macro preprocess(data_vars, missing_vars, ex::Union{Symbol, Expr})
7070
# This branch should compile nicely in all cases except for partial missing data
7171
# For example, when `ex` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
7272
if !DynamicPPL.inparams($sym, $data_vars) || DynamicPPL.inparams($sym, $missing_vars)
73-
$(varname(ex)), $(vinds(ex))
73+
true
7474
else
7575
if DynamicPPL.inparams($sym, $data_vars)
7676
# Evaluate the lhs
7777
$lhs = $ex
7878
if $lhs === missing
79-
$(varname(ex)), $(vinds(ex))
79+
true
8080
else
81-
$lhs
81+
false
8282
end
8383
else
8484
throw("This point should not be reached. Please report this error.")
8585
end
8686
end
8787
end)
8888
end
89+
8990
@generated function inparams(::Val{s}, ::Val{t}) where {s, t}
9091
return (s in t) ? :(true) : :(false)
9192
end
@@ -319,6 +320,9 @@ function replace_tilde!(model_info)
319320
end
320321
""" |> Meta.parse |> eval
321322

323+
# """ Unbreak code highlighting in Emacs julia-mode
324+
325+
322326
"""
323327
generate_tilde(left, right, model_info)
324328
@@ -331,37 +335,43 @@ function generate_tilde(left, right, model_info)
331335
vi = model_info[:main_body_names][:vi]
332336
ctx = model_info[:main_body_names][:ctx]
333337
sampler = model_info[:main_body_names][:sampler]
334-
temp_right = gensym(:temp_right)
335-
out = gensym(:out)
336-
lp = gensym(:lp)
337-
vn = gensym(:vn)
338-
inds = gensym(:inds)
339-
preprocessed = gensym(:preprocessed)
338+
339+
@gensym(out,
340+
lp,
341+
vn,
342+
inds,
343+
isassumption,
344+
temp_right)
345+
340346
assert_ex = :(DynamicPPL.assert_dist($temp_right, msg = $(wrong_dist_errormsg(@__LINE__))))
347+
341348
if left isa Symbol || left isa Expr
342349
ex = quote
343350
$temp_right = $right
344351
$assert_ex
345-
$preprocessed = DynamicPPL.@preprocess($arg_syms, DynamicPPL.getmissing($model), $left)
346-
if $preprocessed isa Tuple
347-
$vn, $inds = $preprocessed
348-
$out = DynamicPPL.tilde($ctx, $sampler, $temp_right, $vn, $inds, $vi)
352+
353+
$vn, $inds = $(varname(left)), $(vinds(left))
354+
$isassumption = DynamicPPL.@isassumption($arg_syms, DynamicPPL.getmissing($model), $left)
355+
if $isassumption
356+
$out = DynamicPPL.tilde_assume($ctx, $sampler, $temp_right, $vn, $inds, $vi)
349357
$left = $out[1]
350358
DynamicPPL.acclogp!($vi, $out[2])
351359
else
352360
DynamicPPL.acclogp!(
353361
$vi,
354-
DynamicPPL.tilde($ctx, $sampler, $temp_right, $preprocessed, $vi),
362+
DynamicPPL.tilde_observe($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi),
355363
)
356364
end
357365
end
358366
else
367+
# we have a literal, which is automatically an observation
359368
ex = quote
360369
$temp_right = $right
361370
$assert_ex
371+
362372
DynamicPPL.acclogp!(
363373
$vi,
364-
DynamicPPL.tilde($ctx, $sampler, $temp_right, $left, $vi),
374+
DynamicPPL.tilde_observe($ctx, $sampler, $temp_right, $left, $vi),
365375
)
366376
end
367377
end
@@ -371,49 +381,55 @@ end
371381
"""
372382
generate_dot_tilde(left, right, model_info)
373383
374-
This function returns the expression that replaces `left .~ right` in the model body. If `preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block will be run.
384+
This function returns the expression that replaces `left .~ right` in the model body. If
385+
`preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block
386+
will be run.
375387
"""
376388
function generate_dot_tilde(left, right, model_info)
377389
arg_syms = Val((model_info[:arg_syms]...,))
378390
model = model_info[:main_body_names][:model]
379391
vi = model_info[:main_body_names][:vi]
380392
ctx = model_info[:main_body_names][:ctx]
381393
sampler = model_info[:main_body_names][:sampler]
382-
out = gensym(:out)
383-
temp_left = gensym(:temp_left)
384-
temp_right = gensym(:temp_right)
385-
preprocessed = gensym(:preprocessed)
386-
lp = gensym(:lp)
387-
vn = gensym(:vn)
388-
inds = gensym(:inds)
394+
395+
@gensym(out,
396+
preprocessed,
397+
lp,
398+
vn,
399+
inds,
400+
isassumption,
401+
temp_right)
402+
389403
assert_ex = :(DynamicPPL.assert_dist($temp_right, msg = $(wrong_dist_errormsg(@__LINE__))))
404+
390405
if left isa Symbol || left isa Expr
391406
ex = quote
392407
$temp_right = $right
393408
$assert_ex
394-
$preprocessed = DynamicPPL.@preprocess($arg_syms, DynamicPPL.getmissing($model), $left)
395-
if $preprocessed isa Tuple
396-
$vn, $inds = $preprocessed
397-
$temp_left = $left
398-
$out = DynamicPPL.dot_tilde($ctx, $sampler, $temp_right, $temp_left, $vn, $inds, $vi)
409+
410+
$vn, $inds = $(varname(left)), $(vinds(left))
411+
$isassumption = DynamicPPL.@isassumption($arg_syms, DynamicPPL.getmissing($model), $left)
412+
413+
if $isassumption
414+
$out = DynamicPPL.dot_tilde_assume($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi)
399415
$left .= $out[1]
400416
DynamicPPL.acclogp!($vi, $out[2])
401417
else
402-
$temp_left = $preprocessed
403418
DynamicPPL.acclogp!(
404419
$vi,
405-
DynamicPPL.dot_tilde($ctx, $sampler, $temp_right, $temp_left, $vi),
420+
DynamicPPL.dot_tilde_observe($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi),
406421
)
407422
end
408423
end
409424
else
425+
# we have a literal, which is automatically an observation
410426
ex = quote
411-
$temp_left = $left
412427
$temp_right = $right
413428
$assert_ex
429+
414430
DynamicPPL.acclogp!(
415431
$vi,
416-
DynamicPPL.dot_tilde($ctx, $sampler, $temp_right, $temp_left, $vi),
432+
DynamicPPL.dot_tilde_observe($ctx, $sampler, $temp_right, $left, $vi),
417433
)
418434
end
419435
end

src/context_implementations.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ function tilde(ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi)
3535
return tilde(ctx.ctx, sampler, right, left, inds, vi)
3636
end
3737

38+
39+
function tilde_assume(ctx, sampler, right, vn, inds, vi)
40+
return tilde(ctx, sampler, right, vn, inds, vi)
41+
end
42+
43+
3844
function _tilde(sampler, right, vn::VarName, vi)
3945
return assume(sampler, right, vn, vi)
4046
end
@@ -68,6 +74,14 @@ function tilde(ctx::MiniBatchContext, sampler, right, left, vi)
6874
return ctx.loglike_scalar * tilde(ctx.ctx, sampler, right, left, vi)
6975
end
7076

77+
function tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
78+
return tilde(ctx, sampler, right, left, vi)
79+
end
80+
function tilde_observe(ctx, sampler, right, left, vi)
81+
return tilde(ctx, sampler, right, left, vi)
82+
end
83+
84+
7185
_tilde(sampler, right, left, vi) = observe(sampler, right, left, vi)
7286

7387
function assume(spl::Sampler, dist)
@@ -163,6 +177,11 @@ function dot_tilde(
163177
return _dot_tilde(sampler, dist, left, vns, vi)
164178
end
165179

180+
function dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
181+
return dot_tilde(ctx, sampler, right, left, vn, inds, vi)
182+
end
183+
184+
166185
function get_vns_and_dist(dist::NamedDist, var, vn::VarName)
167186
name = dist.name
168187
if name isa String
@@ -337,6 +356,14 @@ function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi)
337356
return ctx.loglike_scalar * dot_tilde(ctx.ctx, sampler, right, left, left, vi)
338357
end
339358

359+
function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi)
360+
return dot_tilde(ctx, sampler, right, left, vi)
361+
end
362+
function dot_tilde_observe(ctx, sampler, right, left, vi)
363+
return dot_tilde(ctx, sampler, right, left, vi)
364+
end
365+
366+
340367
function _dot_tilde(sampler, right, left::AbstractArray, vi)
341368
return dot_observe(sampler, right, left, vi)
342369
end

0 commit comments

Comments
 (0)