Skip to content

Commit 999bb05

Browse files
authored
Merge pull request #110 from TuringLang/warning
Allow to disable warnings
2 parents c3d430c + 2ae1a8f commit 999bb05

File tree

2 files changed

+43
-21
lines changed

2 files changed

+43
-21
lines changed

src/compiler.jl

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,31 +42,34 @@ isassumption(expr) = :(false)
4242
#################
4343

4444
"""
45-
@model(body)
45+
@model(expr[, warn = true])
4646
4747
Macro to specify a probabilistic model.
4848
49-
Example:
49+
If `warn` is `true`, a warning is displayed if internal variable names are used in the model
50+
definition.
51+
52+
# Example
5053
5154
Model definition:
5255
5356
```julia
54-
@model model_generator(x = default_x, y) = begin
57+
@model function model_generator(x = default_x, y)
5558
...
5659
end
5760
```
5861
5962
To generate a `Model`, call `model_generator(x_value)`.
6063
"""
61-
macro model(expr)
62-
esc(model(expr))
64+
macro model(expr, warn=true)
65+
esc(model(expr, warn))
6366
end
6467

65-
function model(expr)
68+
function model(expr, warn)
6669
modelinfo = build_model_info(expr)
6770

6871
# Generate main body
69-
modelinfo[:main_body] = generate_mainbody(modelinfo[:main_body], modelinfo[:args])
72+
modelinfo[:main_body] = generate_mainbody(modelinfo[:main_body], modelinfo[:args], warn)
7073

7174
return build_output(modelinfo)
7275
end
@@ -162,50 +165,53 @@ function build_model_info(input_expr)
162165
end
163166

164167
"""
165-
generate_mainbody(expr, args)
168+
generate_mainbody(expr, args, warn)
166169
167170
Generate the body of the main evaluation function from expression `expr` and arguments
168171
`args`.
172+
173+
If `warn` is true, a warning is displayed if internal variables are used in the model
174+
definition.
169175
"""
170-
generate_mainbody(expr, args) = generate_mainbody!(Symbol[], expr, args)
176+
generate_mainbody(expr, args, warn) = generate_mainbody!(Symbol[], expr, args, warn)
171177

172-
generate_mainbody!(found, x, args) = x
173-
function generate_mainbody!(found, sym::Symbol, args)
174-
if sym in INTERNALNAMES && sym found
178+
generate_mainbody!(found, x, args, warn) = x
179+
function generate_mainbody!(found, sym::Symbol, args, warn)
180+
if warn && sym in INTERNALNAMES && sym found
175181
@warn "you are using the internal variable `$(sym)`"
176182
push!(found, sym)
177183
end
178184
return sym
179185
end
180-
function generate_mainbody!(found, expr::Expr, args)
186+
function generate_mainbody!(found, expr::Expr, args, warn)
181187
# Do not touch interpolated expressions
182188
expr.head === :$ && return expr.args[1]
183189

184190
# Apply the `@.` macro first.
185191
if Meta.isexpr(expr, :macrocall) && length(expr.args) > 1 &&
186192
expr.args[1] === Symbol("@__dot__")
187-
return generate_mainbody!(found, Base.Broadcast.__dot__(expr.args[end]), args)
193+
return generate_mainbody!(found, Base.Broadcast.__dot__(expr.args[end]), args, warn)
188194
end
189195

190196
# Modify dotted tilde operators.
191197
args_dottilde = getargs_dottilde(expr)
192198
if args_dottilde !== nothing
193199
L, R = args_dottilde
194-
return Base.remove_linenums!(generate_dot_tilde(generate_mainbody!(found, L, args),
195-
generate_mainbody!(found, R, args),
196-
args))
200+
return generate_dot_tilde(generate_mainbody!(found, L, args, warn),
201+
generate_mainbody!(found, R, args, warn),
202+
args) |> Base.remove_linenums!
197203
end
198204

199205
# Modify tilde operators.
200206
args_tilde = getargs_tilde(expr)
201207
if args_tilde !== nothing
202208
L, R = args_tilde
203-
return Base.remove_linenums!(generate_tilde(generate_mainbody!(found, L, args),
204-
generate_mainbody!(found, R, args),
205-
args))
209+
return generate_tilde(generate_mainbody!(found, L, args, warn),
210+
generate_mainbody!(found, R, args, warn),
211+
args) |> Base.remove_linenums!
206212
end
207213

208-
return Expr(expr.head, map(x -> generate_mainbody!(found, x, args), expr.args)...)
214+
return Expr(expr.head, map(x -> generate_mainbody!(found, x, args, warn), expr.args)...)
209215
end
210216

211217

test/compiler.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,22 @@ end
227227
@test sampler_ === SampleFromPrior()
228228
@test context_ === DefaultContext()
229229

230+
# disable warnings
231+
@model testmodel(x) = begin
232+
x[1] ~  Bernoulli(0.5)
233+
global varinfo_ = _varinfo
234+
global sampler_ = _sampler
235+
global model_ = _model
236+
global context_ = _context
237+
global lp = getlogp(_varinfo)
238+
return x
239+
end false
240+
lpold = lp
241+
model = testmodel([1.0])
242+
varinfo = DynamicPPL.VarInfo(model)
243+
model(varinfo)
244+
@test getlogp(varinfo) == lp == lpold
245+
230246
# test DPPL#61
231247
@model testmodel(z) = begin
232248
m ~ Normal()

0 commit comments

Comments
 (0)