Skip to content

Commit 7b62284

Browse files
authored
Merge branch 'master' into function_signature
2 parents fd86733 + 6b4be6b commit 7b62284

File tree

3 files changed

+44
-22
lines changed

3 files changed

+44
-22
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
33
authors = ["mohamed82008 <[email protected]>"]
4-
version = "0.7.2"
4+
version = "0.7.3"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/compiler.jl

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,31 +42,34 @@ isassumption(expr) = :(false)
4242
#################
4343

4444
"""
45-
@model(body)
45+
@model(expr[, warn = true])
4646
4747
Macro to specify a probabilistic model.
4848
49-
Example:
49+
If `warn` is `true`, a warning is displayed if internal variable names are used in the model
50+
definition.
51+
52+
# Example
5053
5154
Model definition:
5255
5356
```julia
54-
@model model_generator(x = default_x, y) = begin
57+
@model function model_generator(x = default_x, y)
5558
...
5659
end
5760
```
5861
5962
To generate a `Model`, call `model_generator(x_value)`.
6063
"""
61-
macro model(expr)
62-
esc(model(expr))
64+
macro model(expr, warn=true)
65+
esc(model(expr, warn))
6366
end
6467

65-
function model(expr)
68+
function model(expr, warn)
6669
modelinfo = build_model_info(expr)
6770

6871
# Generate main body
69-
modelinfo[:modelbody] = generate_mainbody(modelinfo[:body], modelinfo[:modelargs])
72+
modelinfo[:modelbody] = generate_mainbody(modelinfo[:body], modelinfo[:modelargs], warn)
7073

7174
return build_output(modelinfo)
7275
end
@@ -161,50 +164,53 @@ function build_model_info(input_expr)
161164
end
162165

163166
"""
164-
generate_mainbody(expr, args)
167+
generate_mainbody(expr, args, warn)
165168
166169
Generate the body of the main evaluation function from expression `expr` and arguments
167170
`args`.
171+
172+
If `warn` is true, a warning is displayed if internal variables are used in the model
173+
definition.
168174
"""
169-
generate_mainbody(expr, args) = generate_mainbody!(Symbol[], expr, args)
175+
generate_mainbody(expr, args, warn) = generate_mainbody!(Symbol[], expr, args, warn)
170176

171-
generate_mainbody!(found, x, args) = x
172-
function generate_mainbody!(found, sym::Symbol, args)
173-
if sym in INTERNALNAMES && sym found
177+
generate_mainbody!(found, x, args, warn) = x
178+
function generate_mainbody!(found, sym::Symbol, args, warn)
179+
if warn && sym in INTERNALNAMES && sym found
174180
@warn "you are using the internal variable `$(sym)`"
175181
push!(found, sym)
176182
end
177183
return sym
178184
end
179-
function generate_mainbody!(found, expr::Expr, args)
185+
function generate_mainbody!(found, expr::Expr, args, warn)
180186
# Do not touch interpolated expressions
181187
expr.head === :$ && return expr.args[1]
182188

183189
# Apply the `@.` macro first.
184190
if Meta.isexpr(expr, :macrocall) && length(expr.args) > 1 &&
185191
expr.args[1] === Symbol("@__dot__")
186-
return generate_mainbody!(found, Base.Broadcast.__dot__(expr.args[end]), args)
192+
return generate_mainbody!(found, Base.Broadcast.__dot__(expr.args[end]), args, warn)
187193
end
188194

189195
# Modify dotted tilde operators.
190196
args_dottilde = getargs_dottilde(expr)
191197
if args_dottilde !== nothing
192198
L, R = args_dottilde
193-
return Base.remove_linenums!(generate_dot_tilde(generate_mainbody!(found, L, args),
194-
generate_mainbody!(found, R, args),
195-
args))
199+
return generate_dot_tilde(generate_mainbody!(found, L, args, warn),
200+
generate_mainbody!(found, R, args, warn),
201+
args) |> Base.remove_linenums!
196202
end
197203

198204
# Modify tilde operators.
199205
args_tilde = getargs_tilde(expr)
200206
if args_tilde !== nothing
201207
L, R = args_tilde
202-
return Base.remove_linenums!(generate_tilde(generate_mainbody!(found, L, args),
203-
generate_mainbody!(found, R, args),
204-
args))
208+
return generate_tilde(generate_mainbody!(found, L, args, warn),
209+
generate_mainbody!(found, R, args, warn),
210+
args) |> Base.remove_linenums!
205211
end
206212

207-
return Expr(expr.head, map(x -> generate_mainbody!(found, x, args), expr.args)...)
213+
return Expr(expr.head, map(x -> generate_mainbody!(found, x, args, warn), expr.args)...)
208214
end
209215

210216

test/compiler.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,22 @@ end
242242
@test sampler_ === SampleFromPrior()
243243
@test context_ === DefaultContext()
244244

245+
# disable warnings
246+
@model testmodel(x) = begin
247+
x[1] ~  Bernoulli(0.5)
248+
global varinfo_ = _varinfo
249+
global sampler_ = _sampler
250+
global model_ = _model
251+
global context_ = _context
252+
global lp = getlogp(_varinfo)
253+
return x
254+
end false
255+
lpold = lp
256+
model = testmodel([1.0])
257+
varinfo = DynamicPPL.VarInfo(model)
258+
model(varinfo)
259+
@test getlogp(varinfo) == lp == lpold
260+
245261
# test DPPL#61
246262
@model testmodel(z) = begin
247263
m ~ Normal()

0 commit comments

Comments
 (0)