Skip to content

Commit a2b4ace

Browse files
committed
Improve interpolation and use function isassumption
1 parent 5aff615 commit a2b4ace

File tree

4 files changed

+155
-137
lines changed

4 files changed

+155
-137
lines changed

src/compiler.jl

Lines changed: 142 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -27,43 +27,42 @@ function wrong_dist_errormsg(l)
2727
end
2828

2929
"""
30-
@isassumption(model, expr)
30+
isassumption(model, expr)
3131
32-
Let `expr` be `x[1]`. `vn` is an assumption in the following cases:
33-
1. `x` was not among the input data to the model,
34-
2. `x` was among the input data to the model but with a value `missing`, or
35-
3. `x` was among the input data to the model with a value other than missing,
32+
Return an expression that can be evaluated to check if `expr` is an assumption in the
33+
`model`.
34+
35+
Let `expr` be `:(x[1])`. It is an assumption in the following cases:
36+
1. `x` is not among the input data to the `model`,
37+
2. `x` is among the input data to the `model` but with a value `missing`, or
38+
3. `x` is among the input data to the `model` with a value other than missing,
3639
but `x[1] === missing`.
40+
3741
When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
3842
"""
39-
macro isassumption(model, expr::Union{Symbol, Expr})
40-
# Note: never put a return in this... don't forget it's a macro!
43+
function isassumption(model, expr::Union{Symbol, Expr})
4144
vn = gensym(:vn)
42-
45+
4346
return quote
44-
$vn = @varname($expr)
45-
46-
# This branch should compile nicely in all cases except for partial missing data
47-
# For example, when `expr` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
48-
if !$DynamicPPL.inargnames($vn, $model) || $DynamicPPL.inmissings($vn, $model)
49-
true
50-
else
51-
if $DynamicPPL.inargnames($vn, $model)
52-
# Evaluate the lhs
53-
$expr === missing
47+
let $vn = $(varname(expr))
48+
# This branch should compile nicely in all cases except for partial missing data
49+
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
50+
if !$(DynamicPPL.inargnames)($vn, $model) || $(DynamicPPL.inmissings)($vn, $model)
51+
true
5452
else
55-
throw("This point should not be reached. Please report this error.")
53+
if $(DynamicPPL.inargnames)($vn, $model)
54+
# Evaluate the lhs
55+
$expr === missing
56+
else
57+
throw("This point should not be reached. Please report this error.")
58+
end
5659
end
5760
end
58-
end |> esc
59-
end
60-
61-
macro isassumption(model, expr)
62-
# failsafe: a literal is never an assumption
63-
false
61+
end
6462
end
6563

66-
64+
# failsafe: a literal is never an assumption
65+
isassumption(model, expr) = :(false)
6766

6867
#################
6968
# Main Compiler #
@@ -128,7 +127,7 @@ function build_model_info(input_expr)
128127
Expr(:tuple, QuoteNode.(arg_syms)...),
129128
Expr(:curly, :Tuple, [:(Core.Typeof($x)) for x in arg_syms]...)
130129
)
131-
args_nt = Expr(:call, :($DynamicPPL.namedtuple), nt_type, Expr(:tuple, arg_syms...))
130+
args_nt = Expr(:call, :($namedtuple), nt_type, Expr(:tuple, arg_syms...))
132131
end
133132
args = map(modeldef[:args]) do arg
134133
if (arg isa Symbol)
@@ -217,7 +216,7 @@ function replace_logpdf!(model_info)
217216
vi = model_info[:main_body_names][:vi]
218217
ex = MacroTools.postwalk(ex) do x
219218
if @capture(x, @logpdf())
220-
:($vi.logp[])
219+
:(getlogp($vi))
221220
else
222221
x
223222
end
@@ -294,45 +293,58 @@ function generate_tilde(left, right, model_info)
294293
vi = model_info[:main_body_names][:vi]
295294
ctx = model_info[:main_body_names][:ctx]
296295
sampler = model_info[:main_body_names][:sampler]
297-
temp_right = gensym(:temp_right)
298-
out = gensym(:out)
299-
lp = gensym(:lp)
300-
vn = gensym(:vn)
301-
inds = gensym(:inds)
302-
isassumption = gensym(:isassumption)
303-
assert_ex = :($DynamicPPL.assert_dist($temp_right, msg = $(wrong_dist_errormsg(@__LINE__))))
296+
297+
@gensym tmpright
298+
expr = quote
299+
$tmpright = $right
300+
$(DynamicPPL.assert_dist)($tmpright, msg = $(wrong_dist_errormsg(@__LINE__)))
301+
end
304302

305303
if left isa Symbol || left isa Expr
306-
ex = quote
307-
$temp_right = $right
308-
$assert_ex
309-
310-
$vn, $inds = $(varname(left)), $(vinds(left))
311-
$isassumption = $DynamicPPL.@isassumption($model, $left)
312-
if $isassumption
313-
$out = $DynamicPPL.tilde_assume($ctx, $sampler, $temp_right, $vn, $inds, $vi)
314-
$left = $out[1]
315-
$DynamicPPL.acclogp!($vi, $out[2])
316-
else
317-
$DynamicPPL.acclogp!(
318-
$vi,
319-
$DynamicPPL.tilde_observe($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi),
320-
)
304+
@gensym out vn inds
305+
push!(expr.args,
306+
:($vn = $(varname(left))),
307+
:($inds = $(vinds(left))))
308+
309+
assumption = quote
310+
$out = $(DynamicPPL.tilde_assume)($ctx, $sampler, $tmpright, $vn, $inds,
311+
$vi)
312+
$left = $out[1]
313+
$(DynamicPPL.acclogp!)($vi, $out[2])
314+
end
315+
316+
# It can only be an observation if the LHS is an argument of the model
317+
if vsym(left) in model_info[:args]
318+
@gensym isassumption
319+
return quote
320+
$expr
321+
$isassumption = $(DynamicPPL.isassumption(model, left))
322+
if $isassumption
323+
$assumption
324+
else
325+
$(DynamicPPL.acclogp!)(
326+
$vi,
327+
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $left, $vn,
328+
$inds, $vi)
329+
)
330+
end
321331
end
322332
end
323-
else
324-
# we have a literal, which is automatically an observation
325-
ex = quote
326-
$temp_right = $right
327-
$assert_ex
328-
329-
$DynamicPPL.acclogp!(
330-
$vi,
331-
$DynamicPPL.tilde_observe($ctx, $sampler, $temp_right, $left, $vi),
332-
)
333+
334+
return quote
335+
$expr
336+
$assumption
333337
end
334338
end
335-
return ex
339+
340+
# If the LHS is a literal, it is always an observation
341+
return quote
342+
$expr
343+
$(DynamicPPL.acclogp!)(
344+
$vi,
345+
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $left, $vi)
346+
)
347+
end
336348
end
337349

338350
"""
@@ -347,46 +359,58 @@ function generate_dot_tilde(left, right, model_info)
347359
vi = model_info[:main_body_names][:vi]
348360
ctx = model_info[:main_body_names][:ctx]
349361
sampler = model_info[:main_body_names][:sampler]
350-
out = gensym(:out)
351-
temp_right = gensym(:temp_right)
352-
isassumption = gensym(:isassumption)
353-
lp = gensym(:lp)
354-
vn = gensym(:vn)
355-
inds = gensym(:inds)
356-
assert_ex = :($DynamicPPL.assert_dist($temp_right, msg = $(wrong_dist_errormsg(@__LINE__))))
357362

358-
if left isa Symbol || left isa Expr
359-
ex = quote
360-
$temp_right = $right
361-
$assert_ex
363+
@gensym tmpright
364+
expr = quote
365+
$tmpright = $right
366+
$(DynamicPPL.assert_dist)($tmpright, msg = $(wrong_dist_errormsg(@__LINE__)))
367+
end
362368

363-
$vn, $inds = $(varname(left)), $(vinds(left))
364-
$isassumption = $DynamicPPL.@isassumption($model, $left)
369+
if left isa Symbol || left isa Expr
370+
@gensym out vn inds
371+
push!(expr.args,
372+
:($vn = $(varname(left))),
373+
:($inds = $(vinds(left))))
374+
375+
assumption = quote
376+
$out = $(DynamicPPL.dot_tilde_assume)($ctx, $sampler, $tmpright, $left,
377+
$vn, $inds, $vi)
378+
$left .= $out[1]
379+
$(DynamicPPL.acclogp!)($vi, $out[2])
380+
end
365381

366-
if $isassumption
367-
$out = $DynamicPPL.dot_tilde_assume($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi)
368-
$left .= $out[1]
369-
$DynamicPPL.acclogp!($vi, $out[2])
370-
else
371-
$DynamicPPL.acclogp!(
372-
$vi,
373-
$DynamicPPL.dot_tilde_observe($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi),
374-
)
382+
# It can only be an observation if the LHS is an argument of the model
383+
if vsym(left) in model_info[:args]
384+
@gensym isassumption
385+
return quote
386+
$expr
387+
$isassumption = $(DynamicPPL.isassumption(model, left))
388+
if $isassumption
389+
$assumption
390+
else
391+
$(DynamicPPL.acclogp!)(
392+
$vi,
393+
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $left,
394+
$vn, $inds, $vi)
395+
)
396+
end
375397
end
376398
end
377-
else
378-
# we have a literal, which is automatically an observation
379-
ex = quote
380-
$temp_right = $right
381-
$assert_ex
382-
383-
$DynamicPPL.acclogp!(
384-
$vi,
385-
$DynamicPPL.dot_tilde_observe($ctx, $sampler, $temp_right, $left, $vi),
386-
)
399+
400+
return quote
401+
$expr
402+
$assumption
387403
end
388404
end
389-
return ex
405+
406+
# If the LHS is a literal, it is always an observation
407+
return quote
408+
$expr
409+
$(DynamicPPL.acclogp!)(
410+
$vi,
411+
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $left, $vi)
412+
)
413+
end
390414
end
391415

392416
const FloatOrArrayType = Type{<:Union{AbstractFloat, AbstractArray}}
@@ -425,39 +449,27 @@ function build_output(model_info)
425449

426450
unwrap_data_expr = Expr(:block)
427451
for var in arg_syms
428-
temp_var = gensym(:temp_var)
429-
varT = gensym(:varT)
430-
push!(unwrap_data_expr.args, quote
431-
local $var
432-
$temp_var = $model.args.$var
433-
$varT = typeof($temp_var)
434-
if $temp_var isa $DynamicPPL.FloatOrArrayType
435-
$var = $DynamicPPL.get_matching_type($sampler, $vi, $temp_var)
436-
elseif $DynamicPPL.hasmissing($varT)
437-
$var = $DynamicPPL.get_matching_type($sampler, $vi, $varT)($temp_var)
438-
else
439-
$var = $temp_var
440-
end
441-
end)
452+
push!(unwrap_data_expr.args,
453+
:($var = $(DynamicPPL.matchingvalue)($sampler, $vi, $(model).args.$var)))
442454
end
443455

444456
@gensym(evaluator, generator)
445457
generator_kw_form = isempty(args) ? () : (:($generator(;$(args...)) = $generator($(arg_syms...))),)
446-
model_gen_constructor = :($DynamicPPL.ModelGen{$(Tuple(arg_syms))}($generator, $defaults_nt))
458+
model_gen_constructor = :($(DynamicPPL.ModelGen){$(Tuple(arg_syms))}($generator, $defaults_nt))
447459

448460
ex = quote
449461
function $evaluator(
450-
$model::$DynamicPPL.Model,
451-
$vi::$DynamicPPL.VarInfo,
452-
$sampler::$DynamicPPL.AbstractSampler,
453-
$ctx::$DynamicPPL.AbstractContext,
462+
$model::$(DynamicPPL.Model),
463+
$vi::$(DynamicPPL.VarInfo),
464+
$sampler::$(DynamicPPL.AbstractSampler),
465+
$ctx::$(DynamicPPL.AbstractContext),
454466
)
455467
$unwrap_data_expr
456-
$DynamicPPL.resetlogp!($vi)
468+
$(DynamicPPL.resetlogp!)($vi)
457469
$main_body
458470
end
459471

460-
$generator($(args...)) = $DynamicPPL.Model($evaluator, $args_nt, $model_gen_constructor)
472+
$generator($(args...)) = $(DynamicPPL.Model)($evaluator, $args_nt, $model_gen_constructor)
461473
$(generator_kw_form...)
462474

463475
$model_gen = $model_gen_constructor
@@ -474,6 +486,21 @@ function warn_empty(body)
474486
return
475487
end
476488

489+
"""
490+
matchingvalue(sampler, vi, value)
491+
492+
Convert the `value` to the correct type for the `sampler` and the `vi` object.
493+
"""
494+
function matchingvalue(sampler, vi, value)
495+
T = typeof(value)
496+
if hasmissing(T)
497+
return get_matching_type(sampler, vi, T)(value)
498+
else
499+
return value
500+
end
501+
end
502+
matchingvalue(sampler, vi, value::FloatOrArrayType) = get_matching_type(sampler, vi, value)
503+
477504
"""
478505
get_matching_type(spl, vi, ::Type{T}) where {T}
479506
Get the specialized version of type `T` for sampler `spl`. For example,

src/prob_macro.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
macro logprob_str(str)
22
expr1, expr2 = get_exprs(str)
3-
return :($DynamicPPL.logprob($expr1, $expr2)) |> esc
3+
return :(logprob($(esc(expr1)), $(esc(expr2))))
44
end
55
macro prob_str(str)
66
expr1, expr2 = get_exprs(str)
7-
return :(exp.($DynamicPPL.logprob($expr1, $expr2))) |> esc
7+
return :(exp.(logprob($(esc(expr1)), $(esc(expr2)))))
88
end
99

1010
function get_exprs(str::String)
@@ -169,7 +169,7 @@ end
169169
# `missings` is splatted into a tuple at compile time and inserted as literal
170170
return quote
171171
$(warnings...)
172-
$DynamicPPL.Model{$(Tuple(missings))}(modelgen, $(to_namedtuple_expr(argnames, argvals)))
172+
Model{$(Tuple(missings))}(modelgen, $(to_namedtuple_expr(argnames, argvals)))
173173
end
174174
end
175175

@@ -225,7 +225,7 @@ end
225225

226226
# `args` is inserted as properly typed NamedTuple expression;
227227
# `missings` is splatted into a tuple at compile time and inserted as literal
228-
return :($DynamicPPL.Model{$(Tuple(missings))}(modelgen, $(to_namedtuple_expr(argnames, argvals))))
228+
return :(Model{$(Tuple(missings))}(modelgen, $(to_namedtuple_expr(argnames, argvals))))
229229
end
230230

231231
_setval!(vi::TypedVarInfo, c::AbstractChains) = _setval!(vi.metadata, vi, c)

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ function to_namedtuple_expr(syms, vals=syms)
5959
Expr(:tuple, QuoteNode.(syms)...),
6060
Expr(:curly, :Tuple, [:(Core.Typeof($x)) for x in vals]...)
6161
)
62-
nt = Expr(:call, :($DynamicPPL.namedtuple), nt_type, Expr(:tuple, vals...))
62+
nt = Expr(:call, :($(DynamicPPL.namedtuple)), nt_type, Expr(:tuple, vals...))
6363
end
6464
return nt
6565
end

0 commit comments

Comments
 (0)