Skip to content

Commit f6e5235

Browse files
committed
Merge branch 'master' into phg/varinfo-indices
2 parents 77b6828 + d3ff242 commit f6e5235

File tree

10 files changed

+337
-153
lines changed

10 files changed

+337
-153
lines changed

.travis.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@ os:
88
- osx
99
julia:
1010
- 1.0
11-
- 1.1
12-
- 1.2
13-
- 1.3
11+
- 1
1412
- nightly
1513
matrix:
1614
allow_failures:
@@ -19,6 +17,6 @@ matrix:
1917
notifications:
2018
email: false
2119
after_success:
22-
- if [[ $TRAVIS_JULIA_VERSION = 1.3 ]] && [[ $TRAVIS_OS_NAME = linux ]]; then
20+
- if [[ $TRAVIS_JULIA_VERSION = 1 ]] && [[ $TRAVIS_OS_NAME = linux ]]; then
2321
julia -e 'using Pkg; Pkg.add("Coverage"); using Coverage; Codecov.submit(process_folder())';
2422
fi

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1010
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1111

1212
[compat]
13-
AbstractMCMC = "0.4, 0.5"
13+
AbstractMCMC = "0.4, 0.5, 1.0"
1414
Bijectors = "0.5.2, 0.6"
1515
Distributions = "0.22, 0.23"
1616
MacroTools = "0.5.1"

src/compiler.jl

Lines changed: 75 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -27,45 +27,44 @@ function wrong_dist_errormsg(l)
2727
end
2828

2929
"""
30-
@preprocess(data_vars, missing_vars, ex)
30+
@isassumption(model, expr)
3131
32-
Let `ex` be `x[1]`. This macro returns `@varname x[1]` in any of the following cases:
32+
Let `expr` be `x[1]`. `vn` is an assumption in the following cases:
3333
1. `x` was not among the input data to the model,
3434
2. `x` was among the input data to the model but with a value `missing`, or
3535
3. `x` was among the input data to the model with a value other than missing,
36-
but `x[1] === missing`.
37-
Otherwise, the value of `x[1]` is returned.
36+
but `x[1] === missing`.
37+
When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
3838
"""
39-
macro preprocess(data_vars, missing_vars, ex)
40-
ex
41-
end
42-
macro preprocess(model, ex::Union{Symbol, Expr})
43-
sym = gensym(:sym)
44-
lhs = gensym(:lhs)
45-
return esc(quote
46-
# Extract symbol
47-
$sym = Val($(vsym(ex)))
39+
macro isassumption(model, expr::Union{Symbol, Expr})
40+
# Note: never put a return in this... don't forget it's a macro!
41+
vn = gensym(:vn)
42+
43+
return quote
44+
$vn = @varname($expr)
45+
4846
# This branch should compile nicely in all cases except for partial missing data
49-
# For example, when `ex` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
50-
if !DynamicPPL.inargnames($sym, $model) || DynamicPPL.inmissings($sym, $model)
51-
$(varname(ex)), $(vinds(ex))
47+
# For example, when `expr` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
48+
if !DynamicPPL.inargnames($vn, $model) || DynamicPPL.inmissings($vn, $model)
49+
true
5250
else
53-
if DynamicPPL.inargnames($sym, $model)
51+
if DynamicPPL.inargnames($vn, $model)
5452
# Evaluate the lhs
55-
$lhs = $ex
56-
if $lhs === missing
57-
$(varname(ex)), $(vinds(ex))
58-
else
59-
$lhs
60-
end
53+
$expr === missing
6154
else
6255
throw("This point should not be reached. Please report this error.")
6356
end
6457
end
65-
end)
58+
end |> esc
59+
end
60+
61+
macro isassumption(model, expr)
62+
# failsafe: a literal is never an assumption
63+
false
6664
end
6765

6866

67+
6968
#################
7069
# Main Compiler #
7170
#################
@@ -246,40 +245,40 @@ function replace_sampler!(model_info)
246245
return model_info
247246
end
248247

249-
# The next function is defined that way because .~ gives a parsing error in Julia 1.0
250248
"""
251-
\"""
252249
replace_tilde!(model_info)
253250
254-
Replaces `~` expressions with observation or assumption expressions, updating `model_info`.
255-
\"""
251+
Replace `~` and `.~` expressions with observation or assumption expressions, updating `model_info`.
252+
"""
256253
function replace_tilde!(model_info)
257-
ex = model_info[:main_body]
258-
ex = MacroTools.postwalk(ex) do x
259-
if @capture(x, @M_ L_ ~ R_) && M == Symbol("@__dot__")
260-
generate_dot_tilde(L, R, model_info)
261-
else
262-
x
254+
# Apply the `@.` macro first.
255+
expr = model_info[:main_body]
256+
dottedexpr = MacroTools.postwalk(apply_dotted, expr)
257+
258+
# Check for tilde operators.
259+
tildeexpr = MacroTools.postwalk(dottedexpr) do x
260+
# Check dot tilde first.
261+
dotargs = getargs_dottilde(x)
262+
if dotargs !== nothing
263+
L, R = dotargs
264+
return generate_dot_tilde(L, R, model_info)
263265
end
264-
end
265-
$(VERSION >= v"1.1" ? "ex = MacroTools.postwalk(ex) do x
266-
if @capture(x, L_ .~ R_)
267-
generate_dot_tilde(L, R, model_info)
268-
else
269-
x
270-
end
271-
end" : "")
272-
ex = MacroTools.postwalk(ex) do x
273-
if @capture(x, L_ ~ R_)
274-
generate_tilde(L, R, model_info)
275-
else
276-
x
266+
267+
# Check tilde.
268+
args = getargs_tilde(x)
269+
if args !== nothing
270+
L, R = args
271+
return generate_tilde(L, R, model_info)
277272
end
273+
274+
return x
278275
end
279-
model_info[:main_body] = ex
276+
277+
# Update the function body.
278+
model_info[:main_body] = tildeexpr
279+
280280
return model_info
281281
end
282-
""" |> Meta.parse |> eval
283282

284283
# """ Unbreak code highlighting in Emacs julia-mode
285284

@@ -300,32 +299,36 @@ function generate_tilde(left, right, model_info)
300299
lp = gensym(:lp)
301300
vn = gensym(:vn)
302301
inds = gensym(:inds)
303-
preprocessed = gensym(:preprocessed)
302+
isassumption = gensym(:isassumption)
304303
assert_ex = :(DynamicPPL.assert_dist($temp_right, msg = $(wrong_dist_errormsg(@__LINE__))))
304+
305305
if left isa Symbol || left isa Expr
306306
ex = quote
307307
$temp_right = $right
308308
$assert_ex
309-
$preprocessed = DynamicPPL.@preprocess($model, $left)
310-
if $preprocessed isa Tuple
311-
$vn, $inds = $preprocessed
312-
$out = DynamicPPL.tilde($ctx, $sampler, $temp_right, $vn, $inds, $vi)
309+
310+
$vn, $inds = $(varname(left)), $(vinds(left))
311+
$isassumption = DynamicPPL.@isassumption($model, $left)
312+
if $isassumption
313+
$out = DynamicPPL.tilde_assume($ctx, $sampler, $temp_right, $vn, $inds, $vi)
313314
$left = $out[1]
314315
DynamicPPL.acclogp!($vi, $out[2])
315316
else
316317
DynamicPPL.acclogp!(
317318
$vi,
318-
DynamicPPL.tilde($ctx, $sampler, $temp_right, $preprocessed, $vi),
319+
DynamicPPL.tilde_observe($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi),
319320
)
320321
end
321322
end
322323
else
324+
# we have a literal, which is automatically an observation
323325
ex = quote
324326
$temp_right = $right
325327
$assert_ex
328+
326329
DynamicPPL.acclogp!(
327330
$vi,
328-
DynamicPPL.tilde($ctx, $sampler, $temp_right, $left, $vi),
331+
DynamicPPL.tilde_observe($ctx, $sampler, $temp_right, $left, $vi),
329332
)
330333
end
331334
end
@@ -335,48 +338,51 @@ end
335338
"""
336339
generate_dot_tilde(left, right, model_info)
337340
338-
This function returns the expression that replaces `left .~ right` in the model body. If `preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block will be run.
341+
This function returns the expression that replaces `left .~ right` in the model body. If
342+
`preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block
343+
will be run.
339344
"""
340345
function generate_dot_tilde(left, right, model_info)
341346
model = model_info[:main_body_names][:model]
342347
vi = model_info[:main_body_names][:vi]
343348
ctx = model_info[:main_body_names][:ctx]
344349
sampler = model_info[:main_body_names][:sampler]
345350
out = gensym(:out)
346-
temp_left = gensym(:temp_left)
347351
temp_right = gensym(:temp_right)
348-
preprocessed = gensym(:preprocessed)
352+
isassumption = gensym(:isassumption)
349353
lp = gensym(:lp)
350354
vn = gensym(:vn)
351355
inds = gensym(:inds)
352356
assert_ex = :(DynamicPPL.assert_dist($temp_right, msg = $(wrong_dist_errormsg(@__LINE__))))
357+
353358
if left isa Symbol || left isa Expr
354359
ex = quote
355360
$temp_right = $right
356361
$assert_ex
357-
$preprocessed = DynamicPPL.@preprocess($model, $left)
358-
if $preprocessed isa Tuple
359-
$vn, $inds = $preprocessed
360-
$temp_left = $left
361-
$out = DynamicPPL.dot_tilde($ctx, $sampler, $temp_right, $temp_left, $vn, $inds, $vi)
362+
363+
$vn, $inds = $(varname(left)), $(vinds(left))
364+
$isassumption = DynamicPPL.@isassumption($model, $left)
365+
366+
if $isassumption
367+
$out = DynamicPPL.dot_tilde_assume($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi)
362368
$left .= $out[1]
363369
DynamicPPL.acclogp!($vi, $out[2])
364370
else
365-
$temp_left = $preprocessed
366371
DynamicPPL.acclogp!(
367372
$vi,
368-
DynamicPPL.dot_tilde($ctx, $sampler, $temp_right, $temp_left, $vi),
373+
DynamicPPL.dot_tilde_observe($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi),
369374
)
370375
end
371376
end
372377
else
378+
# we have a literal, which is automatically an observation
373379
ex = quote
374-
$temp_left = $left
375380
$temp_right = $right
376381
$assert_ex
382+
377383
DynamicPPL.acclogp!(
378384
$vi,
379-
DynamicPPL.dot_tilde($ctx, $sampler, $temp_right, $temp_left, $vi),
385+
DynamicPPL.dot_tilde_observe($ctx, $sampler, $temp_right, $left, $vi),
380386
)
381387
end
382388
end
@@ -416,7 +422,7 @@ function build_output(model_info)
416422
model_gen = model_info[:name]
417423
# Main body of the model
418424
main_body = model_info[:main_body]
419-
425+
420426
unwrap_data_expr = Expr(:block)
421427
for var in arg_syms
422428
temp_var = gensym(:temp_var)

src/context_implementations.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ function tilde(ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi)
3535
return tilde(ctx.ctx, sampler, right, left, inds, vi)
3636
end
3737

38+
"""
39+
tilde_assume(ctx, sampler, right, vn, inds, vi)
40+
41+
This method is applied in the generated code for assumed variables, e.g., `x ~ Normal()` where
42+
`x` does not occur in the model inputs.
43+
44+
Falls back to `tilde(ctx, sampler, right, vn, inds, vi)`.
45+
"""
46+
function tilde_assume(ctx, sampler, right, vn, inds, vi)
47+
return tilde(ctx, sampler, right, vn, inds, vi)
48+
end
49+
50+
3851
function _tilde(sampler, right, vn::VarName, vi)
3952
return assume(sampler, right, vn, vi)
4053
end
@@ -56,6 +69,30 @@ function tilde(ctx::MiniBatchContext, sampler, right, left, vi)
5669
return ctx.loglike_scalar * tilde(ctx.ctx, sampler, right, left, vi)
5770
end
5871

72+
"""
73+
tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
74+
75+
This method is applied in the generated code for observed variables, e.g., `x ~ Normal()` where
76+
`x` does occur in the model inputs.
77+
78+
Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable
79+
name and indices; if needed, these can be accessed through this function, though.
80+
"""
81+
function tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
82+
return tilde(ctx, sampler, right, left, vi)
83+
end
84+
85+
"""
86+
tilde_observe(ctx, sampler, right, left, vi)
87+
88+
This method is applied in the generated code for observed constants, e.g., `1.0 ~ Normal()`.
89+
Falls back to `tilde(ctx, sampler, right, left, vi)`.
90+
"""
91+
function tilde_observe(ctx, sampler, right, left, vi)
92+
return tilde(ctx, sampler, right, left, vi)
93+
end
94+
95+
5996
_tilde(sampler, right, left, vi) = observe(sampler, right, left, vi)
6097

6198
function assume(spl::Sampler, dist)
@@ -151,6 +188,19 @@ function dot_tilde(
151188
return _dot_tilde(sampler, dist, left, vns, vi)
152189
end
153190

191+
"""
192+
dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
193+
194+
This method is applied in the generated code for assumed vectorized variables, e.g., `x .~
195+
MvNormal()` where `x` does not occur in the model inputs.
196+
197+
Falls back to `dot_tilde(ctx, sampler, right, left, vn, inds, vi)`.
198+
"""
199+
function dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
200+
return dot_tilde(ctx, sampler, right, left, vn, inds, vi)
201+
end
202+
203+
154204
function get_vns_and_dist(dist::NamedDist, var, vn::VarName)
155205
return get_vns_and_dist(dist.dist, var, dist.name)
156206
end
@@ -314,6 +364,30 @@ function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi)
314364
return ctx.loglike_scalar * dot_tilde(ctx.ctx, sampler, right, left, left, vi)
315365
end
316366

367+
"""
368+
dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
369+
370+
This method is applied in the generated code for vectorized observed variables, e.g., `x .~
371+
MvNormal()` where `x` does occur the model inputs.
372+
373+
Falls back to `dot_tilde(ctx, sampler, right, left, vi)` ignoring the information about variable
374+
name and indices; if needed, these can be accessed through this function, though.
375+
"""
376+
function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi)
377+
return dot_tilde(ctx, sampler, right, left, vi)
378+
end
379+
380+
"""
381+
dot_tilde_observe(ctx, sampler, right, left, vi)
382+
383+
This method is applied in the generated code for vectorized observed constants, e.g., `[1.0] .~
384+
MvNormal()`. Falls back to `dot_tilde(ctx, sampler, right, left, vi)`.
385+
"""
386+
function dot_tilde_observe(ctx, sampler, right, left, vi)
387+
return dot_tilde(ctx, sampler, right, left, vi)
388+
end
389+
390+
317391
function _dot_tilde(sampler, right, left::AbstractArray, vi)
318392
return dot_observe(sampler, right, left, vi)
319393
end

0 commit comments

Comments
 (0)