Skip to content

Commit 581be52

Browse files
authored
Merge pull request #71 from TuringLang/internal
Remove `@varinfo`, `@logpdf`, and `@sampler`
2 parents 13a1ae0 + ce44402 commit 581be52

File tree

5 files changed

+77
-119
lines changed

5 files changed

+77
-119
lines changed

src/DynamicPPL.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,6 @@ export AbstractVarInfo,
5050
ModelGen,
5151
@model,
5252
@varname,
53-
@varinfo,
54-
@logpdf,
55-
@sampler,
5653
# Utilities
5754
vectorize,
5855
reconstruct,

src/compiler.jl

Lines changed: 66 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,30 @@
1-
macro varinfo()
2-
:(throw(_error_msg()))
3-
end
4-
macro logpdf()
5-
:(throw(_error_msg()))
6-
end
7-
macro sampler()
8-
:(throw(_error_msg()))
9-
end
10-
function _error_msg()
11-
return "This macro is only for use in the `@model` macro and not for external use."
12-
end
13-
141
const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " *
152
"Distributions."
163

4+
const INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo)
5+
176
"""
18-
isassumption(model, expr)
7+
isassumption(expr)
198
209
Return an expression that can be evaluated to check if `expr` is an assumption in the
21-
`model`.
10+
model.
2211
2312
Let `expr` be `:(x[1])`. It is an assumption in the following cases:
24-
1. `x` is not among the input data to the `model`,
25-
2. `x` is among the input data to the `model` but with a value `missing`, or
26-
3. `x` is among the input data to the `model` with a value other than missing,
13+
1. `x` is not among the input data to the model,
14+
2. `x` is among the input data to the model but with a value `missing`, or
15+
3. `x` is among the input data to the model with a value other than missing,
2716
but `x[1] === missing`.
2817
2918
When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
3019
"""
31-
function isassumption(model, expr::Union{Symbol, Expr})
20+
function isassumption(expr::Union{Symbol, Expr})
3221
vn = gensym(:vn)
3322

3423
return quote
3524
let $vn = $(varname(expr))
3625
# This branch should compile nicely in all cases except for partial missing data
3726
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
38-
if !$(DynamicPPL.inargnames)($vn, $model) || $(DynamicPPL.inmissings)($vn, $model)
27+
if !$(DynamicPPL.inargnames)($vn, _model) || $(DynamicPPL.inmissings)($vn, _model)
3928
true
4029
else
4130
# Evaluate the LHS
@@ -46,7 +35,7 @@ function isassumption(model, expr::Union{Symbol, Expr})
4635
end
4736

4837
# failsafe: a literal is never an assumption
49-
isassumption(model, expr) = :(false)
38+
isassumption(expr) = :(false)
5039

5140
#################
5241
# Main Compiler #
@@ -77,7 +66,7 @@ function model(expr)
7766
modelinfo = build_model_info(expr)
7867

7968
# Generate main body
80-
modelinfo[:main_body] = generate_mainbody(modelinfo)
69+
modelinfo[:main_body] = generate_mainbody(modelinfo[:main_body], modelinfo[:args])
8170

8271
return build_output(modelinfo)
8372
end
@@ -166,84 +155,69 @@ function build_model_info(input_expr)
166155
:args_nt => args_nt,
167156
:defaults_nt => defaults_nt,
168157
:args => args,
169-
:whereparams => modeldef[:whereparams],
170-
:main_body_names => Dict(
171-
:ctx => gensym(:ctx),
172-
:vi => gensym(:vi),
173-
:sampler => gensym(:sampler),
174-
:model => gensym(:model)
175-
)
158+
:whereparams => modeldef[:whereparams]
176159
)
177160

178161
return model_info
179162
end
180163

181164
"""
182-
generate_mainbody([expr, ]modelinfo)
165+
generate_mainbody(expr, args)
183166
184-
Generate the body of the main evaluation function.
167+
Generate the body of the main evaluation function from expression `expr` and arguments
168+
`args`.
185169
"""
186-
generate_mainbody(modelinfo) = generate_mainbody(modelinfo[:main_body], modelinfo)
170+
generate_mainbody(expr, args) = generate_mainbody!(Symbol[], expr, args)
187171

188-
generate_mainbody(x, modelinfo) = x
189-
function generate_mainbody(expr::Expr, modelinfo)
172+
generate_mainbody!(found, x, args) = x
173+
function generate_mainbody!(found, sym::Symbol, args)
174+
if sym in INTERNALNAMES && sym found
175+
@warn "you are using the internal variable `$(sym)`"
176+
push!(found, sym)
177+
end
178+
return sym
179+
end
180+
function generate_mainbody!(found, expr::Expr, args)
190181
# Do not touch interpolated expressions
191182
expr.head === :$ && return expr.args[1]
192183

193184
# Apply the `@.` macro first.
194185
if Meta.isexpr(expr, :macrocall) && length(expr.args) > 1 &&
195186
expr.args[1] === Symbol("@__dot__")
196-
return generate_mainbody(Base.Broadcast.__dot__(expr.args[end]), modelinfo)
197-
end
198-
199-
# Modify macro calls.
200-
if Meta.isexpr(expr, :macrocall) && !isempty(expr.args)
201-
name = expr.args[1]
202-
if name === Symbol("@varinfo")
203-
return modelinfo[:main_body_names][:vi]
204-
elseif name === Symbol("@logpdf")
205-
return :($(modelinfo[:main_body_names][:vi]).logp[])
206-
elseif name === Symbol("@sampler")
207-
return :($(modelinfo[:main_body_names][:sampler]))
208-
end
187+
return generate_mainbody!(found, Base.Broadcast.__dot__(expr.args[end]), args)
209188
end
210189

211190
# Modify dotted tilde operators.
212191
args_dottilde = getargs_dottilde(expr)
213192
if args_dottilde !== nothing
214193
L, R = args_dottilde
215-
return Base.remove_linenums!(generate_dot_tilde(generate_mainbody(L, modelinfo),
216-
generate_mainbody(R, modelinfo),
217-
modelinfo))
194+
return Base.remove_linenums!(generate_dot_tilde(generate_mainbody!(found, L, args),
195+
generate_mainbody!(found, R, args),
196+
args))
218197
end
219198

220199
# Modify tilde operators.
221200
args_tilde = getargs_tilde(expr)
222201
if args_tilde !== nothing
223202
L, R = args_tilde
224-
return Base.remove_linenums!(generate_tilde(generate_mainbody(L, modelinfo),
225-
generate_mainbody(R, modelinfo),
226-
modelinfo))
203+
return Base.remove_linenums!(generate_tilde(generate_mainbody!(found, L, args),
204+
generate_mainbody!(found, R, args),
205+
args))
227206
end
228207

229-
return Expr(expr.head, map(x -> generate_mainbody(x, modelinfo), expr.args)...)
208+
return Expr(expr.head, map(x -> generate_mainbody!(found, x, args), expr.args)...)
230209
end
231210

232211
# """ Unbreak code highlighting in Emacs julia-mode
233212

234213

235214
"""
236-
generate_tilde(left, right, model_info)
215+
generate_tilde(left, right, args)
237216
238-
The `tilde` function generates `observe` expression for data variables and `assume`
239-
expressions for parameter variables, updating `model_info` in the process.
217+
Generate an `observe` expression for data variables and `assume` expression for parameter
218+
variables.
240219
"""
241-
function generate_tilde(left, right, model_info)
242-
model = model_info[:main_body_names][:model]
243-
vi = model_info[:main_body_names][:vi]
244-
ctx = model_info[:main_body_names][:ctx]
245-
sampler = model_info[:main_body_names][:sampler]
246-
220+
function generate_tilde(left, right, args)
247221
@gensym tmpright tmpleft
248222
top = [:($tmpright = $right),
249223
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
@@ -254,26 +228,26 @@ function generate_tilde(left, right, model_info)
254228
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))
255229

256230
assumption = [
257-
:($out = $(DynamicPPL.tilde_assume)($ctx, $sampler, $tmpright, $vn, $inds,
258-
$vi)),
259-
:($(DynamicPPL.acclogp!)($vi, $out[2])),
231+
:($out = $(DynamicPPL.tilde_assume)(_context, _sampler, $tmpright, $vn, $inds,
232+
_varinfo)),
233+
:($(DynamicPPL.acclogp!)(_varinfo, $out[2])),
260234
:($left = $out[1])
261235
]
262236

263237
# It can only be an observation if the LHS is an argument of the model
264-
if vsym(left) in model_info[:args]
238+
if vsym(left) in args
265239
@gensym isassumption
266240
return quote
267241
$(top...)
268-
$isassumption = $(DynamicPPL.isassumption(model, left))
242+
$isassumption = $(DynamicPPL.isassumption(left))
269243
if $isassumption
270244
$(assumption...)
271245
else
272246
$tmpleft = $left
273247
$(DynamicPPL.acclogp!)(
274-
$vi,
275-
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $tmpleft, $vn,
276-
$inds, $vi)
248+
_varinfo,
249+
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $tmpleft,
250+
$vn, $inds, _varinfo)
277251
)
278252
$tmpleft
279253
end
@@ -291,26 +265,19 @@ function generate_tilde(left, right, model_info)
291265
$(top...)
292266
$tmpleft = $left
293267
$(DynamicPPL.acclogp!)(
294-
$vi,
295-
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $tmpleft, $vi)
268+
_varinfo,
269+
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $tmpleft, _varinfo)
296270
)
297271
$tmpleft
298272
end
299273
end
300274

301275
"""
302-
generate_dot_tilde(left, right, model_info)
276+
generate_dot_tilde(left, right, args)
303277
304-
This function returns the expression that replaces `left .~ right` in the model body. If
305-
`preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block
306-
will be run.
278+
Generate the expression that replaces `left .~ right` in the model body.
307279
"""
308-
function generate_dot_tilde(left, right, model_info)
309-
model = model_info[:main_body_names][:model]
310-
vi = model_info[:main_body_names][:vi]
311-
ctx = model_info[:main_body_names][:ctx]
312-
sampler = model_info[:main_body_names][:sampler]
313-
280+
function generate_dot_tilde(left, right, args)
314281
@gensym tmpright tmpleft
315282
top = [:($tmpright = $right),
316283
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
@@ -321,26 +288,26 @@ function generate_dot_tilde(left, right, model_info)
321288
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))
322289

323290
assumption = [
324-
:($out = $(DynamicPPL.dot_tilde_assume)($ctx, $sampler, $tmpright, $left,
325-
$vn, $inds, $vi)),
326-
:($(DynamicPPL.acclogp!)($vi, $out[2])),
291+
:($out = $(DynamicPPL.dot_tilde_assume)(_context, _sampler, $tmpright, $left,
292+
$vn, $inds, _varinfo)),
293+
:($(DynamicPPL.acclogp!)(_varinfo, $out[2])),
327294
:($left .= $out[1])
328295
]
329296

330297
# It can only be an observation if the LHS is an argument of the model
331-
if vsym(left) in model_info[:args]
298+
if vsym(left) in args
332299
@gensym isassumption
333300
return quote
334301
$(top...)
335-
$isassumption = $(DynamicPPL.isassumption(model, left))
302+
$isassumption = $(DynamicPPL.isassumption(left))
336303
if $isassumption
337304
$(assumption...)
338305
else
339306
$tmpleft = $left
340307
$(DynamicPPL.acclogp!)(
341-
$vi,
342-
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $tmpleft,
343-
$vn, $inds, $vi)
308+
_varinfo,
309+
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright,
310+
$tmpleft, $vn, $inds, _varinfo)
344311
)
345312
$tmpleft
346313
end
@@ -358,8 +325,9 @@ function generate_dot_tilde(left, right, model_info)
358325
$(top...)
359326
$tmpleft = $left
360327
$(DynamicPPL.acclogp!)(
361-
$vi,
362-
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $tmpleft, $vi)
328+
_varinfo,
329+
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $tmpleft,
330+
_varinfo)
363331
)
364332
$tmpleft
365333
end
@@ -376,13 +344,6 @@ hasmissing(T::Type) = false
376344
Builds the output expression.
377345
"""
378346
function build_output(model_info)
379-
# Construct user-facing function
380-
main_body_names = model_info[:main_body_names]
381-
ctx = main_body_names[:ctx]
382-
vi = main_body_names[:vi]
383-
model = main_body_names[:model]
384-
sampler = main_body_names[:sampler]
385-
386347
# Arguments with default values
387348
args = model_info[:args]
388349
# Argument symbols without default values
@@ -402,7 +363,7 @@ function build_output(model_info)
402363
unwrap_data_expr = Expr(:block)
403364
for var in arg_syms
404365
push!(unwrap_data_expr.args,
405-
:($var = $(DynamicPPL.matchingvalue)($sampler, $vi, $(model).args.$var)))
366+
:($var = $(DynamicPPL.matchingvalue)(_sampler, _varinfo, _model.args.$var)))
406367
end
407368

408369
@gensym(evaluator, generator)
@@ -411,10 +372,10 @@ function build_output(model_info)
411372

412373
return quote
413374
function $evaluator(
414-
$model::$(DynamicPPL.Model),
415-
$vi::$(DynamicPPL.VarInfo),
416-
$sampler::$(DynamicPPL.AbstractSampler),
417-
$ctx::$(DynamicPPL.AbstractContext),
375+
_model::$(DynamicPPL.Model),
376+
_varinfo::$(DynamicPPL.VarInfo),
377+
_sampler::$(DynamicPPL.AbstractSampler),
378+
_context::$(DynamicPPL.AbstractContext),
418379
)
419380
$unwrap_data_expr
420381
$main_body

test/Turing/Turing.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,6 @@ end
6666
# Turing essentials - modelling macros and inference algorithms
6767
export @model, # modelling
6868
@varname,
69-
@varinfo,
70-
@logpdf,
71-
@sampler,
7269
DynamicPPL,
7370

7471
MH, # classic sampling

test/Turing/core/Core.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,6 @@ export @model,
5252
ADBACKEND,
5353
setchunksize,
5454
verifygrad,
55-
@varinfo,
56-
@logpdf,
57-
@sampler,
5855
@logprob_str,
5956
@prob_str
6057

test/compiler.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,18 +208,24 @@ end
208208
end
209209
@test_throws MethodError testmodel(missing)()
210210

211-
# Test @varinfo() and @logpdf()
211+
# Test use of internal names
212212
@model testmodel(x) = begin
213213
x[1] ~ Bernoulli(0.5)
214-
global _varinfo = @varinfo()
215-
global lp = @logpdf()
214+
global varinfo_ = _varinfo
215+
global sampler_ = _sampler
216+
global model_ = _model
217+
global context_ = _context
218+
global lp = getlogp(_varinfo)
216219
return x
217220
end
218221
model = testmodel([1.0])
219222
varinfo = DynamicPPL.VarInfo(model)
220223
model(varinfo)
221224
@test getlogp(varinfo) == lp
222-
@test varinfo === _varinfo
225+
@test varinfo_ === varinfo
226+
@test model_ === model
227+
@test sampler_ === SampleFromPrior()
228+
@test context_ === DefaultContext()
223229

224230
# test DPPL#61
225231
@model testmodel(z) = begin
@@ -234,7 +240,7 @@ end
234240
function makemodel(p)
235241
@model testmodel(x) = begin
236242
x[1] ~ Bernoulli(p)
237-
global lp = @logpdf()
243+
global lp = getlogp(_varinfo)
238244
return x
239245
end
240246
return testmodel

0 commit comments

Comments
 (0)