Skip to content

Commit 3a8f14a

Browse files
committed
Do not modify interpolated expressions
1 parent 9fbe09a commit 3a8f14a

File tree

5 files changed

+72
-85
lines changed

5 files changed

+72
-85
lines changed

src/compiler.jl

Lines changed: 51 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,17 @@ end
6969
7070
To generate a `Model`, call `model_generator(x_value)`.
7171
"""
72-
macro model(input_expr)
73-
Base.replace_ref_end!(input_expr) |> build_model_info |> replace_tilde! |> replace_vi! |>
74-
replace_logpdf! |> replace_sampler! |> build_output
72+
macro model(expr)
73+
esc(model(expr))
74+
end
75+
76+
function model(expr)
77+
modelinfo = build_model_info(expr)
78+
79+
# Generate main body
80+
modelinfo[:main_body] = generate_mainbody(modelinfo)
81+
82+
return build_output(modelinfo)
7583
end
7684

7785
"""
@@ -170,62 +178,55 @@ function build_model_info(input_expr)
170178
return model_info
171179
end
172180

173-
174181
"""
175-
replace_vi!(model_info)
182+
generate_mainbody([expr, ]modelinfo)
176183
177-
Replaces `@varinfo()` expressions with a handle to the `VarInfo` struct.
184+
Generate the body of the main evaluation function.
178185
"""
179-
function replace_vi!(model_info)
180-
ex = model_info[:main_body]
181-
vi = model_info[:main_body_names][:vi]
182-
ex = MacroTools.postwalk(ex) do x
183-
if @capture(x, @varinfo())
184-
vi
185-
else
186-
x
187-
end
188-
end
189-
model_info[:main_body] = ex
190-
return model_info
191-
end
186+
generate_mainbody(modelinfo) = generate_mainbody(modelinfo[:main_body], modelinfo)
192187

193-
"""
194-
replace_logpdf!(model_info)
188+
generate_mainbody(x, modelinfo) = x
189+
function generate_mainbody(expr::Expr, modelinfo)
190+
# Do not touch interpolated expressions
191+
expr.head === :$ && return expr.args[1]
195192

196-
Replaces `@logpdf()` expressions with the value of the accumulated `logpdf` in the `VarInfo` struct.
197-
"""
198-
function replace_logpdf!(model_info)
199-
ex = model_info[:main_body]
200-
vi = model_info[:main_body_names][:vi]
201-
ex = MacroTools.postwalk(ex) do x
202-
if @capture(x, @logpdf())
203-
:($(vi).logp[])
204-
else
205-
x
193+
# Apply the `@.` macro first.
194+
if Meta.isexpr(expr, :macrocall) && length(expr.args) > 1 &&
195+
expr.args[1] === Symbol("@__dot__")
196+
return generate_mainbody(Base.Broadcast.__dot__(expr.args[end]), modelinfo)
197+
end
198+
199+
# Modify macro calls.
200+
if Meta.isexpr(expr, :macrocall) && !isempty(expr.args)
201+
name = expr.args[1]
202+
if name === Symbol("@varinfo")
203+
return modelinfo[:main_body_names][:vi]
204+
elseif name === Symbol("@logpdf")
205+
return :($(modelinfo[:main_body_names][:vi]).logp[])
206+
elseif name === Symbol("@sampler")
207+
return :($(modelinfo[:main_body_names][:sampler]))
206208
end
207209
end
208-
model_info[:main_body] = ex
209-
return model_info
210-
end
211210

212-
"""
213-
replace_sampler!(model_info)
211+
# Modify dotted tilde operators.
212+
args_dottilde = getargs_dottilde(expr)
213+
if args_dottilde !== nothing
214+
L, R = args_dottilde
215+
return Base.remove_linenums!(generate_dot_tilde(generate_mainbody(L, modelinfo),
216+
generate_mainbody(R, modelinfo),
217+
modelinfo))
218+
end
214219

215-
Replaces `@sampler()` expressions with a handle to the sampler struct.
216-
"""
217-
function replace_sampler!(model_info)
218-
ex = model_info[:main_body]
219-
spl = model_info[:main_body_names][:sampler]
220-
ex = MacroTools.postwalk(ex) do x
221-
if @capture(x, @sampler())
222-
spl
223-
else
224-
x
225-
end
220+
# Modify tilde operators.
221+
args_tilde = getargs_tilde(expr)
222+
if args_tilde !== nothing
223+
L, R = args_tilde
224+
return Base.remove_linenums!(generate_tilde(generate_mainbody(L, modelinfo),
225+
generate_mainbody(R, modelinfo),
226+
modelinfo))
226227
end
227-
model_info[:main_body] = ex
228-
return model_info
228+
229+
return Expr(expr.head, map(x -> generate_mainbody(x, modelinfo), expr.args)...)
229230
end
230231

231232
"""
@@ -443,7 +444,7 @@ function build_output(model_info)
443444
generator_kw_form = isempty(args) ? () : (:($generator(;$(args...)) = $generator($(arg_syms...))),)
444445
model_gen_constructor = :($(DynamicPPL.ModelGen){$(Tuple(arg_syms))}($generator, $defaults_nt))
445446

446-
ex = quote
447+
return quote
447448
function $evaluator(
448449
$model::$(DynamicPPL.Model),
449450
$vi::$(DynamicPPL.VarInfo),
@@ -459,8 +460,6 @@ function build_output(model_info)
459460

460461
$model_gen = $model_gen_constructor
461462
end
462-
463-
return esc(ex)
464463
end
465464

466465

src/utils.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,3 @@
1-
"""
2-
apply_dotted(x)
3-
4-
Apply the transformation of the `@.` macro if `x` is an expression of the form `@. X`.
5-
"""
6-
apply_dotted(x) = x
7-
function apply_dotted(expr::Expr)
8-
if Meta.isexpr(expr, :macrocall) && length(expr.args) > 1 &&
9-
expr.args[1] === Symbol("@__dot__")
10-
return Base.Broadcast.__dot__(expr.args[end])
11-
end
12-
return expr
13-
end
14-
151
"""
162
getargs_dottilde(x)
173

src/varname.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,10 @@ end
187187
vinds(expr::Symbol) = Expr(:tuple)
188188
function vinds(expr::Expr)
189189
if Meta.isexpr(expr, :ref)
190-
last = Expr(:tuple, expr.args[2:end]...)
191-
init = vinds(expr.args[1]).args
190+
ex = copy(expr)
191+
Base.replace_ref_end!(ex)
192+
last = Expr(:tuple, ex.args[2:end]...)
193+
init = vinds(ex.args[1]).args
192194
return Expr(:tuple, init..., last)
193195
else
194196
throw("VarName: Mis-formed variable name $(expr)!")

test/compiler.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@ Random.seed!(129)
88

99
priors = 0 # See "new grammar" test.
1010

11+
macro custom(expr)
12+
(Meta.isexpr(expr, :call, 3) && expr.args[1] === :~) ||
13+
error("incorrect macro usage")
14+
quote
15+
$(esc(expr.args[2])) = 0.0
16+
end
17+
end
18+
1119
@testset "compiler.jl" begin
1220
@testset "assume" begin
1321
@model test_assume() = begin
@@ -550,4 +558,12 @@ priors = 0 # See "new grammar" test.
550558
@test haskey(vi3.metadata, :y)
551559
@test vi3.metadata.y.vns[1] == VarName(:y, ((1,),))
552560
end
561+
@testset "custom tilde" begin
562+
@model demo() = begin
563+
$(@custom m ~ Normal())
564+
return m
565+
end
566+
model = demo()
567+
@test all(iszero(model()) for _ in 1:1000)
568+
end
553569
end

test/utils.jl

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,8 @@
11
using DynamicPPL
2-
using DynamicPPL: apply_dotted, getargs_dottilde, getargs_tilde
2+
using DynamicPPL: getargs_dottilde, getargs_tilde
33

44
using Test
55

6-
@testset "apply_dotted" begin
7-
# Some things that are not expressions.
8-
@test apply_dotted(:x) === :x
9-
@test apply_dotted(1.0) === 1.0
10-
@test apply_dotted([1.0, 2.0, 4.0]) == [1.0, 2.0, 4.0]
11-
12-
# Some expressions.
13-
@test apply_dotted(:(x ~ Normal(μ, σ))) == :(x ~ Normal(μ, σ))
14-
@test apply_dotted(:((.~)(x, Normal(μ, σ)))) == :((.~)(x, Normal(μ, σ)))
15-
@test apply_dotted(:((~).(x, Normal(μ, σ)))) == :((~).(x, Normal(μ, σ)))
16-
@test apply_dotted(:(@. x ~ Normal(μ, σ))) == :((~).(x, Normal.(μ, σ)))
17-
@test apply_dotted(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) ==
18-
:((~).(x, Normal.(μ, sqrt(v))))
19-
@test apply_dotted(:(@~ Normal.(μ, σ))) == :(@~ Normal.(μ, σ))
20-
end
21-
226
@testset "getargs_dottilde" begin
237
# Some things that are not expressions.
248
@test getargs_dottilde(:x) === nothing

0 commit comments

Comments
 (0)