Skip to content

Commit 5aff615

Browse files
committed
Interpolate DynamicPPL in macros
1 parent 14a3a7d commit 5aff615

File tree

4 files changed

+44
-44
lines changed

4 files changed

+44
-44
lines changed

src/compiler.jl

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ macro isassumption(model, expr::Union{Symbol, Expr})
4545

4646
# This branch should compile nicely in all cases except for partial missing data
4747
# For example, when `expr` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
48-
if !DynamicPPL.inargnames($vn, $model) || DynamicPPL.inmissings($vn, $model)
48+
if !$DynamicPPL.inargnames($vn, $model) || $DynamicPPL.inmissings($vn, $model)
4949
true
5050
else
51-
if DynamicPPL.inargnames($vn, $model)
51+
if $DynamicPPL.inargnames($vn, $model)
5252
# Evaluate the lhs
5353
$expr === missing
5454
else
@@ -128,7 +128,7 @@ function build_model_info(input_expr)
128128
Expr(:tuple, QuoteNode.(arg_syms)...),
129129
Expr(:curly, :Tuple, [:(Core.Typeof($x)) for x in arg_syms]...)
130130
)
131-
args_nt = Expr(:call, :(DynamicPPL.namedtuple), nt_type, Expr(:tuple, arg_syms...))
131+
args_nt = Expr(:call, :($DynamicPPL.namedtuple), nt_type, Expr(:tuple, arg_syms...))
132132
end
133133
args = map(modeldef[:args]) do arg
134134
if (arg isa Symbol)
@@ -300,23 +300,23 @@ function generate_tilde(left, right, model_info)
300300
vn = gensym(:vn)
301301
inds = gensym(:inds)
302302
isassumption = gensym(:isassumption)
303-
assert_ex = :(DynamicPPL.assert_dist($temp_right, msg = $(wrong_dist_errormsg(@__LINE__))))
304-
303+
assert_ex = :($DynamicPPL.assert_dist($temp_right, msg = $(wrong_dist_errormsg(@__LINE__))))
304+
305305
if left isa Symbol || left isa Expr
306306
ex = quote
307307
$temp_right = $right
308308
$assert_ex
309-
309+
310310
$vn, $inds = $(varname(left)), $(vinds(left))
311-
$isassumption = DynamicPPL.@isassumption($model, $left)
312-
if $isassumption
313-
$out = DynamicPPL.tilde_assume($ctx, $sampler, $temp_right, $vn, $inds, $vi)
311+
$isassumption = $DynamicPPL.@isassumption($model, $left)
312+
if $isassumption
313+
$out = $DynamicPPL.tilde_assume($ctx, $sampler, $temp_right, $vn, $inds, $vi)
314314
$left = $out[1]
315-
DynamicPPL.acclogp!($vi, $out[2])
315+
$DynamicPPL.acclogp!($vi, $out[2])
316316
else
317-
DynamicPPL.acclogp!(
317+
$DynamicPPL.acclogp!(
318318
$vi,
319-
DynamicPPL.tilde_observe($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi),
319+
$DynamicPPL.tilde_observe($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi),
320320
)
321321
end
322322
end
@@ -325,10 +325,10 @@ function generate_tilde(left, right, model_info)
325325
ex = quote
326326
$temp_right = $right
327327
$assert_ex
328-
329-
DynamicPPL.acclogp!(
328+
329+
$DynamicPPL.acclogp!(
330330
$vi,
331-
DynamicPPL.tilde_observe($ctx, $sampler, $temp_right, $left, $vi),
331+
$DynamicPPL.tilde_observe($ctx, $sampler, $temp_right, $left, $vi),
332332
)
333333
end
334334
end
@@ -353,24 +353,24 @@ function generate_dot_tilde(left, right, model_info)
353353
lp = gensym(:lp)
354354
vn = gensym(:vn)
355355
inds = gensym(:inds)
356-
assert_ex = :(DynamicPPL.assert_dist($temp_right, msg = $(wrong_dist_errormsg(@__LINE__))))
357-
356+
assert_ex = :($DynamicPPL.assert_dist($temp_right, msg = $(wrong_dist_errormsg(@__LINE__))))
357+
358358
if left isa Symbol || left isa Expr
359359
ex = quote
360360
$temp_right = $right
361361
$assert_ex
362362

363363
$vn, $inds = $(varname(left)), $(vinds(left))
364-
$isassumption = DynamicPPL.@isassumption($model, $left)
365-
364+
$isassumption = $DynamicPPL.@isassumption($model, $left)
365+
366366
if $isassumption
367-
$out = DynamicPPL.dot_tilde_assume($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi)
367+
$out = $DynamicPPL.dot_tilde_assume($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi)
368368
$left .= $out[1]
369-
DynamicPPL.acclogp!($vi, $out[2])
369+
$DynamicPPL.acclogp!($vi, $out[2])
370370
else
371-
DynamicPPL.acclogp!(
371+
$DynamicPPL.acclogp!(
372372
$vi,
373-
DynamicPPL.dot_tilde_observe($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi),
373+
$DynamicPPL.dot_tilde_observe($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi),
374374
)
375375
end
376376
end
@@ -379,10 +379,10 @@ function generate_dot_tilde(left, right, model_info)
379379
ex = quote
380380
$temp_right = $right
381381
$assert_ex
382-
383-
DynamicPPL.acclogp!(
382+
383+
$DynamicPPL.acclogp!(
384384
$vi,
385-
DynamicPPL.dot_tilde_observe($ctx, $sampler, $temp_right, $left, $vi),
385+
$DynamicPPL.dot_tilde_observe($ctx, $sampler, $temp_right, $left, $vi),
386386
)
387387
end
388388
end
@@ -431,10 +431,10 @@ function build_output(model_info)
431431
local $var
432432
$temp_var = $model.args.$var
433433
$varT = typeof($temp_var)
434-
if $temp_var isa DynamicPPL.FloatOrArrayType
435-
$var = DynamicPPL.get_matching_type($sampler, $vi, $temp_var)
436-
elseif DynamicPPL.hasmissing($varT)
437-
$var = DynamicPPL.get_matching_type($sampler, $vi, $varT)($temp_var)
434+
if $temp_var isa $DynamicPPL.FloatOrArrayType
435+
$var = $DynamicPPL.get_matching_type($sampler, $vi, $temp_var)
436+
elseif $DynamicPPL.hasmissing($varT)
437+
$var = $DynamicPPL.get_matching_type($sampler, $vi, $varT)($temp_var)
438438
else
439439
$var = $temp_var
440440
end
@@ -443,21 +443,21 @@ function build_output(model_info)
443443

444444
@gensym(evaluator, generator)
445445
generator_kw_form = isempty(args) ? () : (:($generator(;$(args...)) = $generator($(arg_syms...))),)
446-
model_gen_constructor = :(DynamicPPL.ModelGen{$(Tuple(arg_syms))}($generator, $defaults_nt))
446+
model_gen_constructor = :($DynamicPPL.ModelGen{$(Tuple(arg_syms))}($generator, $defaults_nt))
447447

448448
ex = quote
449449
function $evaluator(
450-
$model::DynamicPPL.Model,
451-
$vi::DynamicPPL.VarInfo,
452-
$sampler::DynamicPPL.AbstractSampler,
453-
$ctx::DynamicPPL.AbstractContext,
450+
$model::$DynamicPPL.Model,
451+
$vi::$DynamicPPL.VarInfo,
452+
$sampler::$DynamicPPL.AbstractSampler,
453+
$ctx::$DynamicPPL.AbstractContext,
454454
)
455455
$unwrap_data_expr
456-
DynamicPPL.resetlogp!($vi)
456+
$DynamicPPL.resetlogp!($vi)
457457
$main_body
458458
end
459459

460-
$generator($(args...)) = DynamicPPL.Model($evaluator, $args_nt, $model_gen_constructor)
460+
$generator($(args...)) = $DynamicPPL.Model($evaluator, $args_nt, $model_gen_constructor)
461461
$(generator_kw_form...)
462462

463463
$model_gen = $model_gen_constructor

src/prob_macro.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
macro logprob_str(str)
22
expr1, expr2 = get_exprs(str)
3-
return :(DynamicPPL.logprob($expr1, $expr2)) |> esc
3+
return :($DynamicPPL.logprob($expr1, $expr2)) |> esc
44
end
55
macro prob_str(str)
66
expr1, expr2 = get_exprs(str)
7-
return :(exp.(DynamicPPL.logprob($expr1, $expr2))) |> esc
7+
return :(exp.($DynamicPPL.logprob($expr1, $expr2))) |> esc
88
end
99

1010
function get_exprs(str::String)
@@ -169,7 +169,7 @@ end
169169
# `missings` is splatted into a tuple at compile time and inserted as literal
170170
return quote
171171
$(warnings...)
172-
DynamicPPL.Model{$(Tuple(missings))}(modelgen, $(to_namedtuple_expr(argnames, argvals)))
172+
$DynamicPPL.Model{$(Tuple(missings))}(modelgen, $(to_namedtuple_expr(argnames, argvals)))
173173
end
174174
end
175175

@@ -225,7 +225,7 @@ end
225225

226226
# `args` is inserted as properly typed NamedTuple expression;
227227
# `missings` is splatted into a tuple at compile time and inserted as literal
228-
return :(DynamicPPL.Model{$(Tuple(missings))}(modelgen, $(to_namedtuple_expr(argnames, argvals))))
228+
return :($DynamicPPL.Model{$(Tuple(missings))}(modelgen, $(to_namedtuple_expr(argnames, argvals))))
229229
end
230230

231231
_setval!(vi::TypedVarInfo, c::AbstractChains) = _setval!(vi.metadata, vi, c)

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ function to_namedtuple_expr(syms, vals=syms)
5959
Expr(:tuple, QuoteNode.(syms)...),
6060
Expr(:curly, :Tuple, [:(Core.Typeof($x)) for x in vals]...)
6161
)
62-
nt = Expr(:call, :(DynamicPPL.namedtuple), nt_type, Expr(:tuple, vals...))
62+
nt = Expr(:call, :($DynamicPPL.namedtuple), nt_type, Expr(:tuple, vals...))
6363
end
6464
return nt
6565
end

src/varname.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ end
3535
function varname(expr)
3636
ex = deepcopy(expr)
3737
(ex isa Symbol) && return quote
38-
DynamicPPL.VarName{$(QuoteNode(ex))}("")
38+
$DynamicPPL.VarName{$(QuoteNode(ex))}("")
3939
end
4040
(ex.head == :ref) || throw("VarName: Mis-formed variable name $(expr)!")
4141
inds = :(())
@@ -46,7 +46,7 @@ function varname(expr)
4646
end
4747
ex = ex.args[1]
4848
isa(ex, Symbol) && return quote
49-
DynamicPPL.VarName{$(QuoteNode(ex))}(foldl(*, $inds, init = ""))
49+
$DynamicPPL.VarName{$(QuoteNode(ex))}(foldl(*, $inds, init = ""))
5050
end
5151
end
5252
throw("VarName: Mis-formed variable name $(expr)!")

0 commit comments

Comments
 (0)