Skip to content

Commit d78d089

Browse files
committed
Fix dot macro and avoid eval
1 parent 08561a6 commit d78d089

File tree

3 files changed

+86
-27
lines changed

3 files changed

+86
-27
lines changed

src/compiler.jl

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -284,40 +284,40 @@ function replace_sampler!(model_info)
284284
return model_info
285285
end
286286

287-
# The next function is defined that way because .~ gives a parsing error in Julia 1.0
288287
"""
289-
\"""
290288
replace_tilde!(model_info)
291289
292-
Replaces `~` expressions with observation or assumption expressions, updating `model_info`.
293-
\"""
290+
Replace `~` and `.~` expressions with observation or assumption expressions, updating `model_info`.
291+
"""
294292
function replace_tilde!(model_info)
295-
ex = model_info[:main_body]
296-
ex = MacroTools.postwalk(ex) do x
297-
if @capture(x, @M_ L_ ~ R_) && M == Symbol("@__dot__")
298-
generate_dot_tilde(L, R, model_info)
299-
else
300-
x
301-
end
302-
end
303-
$(VERSION >= v"1.1" ? "ex = MacroTools.postwalk(ex) do x
304-
if @capture(x, L_ .~ R_)
305-
generate_dot_tilde(L, R, model_info)
306-
else
307-
x
293+
# Apply the `@.` macro first.
294+
expr = model_info[:main_body]
295+
dottedexpr = MacroTools.postwalk(apply_dotted, expr)
296+
297+
# Check for tilde operators.
298+
tildeexpr = MacroTools.postwalk(dottedexpr) do x
299+
# Check dot tilde first.
300+
dotargs = getargs_dottilde(x)
301+
if dotargs !== nothing
302+
L, R = dotargs
303+
return generate_dot_tilde(L, R, model_info)
308304
end
309-
end" : "")
310-
ex = MacroTools.postwalk(ex) do x
311-
if @capture(x, L_ ~ R_)
312-
generate_tilde(L, R, model_info)
313-
else
314-
x
305+
306+
# Check tilde.
307+
args = getargs_tilde(x)
308+
if args !== nothing
309+
L, R = args
310+
return generate_tilde(L, R, model_info)
315311
end
312+
313+
return x
316314
end
317-
model_info[:main_body] = ex
315+
316+
# Update the function body.
317+
model_info[:main_body] = tildeexpr
318+
318319
return model_info
319320
end
320-
""" |> Meta.parse |> eval
321321

322322
"""
323323
generate_tilde(left, right, model_info)

src/utils.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,53 @@
1+
"""
2+
apply_dotted(x)
3+
4+
Apply the transformation of the `@.` macro if `x` is an expression of the form `@. X`.
5+
"""
6+
apply_dotted(x) = x
7+
function apply_dotted(expr::Expr)
8+
if Meta.isexpr(expr, :macrocall) && length(expr.args) > 1 &&
9+
expr.args[1] === Symbol("@__dot__")
10+
return Base.Broadcast.__dot__(expr.args[end])
11+
end
12+
return expr
13+
end
14+
15+
"""
16+
getargs_dottilde(x)
17+
18+
Return the arguments `L` and `R`, if `x` is an expression of the form `L .~ R` or
19+
`(~).(L, R)`, or `nothing` otherwise.
20+
"""
21+
getargs_dottilde(x) = nothing
22+
function getargs_dottilde(expr::Expr)
23+
# Check if the expression is of the form `L .~ R`.
24+
if Meta.isexpr(expr, :call, 3) && expr.args[1] === :.~
25+
return expr.args[2], expr.args[3]
26+
end
27+
28+
# Check if the expression is of the form `(~).(L, R)`.
29+
if Meta.isexpr(expr, :., 2) && expr.args[1] === :~ &&
30+
Meta.isexpr(expr.args[2], :tuple, 2)
31+
return expr.args[2].args[1], expr.args[2].args[2]
32+
end
33+
34+
return
35+
end
36+
37+
"""
38+
getargs_tilde(x)
39+
40+
Return the arguments `L` and `R`, if `x` is an expression of the form `L ~ R`, or `nothing`
41+
otherwise.
42+
"""
43+
getargs_tilde(x) = nothing
44+
function getargs_tilde(expr::Expr)
45+
if Meta.isexpr(expr, :call, 3) && expr.args[1] === :~
46+
return expr.args[2], expr.args[3]
47+
end
48+
return
49+
end
50+
151
############################################
252
# Julia 1.2 temporary fix - Julia PR 33303 #
353
############################################

test/compiler.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,14 +313,23 @@ priors = 0 # See "new grammar" test.
313313
x = randn(100)
314314
res = sample(vdemo1(x), alg, 250)
315315

316+
@model vdemo1b(x) = begin
317+
s ~ InverseGamma(2,3)
318+
m ~ Normal(0, sqrt(s))
319+
@. x ~ Normal(m, $(sqrt(s)))
320+
return s, m
321+
end
322+
323+
res = sample(vdemo1b(x), alg, 250)
324+
316325
D = 2
317326
@model vdemo2(x) = begin
318327
μ ~ MvNormal(zeros(D), ones(D))
319-
@. x ~ MvNormal(μ, ones(D))
328+
@. x ~ $(MvNormal(μ, ones(D)))
320329
end
321330

322331
alg = HMC(0.01, 5)
323-
res = sample(vdemo2(randn(D,100)), alg, 250)
332+
res = sample(vdemo2(randn(D, 100)), alg, 250)
324333

325334
# Vector assumptions
326335
N = 10

0 commit comments

Comments
 (0)