Skip to content

Commit 312e7da

Browse files
authored
Merge pull request #54 from TuringLang/dottilde
Take @. seriously and avoid eval
2 parents 61c4825 + 43e8eee commit 312e7da

File tree

5 files changed

+189
-82
lines changed

5 files changed

+189
-82
lines changed

src/compiler.jl

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -246,40 +246,40 @@ function replace_sampler!(model_info)
246246
return model_info
247247
end
248248

249-
# The next function is defined that way because .~ gives a parsing error in Julia 1.0
250249
"""
251-
\"""
252250
replace_tilde!(model_info)
253251
254-
Replaces `~` expressions with observation or assumption expressions, updating `model_info`.
255-
\"""
252+
Replace `~` and `.~` expressions with observation or assumption expressions, updating `model_info`.
253+
"""
256254
function replace_tilde!(model_info)
257-
ex = model_info[:main_body]
258-
ex = MacroTools.postwalk(ex) do x
259-
if @capture(x, @M_ L_ ~ R_) && M == Symbol("@__dot__")
260-
generate_dot_tilde(L, R, model_info)
261-
else
262-
x
263-
end
264-
end
265-
$(VERSION >= v"1.1" ? "ex = MacroTools.postwalk(ex) do x
266-
if @capture(x, L_ .~ R_)
267-
generate_dot_tilde(L, R, model_info)
268-
else
269-
x
255+
# Apply the `@.` macro first.
256+
expr = model_info[:main_body]
257+
dottedexpr = MacroTools.postwalk(apply_dotted, expr)
258+
259+
# Check for tilde operators.
260+
tildeexpr = MacroTools.postwalk(dottedexpr) do x
261+
# Check dot tilde first.
262+
dotargs = getargs_dottilde(x)
263+
if dotargs !== nothing
264+
L, R = dotargs
265+
return generate_dot_tilde(L, R, model_info)
270266
end
271-
end" : "")
272-
ex = MacroTools.postwalk(ex) do x
273-
if @capture(x, L_ ~ R_)
274-
generate_tilde(L, R, model_info)
275-
else
276-
x
267+
268+
# Check tilde.
269+
args = getargs_tilde(x)
270+
if args !== nothing
271+
L, R = args
272+
return generate_tilde(L, R, model_info)
277273
end
274+
275+
return x
278276
end
279-
model_info[:main_body] = ex
277+
278+
# Update the function body.
279+
model_info[:main_body] = tildeexpr
280+
280281
return model_info
281282
end
282-
""" |> Meta.parse |> eval
283283

284284
# """ Unbreak code highlighting in Emacs julia-mode
285285

src/utils.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,53 @@
1+
"""
2+
apply_dotted(x)
3+
4+
Apply the transformation of the `@.` macro if `x` is an expression of the form `@. X`.
5+
"""
6+
apply_dotted(x) = x
7+
function apply_dotted(expr::Expr)
8+
if Meta.isexpr(expr, :macrocall) && length(expr.args) > 1 &&
9+
expr.args[1] === Symbol("@__dot__")
10+
return Base.Broadcast.__dot__(expr.args[end])
11+
end
12+
return expr
13+
end
14+
15+
"""
16+
getargs_dottilde(x)
17+
18+
Return the arguments `L` and `R`, if `x` is an expression of the form `L .~ R` or
19+
`(~).(L, R)`, or `nothing` otherwise.
20+
"""
21+
getargs_dottilde(x) = nothing
22+
function getargs_dottilde(expr::Expr)
23+
# Check if the expression is of the form `L .~ R`.
24+
if Meta.isexpr(expr, :call, 3) && expr.args[1] === :.~
25+
return expr.args[2], expr.args[3]
26+
end
27+
28+
# Check if the expression is of the form `(~).(L, R)`.
29+
if Meta.isexpr(expr, :., 2) && expr.args[1] === :~ &&
30+
Meta.isexpr(expr.args[2], :tuple, 2)
31+
return expr.args[2].args[1], expr.args[2].args[2]
32+
end
33+
34+
return
35+
end
36+
37+
"""
38+
getargs_tilde(x)
39+
40+
Return the arguments `L` and `R`, if `x` is an expression of the form `L ~ R`, or `nothing`
41+
otherwise.
42+
"""
43+
getargs_tilde(x) = nothing
44+
function getargs_tilde(expr::Expr)
45+
if Meta.isexpr(expr, :call, 3) && expr.args[1] === :~
46+
return expr.args[2], expr.args[3]
47+
end
48+
return
49+
end
50+
151
############################################
252
# Julia 1.2 temporary fix - Julia PR 33303 #
353
############################################

test/compiler.jl

Lines changed: 63 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,23 @@ priors = 0 # See "new grammar" test.
327327
x = randn(100)
328328
res = sample(vdemo1(x), alg, 250)
329329

330+
@model vdemo1b(x) = begin
331+
s ~ InverseGamma(2,3)
332+
m ~ Normal(0, sqrt(s))
333+
@. x ~ Normal(m, $(sqrt(s)))
334+
return s, m
335+
end
336+
337+
res = sample(vdemo1b(x), alg, 250)
338+
330339
D = 2
331340
@model vdemo2(x) = begin
332341
μ ~ MvNormal(zeros(D), ones(D))
333-
@. x ~ MvNormal(μ, ones(D))
342+
@. x ~ $(MvNormal(μ, ones(D)))
334343
end
335344

336345
alg = HMC(0.01, 5)
337-
res = sample(vdemo2(randn(D,100)), alg, 250)
346+
res = sample(vdemo2(randn(D, 100)), alg, 250)
338347

339348
# Vector assumptions
340349
N = 10
@@ -386,78 +395,75 @@ priors = 0 # See "new grammar" test.
386395
sample(vdemo7(), alg, 1000)
387396
end
388397

389-
if VERSION >= v"1.1"
390-
"""
391-
@testset "vectorization .~" begin
392-
@model vdemo1(x) = begin
393-
s ~ InverseGamma(2,3)
394-
m ~ Normal(0, sqrt(s))
395-
x .~ Normal(m, sqrt(s))
396-
return s, m
397-
end
398+
# Notation is ugly since `x .~ Normal(μ, σ)` cannot be parsed in Julia 1.0
399+
@testset "vectorization .~" begin
400+
@model vdemo1(x) = begin
401+
s ~ InverseGamma(2,3)
402+
m ~ Normal(0, sqrt(s))
403+
(.~)(x, Normal(m, sqrt(s)))
404+
return s, m
405+
end
398406

399-
alg = HMC(0.01, 5)
400-
x = randn(100)
401-
res = sample(vdemo1(x), alg, 250)
407+
alg = HMC(0.01, 5)
408+
x = randn(100)
409+
res = sample(vdemo1(x), alg, 250)
402410

403-
D = 2
404-
@model vdemo2(x) = begin
405-
μ ~ MvNormal(zeros(D), ones(D))
406-
x .~ MvNormal(μ, ones(D))
407-
end
411+
D = 2
412+
@model vdemo2(x) = begin
413+
μ ~ MvNormal(zeros(D), ones(D))
414+
(.~)(x, MvNormal(μ, ones(D)))
415+
end
408416

409-
alg = HMC(0.01, 5)
410-
res = sample(vdemo2(randn(D,100)), alg, 250)
417+
alg = HMC(0.01, 5)
418+
res = sample(vdemo2(randn(D,100)), alg, 250)
411419

412-
# Vector assumptions
413-
N = 10
414-
setchunksize(N)
415-
alg = HMC(0.2, 4)
420+
# Vector assumptions
421+
N = 10
422+
setchunksize(N)
423+
alg = HMC(0.2, 4)
416424

417-
@model vdemo3() = begin
418-
x = Vector{Real}(undef, N)
419-
for i = 1:N
420-
x[i] ~ Normal(0, sqrt(4))
421-
end
425+
@model vdemo3() = begin
426+
x = Vector{Real}(undef, N)
427+
for i = 1:N
428+
x[i] ~ Normal(0, sqrt(4))
422429
end
430+
end
423431

424-
t_loop = @elapsed res = sample(vdemo3(), alg, 1000)
432+
t_loop = @elapsed res = sample(vdemo3(), alg, 1000)
425433

426-
# Test for vectorize UnivariateDistribution
427-
@model vdemo4() = begin
434+
# Test for vectorize UnivariateDistribution
435+
@model vdemo4() = begin
428436
x = Vector{Real}(undef, N)
429-
x .~ Normal(0, 2)
430-
end
437+
(.~)(x, Normal(0, 2))
438+
end
431439

432-
t_vec = @elapsed res = sample(vdemo4(), alg, 1000)
440+
t_vec = @elapsed res = sample(vdemo4(), alg, 1000)
433441

434-
@model vdemo5() = begin
435-
x ~ MvNormal(zeros(N), 2 * ones(N))
436-
end
442+
@model vdemo5() = begin
443+
x ~ MvNormal(zeros(N), 2 * ones(N))
444+
end
437445

438-
t_mv = @elapsed res = sample(vdemo5(), alg, 1000)
446+
t_mv = @elapsed res = sample(vdemo5(), alg, 1000)
439447

440-
println("Time for")
441-
println(" Loop : \$t_loop")
442-
println(" Vec : \$t_vec")
443-
println(" Mv : \$t_mv")
448+
println("Time for")
449+
println(" Loop : \$t_loop")
450+
println(" Vec : \$t_vec")
451+
println(" Mv : \$t_mv")
444452

445-
# Transformed test
446-
@model vdemo6() = begin
447-
x = Vector{Real}(undef, N)
448-
x .~ InverseGamma(2, 3)
449-
end
453+
# Transformed test
454+
@model vdemo6() = begin
455+
x = Vector{Real}(undef, N)
456+
(.~)(x, InverseGamma(2, 3))
457+
end
450458

451-
sample(vdemo6(), alg, 1000)
459+
sample(vdemo6(), alg, 1000)
452460

453-
@model vdemo7() = begin
454-
x = Array{Real}(undef, N, N)
455-
x .~ [InverseGamma(2, 3) for i in 1:N]
456-
end
457-
458-
sample(vdemo7(), alg, 1000)
461+
@model vdemo7() = begin
462+
x = Array{Real}(undef, N, N)
463+
(.~)(x, [InverseGamma(2, 3) for i in 1:N])
459464
end
460-
""" |> Meta.parse |> eval
465+
466+
sample(vdemo7(), alg, 1000)
461467
end
462468

463469
@testset "Type parameters" begin

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using .Turing
66
turnprogress(false)
77

88
@testset "DynamicPPL.jl" begin
9+
include("utils.jl")
910
include("compiler.jl")
1011
include("varinfo.jl")
1112
include("prob_macro.jl")

test/utils.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
using DynamicPPL
2+
using DynamicPPL: apply_dotted, getargs_dottilde, getargs_tilde
3+
4+
using Test
5+
6+
@testset "apply_dotted" begin
7+
# Some things that are not expressions.
8+
@test apply_dotted(:x) === :x
9+
@test apply_dotted(1.0) === 1.0
10+
@test apply_dotted([1.0, 2.0, 4.0]) == [1.0, 2.0, 4.0]
11+
12+
# Some expressions.
13+
@test apply_dotted(:(x ~ Normal(μ, σ))) == :(x ~ Normal(μ, σ))
14+
@test apply_dotted(:((.~)(x, Normal(μ, σ)))) == :((.~)(x, Normal(μ, σ)))
15+
@test apply_dotted(:((~).(x, Normal(μ, σ)))) == :((~).(x, Normal(μ, σ)))
16+
@test apply_dotted(:(@. x ~ Normal(μ, σ))) == :((~).(x, Normal.(μ, σ)))
17+
@test apply_dotted(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) ==
18+
:((~).(x, Normal.(μ, sqrt(v))))
19+
@test apply_dotted(:(@~ Normal.(μ, σ))) == :(@~ Normal.(μ, σ))
20+
end
21+
22+
@testset "getargs_dottilde" begin
23+
# Some things that are not expressions.
24+
@test getargs_dottilde(:x) === nothing
25+
@test getargs_dottilde(1.0) === nothing
26+
@test getargs_dottilde([1.0, 2.0, 4.0]) === nothing
27+
28+
# Some expressions.
29+
@test getargs_dottilde(:(x ~ Normal(μ, σ))) === nothing
30+
@test getargs_dottilde(:((.~)(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ)))
31+
@test getargs_dottilde(:((~).(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ)))
32+
@test getargs_dottilde(:(@. x ~ Normal(μ, σ))) === nothing
33+
@test getargs_dottilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing
34+
@test getargs_dottilde(:(@~ Normal.(μ, σ))) === nothing
35+
end
36+
37+
@testset "getargs_tilde" begin
38+
# Some things that are not expressions.
39+
@test getargs_tilde(:x) === nothing
40+
@test getargs_tilde(1.0) === nothing
41+
@test getargs_tilde([1.0, 2.0, 4.0]) === nothing
42+
43+
# Some expressions.
44+
@test getargs_tilde(:(x ~ Normal(μ, σ))) == (:x, :(Normal(μ, σ)))
45+
@test getargs_tilde(:((.~)(x, Normal(μ, σ)))) === nothing
46+
@test getargs_tilde(:((~).(x, Normal(μ, σ)))) === nothing
47+
@test getargs_tilde(:(@. x ~ Normal(μ, σ))) === nothing
48+
@test getargs_tilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing
49+
@test getargs_tilde(:(@~ Normal.(μ, σ))) === nothing
50+
end

0 commit comments

Comments
 (0)