Skip to content

Commit f531f12

Browse files
committed
Don't "compile" inputs of macros (#222)
At the moment we will actually call `generate_mainbody!` on inputs to macros inside the model, e.g. in a model `@mymacro x ~ Normal()` will actually result in code `@mymacro $(generate_mainbody!(:(x ~ Normal())))` (or something, you get the idea). IMO, this shouldn't be done for the following reasons: 1. Breaks with what you'd expect in Julia, IMO, which is that a macro eats the "raw" code. 2. Means that if we want to do stuff like `@reparam` from #220 (and a bunch of others, see #221 for a small list of possibilities), we need touch the compiler rather than just make a small macro that will perform transformations *after* the compiler has done it's job (referring to DynamicPPL compiler here). 3. If the user wants to use a macro on some variables, but they want the actual variable rather than messing around with the sample-statement, they can just separate it into two lines, e.g. `x ~ Normal(); @mymacro ...`. Also, to be completely honest, for the longest time I've just assumed that I'm not even allowed to do `@mymacro x ~ Normal()` and have things work 😅 I bet a lot of people have the same impression by default (though this might of course just not be true:) )
1 parent 2d6ef3f commit f531f12

File tree

3 files changed

+61
-18
lines changed

3 files changed

+61
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.10.9"
3+
version = "0.10.10"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/compiler.jl

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,15 @@ To generate a `Model`, call `model(xvalue)` or `model(xvalue, yvalue)`.
6464
macro model(expr, warn=true)
6565
# include `LineNumberNode` with information about the call site in the
6666
# generated function for easier debugging and interpretation of error messages
67-
esc(model(expr, __source__, warn))
67+
esc(model(__module__, __source__, expr, warn))
6868
end
6969

70-
function model(expr, linenumbernode, warn)
70+
function model(mod, linenumbernode, expr, warn)
7171
modelinfo = build_model_info(expr)
7272

7373
# Generate main body
7474
modelinfo[:body] = generate_mainbody(
75-
modelinfo[:modeldef][:body], modelinfo[:allargs_syms], warn
75+
mod, modelinfo[:modeldef][:body], modelinfo[:allargs_syms], warn
7676
)
7777

7878
return build_output(modelinfo, linenumbernode)
@@ -155,53 +155,52 @@ function build_model_info(input_expr)
155155
end
156156

157157
"""
158-
generate_mainbody(expr, args, warn)
158+
generate_mainbody(mod, expr, args, warn)
159159
160160
Generate the body of the main evaluation function from expression `expr` and arguments
161161
`args`.
162162
163163
If `warn` is true, a warning is displayed if internal variables are used in the model
164164
definition.
165165
"""
166-
generate_mainbody(expr, args, warn) = generate_mainbody!(Symbol[], expr, args, warn)
166+
generate_mainbody(mod, expr, args, warn) = generate_mainbody!(mod, Symbol[], expr, args, warn)
167167

168-
generate_mainbody!(found, x, args, warn) = x
169-
function generate_mainbody!(found, sym::Symbol, args, warn)
168+
generate_mainbody!(mod, found, x, args, warn) = x
169+
function generate_mainbody!(mod, found, sym::Symbol, args, warn)
170170
if warn && sym in INTERNALNAMES && sym found
171171
@warn "you are using the internal variable `$(sym)`"
172172
push!(found, sym)
173173
end
174174
return sym
175175
end
176-
function generate_mainbody!(found, expr::Expr, args, warn)
176+
function generate_mainbody!(mod, found, expr::Expr, args, warn)
177177
# Do not touch interpolated expressions
178178
expr.head === :$ && return expr.args[1]
179179

180-
# Apply the `@.` macro first.
181-
if Meta.isexpr(expr, :macrocall) && length(expr.args) > 1 &&
182-
expr.args[1] === Symbol("@__dot__")
183-
return generate_mainbody!(found, Base.Broadcast.__dot__(expr.args[end]), args, warn)
180+
# If it's a macro, we expand it
181+
if Meta.isexpr(expr, :macrocall)
182+
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), args, warn)
184183
end
185184

186185
# Modify dotted tilde operators.
187186
args_dottilde = getargs_dottilde(expr)
188187
if args_dottilde !== nothing
189188
L, R = args_dottilde
190-
return generate_dot_tilde(generate_mainbody!(found, L, args, warn),
191-
generate_mainbody!(found, R, args, warn),
189+
return generate_dot_tilde(generate_mainbody!(mod, found, L, args, warn),
190+
generate_mainbody!(mod, found, R, args, warn),
192191
args) |> Base.remove_linenums!
193192
end
194193

195194
# Modify tilde operators.
196195
args_tilde = getargs_tilde(expr)
197196
if args_tilde !== nothing
198197
L, R = args_tilde
199-
return generate_tilde(generate_mainbody!(found, L, args, warn),
200-
generate_mainbody!(found, R, args, warn),
198+
return generate_tilde(generate_mainbody!(mod, found, L, args, warn),
199+
generate_mainbody!(mod, found, R, args, warn),
201200
args) |> Base.remove_linenums!
202201
end
203202

204-
return Expr(expr.head, map(x -> generate_mainbody!(found, x, args, warn), expr.args)...)
203+
return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, args, warn), expr.args)...)
205204
end
206205

207206

test/compiler.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,29 @@ macro custom(expr)
66
end
77
end
88

9+
macro mymodel1(ex)
10+
# check if expression was modified by the DynamicPPL "compiler"
11+
if ex == :(y ~ Uniform())
12+
return esc(:(x ~ Normal()))
13+
else
14+
return esc(:(z ~ Exponential()))
15+
end
16+
end
17+
18+
struct MyModelStruct{T}
19+
x::T
20+
end
21+
Base.:~(x, y::MyModelStruct) = y.x
22+
macro mymodel2(ex)
23+
# check if expression was modified by the DynamicPPL "compiler"
24+
if ex == :(y ~ Uniform())
25+
# Just returns 42
26+
return :(4 ~ MyModelStruct(42))
27+
else
28+
return :(return -1)
29+
end
30+
end
31+
932
@testset "compiler.jl" begin
1033
@testset "model macro" begin
1134
@model function testmodel_comp(x, y)
@@ -269,4 +292,25 @@ end
269292
end
270293
@test isempty(VarInfo(demo_with(0.0)))
271294
end
295+
296+
@testset "macros within model" begin
297+
# Macro expansion
298+
@model function demo()
299+
@mymodel1(y ~ Uniform())
300+
end
301+
302+
@test haskey(VarInfo(demo()), @varname(x))
303+
304+
# Interpolation
305+
# Will fail if:
306+
# 1. Compiler expands `y ~ Uniform()` before expanding the macros
307+
# => returns -1.
308+
# 2. `@mymodel` is expanded before entire `@model` has been
309+
# expanded => errors since `MyModelStruct` is not a distribution,
310+
# and hence `tilde_observe` errors.
311+
@model function demo()
312+
$(@mymodel2(y ~ Uniform()))
313+
end
314+
@test demo()() == 42
315+
end
272316
end

0 commit comments

Comments
 (0)