Skip to content

Commit ce44402

Browse files
committed
Remove @varinfo, @logpdf, and @sampler
1 parent 9be5881 commit ce44402

File tree

5 files changed

+77
-119
lines changed

5 files changed

+77
-119
lines changed

src/DynamicPPL.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,6 @@ export AbstractVarInfo,
5050
ModelGen,
5151
@model,
5252
@varname,
53-
@varinfo,
54-
@logpdf,
55-
@sampler,
5653
# Utilities
5754
vectorize,
5855
reconstruct,

src/compiler.jl

Lines changed: 66 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,30 @@
1-
macro varinfo()
2-
:(throw(_error_msg()))
3-
end
4-
macro logpdf()
5-
:(throw(_error_msg()))
6-
end
7-
macro sampler()
8-
:(throw(_error_msg()))
9-
end
10-
function _error_msg()
11-
return "This macro is only for use in the `@model` macro and not for external use."
12-
end
13-
141
const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " *
152
"Distributions."
163

4+
const INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo)
5+
176
"""
18-
isassumption(model, expr)
7+
isassumption(expr)
198
209
Return an expression that can be evaluated to check if `expr` is an assumption in the
21-
`model`.
10+
model.
2211
2312
Let `expr` be `:(x[1])`. It is an assumption in the following cases:
24-
1. `x` is not among the input data to the `model`,
25-
2. `x` is among the input data to the `model` but with a value `missing`, or
26-
3. `x` is among the input data to the `model` with a value other than missing,
13+
1. `x` is not among the input data to the model,
14+
2. `x` is among the input data to the model but with a value `missing`, or
15+
3. `x` is among the input data to the model with a value other than missing,
2716
but `x[1] === missing`.
2817
2918
When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
3019
"""
31-
function isassumption(model, expr::Union{Symbol, Expr})
20+
function isassumption(expr::Union{Symbol, Expr})
3221
vn = gensym(:vn)
3322

3423
return quote
3524
let $vn = $(varname(expr))
3625
# This branch should compile nicely in all cases except for partial missing data
3726
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
38-
if !$(DynamicPPL.inargnames)($vn, $model) || $(DynamicPPL.inmissings)($vn, $model)
27+
if !$(DynamicPPL.inargnames)($vn, _model) || $(DynamicPPL.inmissings)($vn, _model)
3928
true
4029
else
4130
# Evaluate the LHS
@@ -46,7 +35,7 @@ function isassumption(model, expr::Union{Symbol, Expr})
4635
end
4736

4837
# failsafe: a literal is never an assumption
49-
isassumption(model, expr) = :(false)
38+
isassumption(expr) = :(false)
5039

5140
#################
5241
# Main Compiler #
@@ -77,7 +66,7 @@ function model(expr)
7766
modelinfo = build_model_info(expr)
7867

7968
# Generate main body
80-
modelinfo[:main_body] = generate_mainbody(modelinfo)
69+
modelinfo[:main_body] = generate_mainbody(modelinfo[:main_body], modelinfo[:args])
8170

8271
return build_output(modelinfo)
8372
end
@@ -166,67 +155,57 @@ function build_model_info(input_expr)
166155
:args_nt => args_nt,
167156
:defaults_nt => defaults_nt,
168157
:args => args,
169-
:whereparams => modeldef[:whereparams],
170-
:main_body_names => Dict(
171-
:ctx => gensym(:ctx),
172-
:vi => gensym(:vi),
173-
:sampler => gensym(:sampler),
174-
:model => gensym(:model)
175-
)
158+
:whereparams => modeldef[:whereparams]
176159
)
177160

178161
return model_info
179162
end
180163

181164
"""
182-
generate_mainbody([expr, ]modelinfo)
165+
generate_mainbody(expr, args)
183166
184-
Generate the body of the main evaluation function.
167+
Generate the body of the main evaluation function from expression `expr` and arguments
168+
`args`.
185169
"""
186-
generate_mainbody(modelinfo) = generate_mainbody(modelinfo[:main_body], modelinfo)
170+
generate_mainbody(expr, args) = generate_mainbody!(Symbol[], expr, args)
187171

188-
generate_mainbody(x, modelinfo) = x
189-
function generate_mainbody(expr::Expr, modelinfo)
172+
generate_mainbody!(found, x, args) = x
173+
function generate_mainbody!(found, sym::Symbol, args)
174+
if sym in INTERNALNAMES && sym found
175+
@warn "you are using the internal variable `$(sym)`"
176+
push!(found, sym)
177+
end
178+
return sym
179+
end
180+
function generate_mainbody!(found, expr::Expr, args)
190181
# Do not touch interpolated expressions
191182
expr.head === :$ && return expr.args[1]
192183

193184
# Apply the `@.` macro first.
194185
if Meta.isexpr(expr, :macrocall) && length(expr.args) > 1 &&
195186
expr.args[1] === Symbol("@__dot__")
196-
return generate_mainbody(Base.Broadcast.__dot__(expr.args[end]), modelinfo)
197-
end
198-
199-
# Modify macro calls.
200-
if Meta.isexpr(expr, :macrocall) && !isempty(expr.args)
201-
name = expr.args[1]
202-
if name === Symbol("@varinfo")
203-
return modelinfo[:main_body_names][:vi]
204-
elseif name === Symbol("@logpdf")
205-
return :($(modelinfo[:main_body_names][:vi]).logp[])
206-
elseif name === Symbol("@sampler")
207-
return :($(modelinfo[:main_body_names][:sampler]))
208-
end
187+
return generate_mainbody!(found, Base.Broadcast.__dot__(expr.args[end]), args)
209188
end
210189

211190
# Modify dotted tilde operators.
212191
args_dottilde = getargs_dottilde(expr)
213192
if args_dottilde !== nothing
214193
L, R = args_dottilde
215-
return Base.remove_linenums!(generate_dot_tilde(generate_mainbody(L, modelinfo),
216-
generate_mainbody(R, modelinfo),
217-
modelinfo))
194+
return Base.remove_linenums!(generate_dot_tilde(generate_mainbody!(found, L, args),
195+
generate_mainbody!(found, R, args),
196+
args))
218197
end
219198

220199
# Modify tilde operators.
221200
args_tilde = getargs_tilde(expr)
222201
if args_tilde !== nothing
223202
L, R = args_tilde
224-
return Base.remove_linenums!(generate_tilde(generate_mainbody(L, modelinfo),
225-
generate_mainbody(R, modelinfo),
226-
modelinfo))
203+
return Base.remove_linenums!(generate_tilde(generate_mainbody!(found, L, args),
204+
generate_mainbody!(found, R, args),
205+
args))
227206
end
228207

229-
return Expr(expr.head, map(x -> generate_mainbody(x, modelinfo), expr.args)...)
208+
return Expr(expr.head, map(x -> generate_mainbody!(found, x, args), expr.args)...)
230209
end
231210

232211
"""
@@ -268,17 +247,12 @@ end
268247

269248

270249
"""
271-
generate_tilde(left, right, model_info)
250+
generate_tilde(left, right, args)
272251
273-
The `tilde` function generates `observe` expression for data variables and `assume`
274-
expressions for parameter variables, updating `model_info` in the process.
252+
Generate an `observe` expression for data variables and `assume` expression for parameter
253+
variables.
275254
"""
276-
function generate_tilde(left, right, model_info)
277-
model = model_info[:main_body_names][:model]
278-
vi = model_info[:main_body_names][:vi]
279-
ctx = model_info[:main_body_names][:ctx]
280-
sampler = model_info[:main_body_names][:sampler]
281-
255+
function generate_tilde(left, right, args)
282256
@gensym tmpright tmpleft
283257
top = [:($tmpright = $right),
284258
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
@@ -289,26 +263,26 @@ function generate_tilde(left, right, model_info)
289263
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))
290264

291265
assumption = [
292-
:($out = $(DynamicPPL.tilde_assume)($ctx, $sampler, $tmpright, $vn, $inds,
293-
$vi)),
294-
:($(DynamicPPL.acclogp!)($vi, $out[2])),
266+
:($out = $(DynamicPPL.tilde_assume)(_context, _sampler, $tmpright, $vn, $inds,
267+
_varinfo)),
268+
:($(DynamicPPL.acclogp!)(_varinfo, $out[2])),
295269
:($left = $out[1])
296270
]
297271

298272
# It can only be an observation if the LHS is an argument of the model
299-
if vsym(left) in model_info[:args]
273+
if vsym(left) in args
300274
@gensym isassumption
301275
return quote
302276
$(top...)
303-
$isassumption = $(DynamicPPL.isassumption(model, left))
277+
$isassumption = $(DynamicPPL.isassumption(left))
304278
if $isassumption
305279
$(assumption...)
306280
else
307281
$tmpleft = $left
308282
$(DynamicPPL.acclogp!)(
309-
$vi,
310-
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $tmpleft, $vn,
311-
$inds, $vi)
283+
_varinfo,
284+
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $tmpleft,
285+
$vn, $inds, _varinfo)
312286
)
313287
$tmpleft
314288
end
@@ -326,26 +300,19 @@ function generate_tilde(left, right, model_info)
326300
$(top...)
327301
$tmpleft = $left
328302
$(DynamicPPL.acclogp!)(
329-
$vi,
330-
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $tmpleft, $vi)
303+
_varinfo,
304+
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $tmpleft, _varinfo)
331305
)
332306
$tmpleft
333307
end
334308
end
335309

336310
"""
337-
generate_dot_tilde(left, right, model_info)
311+
generate_dot_tilde(left, right, args)
338312
339-
This function returns the expression that replaces `left .~ right` in the model body. If
340-
`preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block
341-
will be run.
313+
Generate the expression that replaces `left .~ right` in the model body.
342314
"""
343-
function generate_dot_tilde(left, right, model_info)
344-
model = model_info[:main_body_names][:model]
345-
vi = model_info[:main_body_names][:vi]
346-
ctx = model_info[:main_body_names][:ctx]
347-
sampler = model_info[:main_body_names][:sampler]
348-
315+
function generate_dot_tilde(left, right, args)
349316
@gensym tmpright tmpleft
350317
top = [:($tmpright = $right),
351318
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
@@ -356,26 +323,26 @@ function generate_dot_tilde(left, right, model_info)
356323
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))
357324

358325
assumption = [
359-
:($out = $(DynamicPPL.dot_tilde_assume)($ctx, $sampler, $tmpright, $left,
360-
$vn, $inds, $vi)),
361-
:($(DynamicPPL.acclogp!)($vi, $out[2])),
326+
:($out = $(DynamicPPL.dot_tilde_assume)(_context, _sampler, $tmpright, $left,
327+
$vn, $inds, _varinfo)),
328+
:($(DynamicPPL.acclogp!)(_varinfo, $out[2])),
362329
:($left .= $out[1])
363330
]
364331

365332
# It can only be an observation if the LHS is an argument of the model
366-
if vsym(left) in model_info[:args]
333+
if vsym(left) in args
367334
@gensym isassumption
368335
return quote
369336
$(top...)
370-
$isassumption = $(DynamicPPL.isassumption(model, left))
337+
$isassumption = $(DynamicPPL.isassumption(left))
371338
if $isassumption
372339
$(assumption...)
373340
else
374341
$tmpleft = $left
375342
$(DynamicPPL.acclogp!)(
376-
$vi,
377-
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $tmpleft,
378-
$vn, $inds, $vi)
343+
_varinfo,
344+
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright,
345+
$tmpleft, $vn, $inds, _varinfo)
379346
)
380347
$tmpleft
381348
end
@@ -393,8 +360,9 @@ function generate_dot_tilde(left, right, model_info)
393360
$(top...)
394361
$tmpleft = $left
395362
$(DynamicPPL.acclogp!)(
396-
$vi,
397-
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $tmpleft, $vi)
363+
_varinfo,
364+
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $tmpleft,
365+
_varinfo)
398366
)
399367
$tmpleft
400368
end
@@ -411,13 +379,6 @@ hasmissing(T::Type) = false
411379
Builds the output expression.
412380
"""
413381
function build_output(model_info)
414-
# Construct user-facing function
415-
main_body_names = model_info[:main_body_names]
416-
ctx = main_body_names[:ctx]
417-
vi = main_body_names[:vi]
418-
model = main_body_names[:model]
419-
sampler = main_body_names[:sampler]
420-
421382
# Arguments with default values
422383
args = model_info[:args]
423384
# Argument symbols without default values
@@ -437,7 +398,7 @@ function build_output(model_info)
437398
unwrap_data_expr = Expr(:block)
438399
for var in arg_syms
439400
push!(unwrap_data_expr.args,
440-
:($var = $(DynamicPPL.matchingvalue)($sampler, $vi, $(model).args.$var)))
401+
:($var = $(DynamicPPL.matchingvalue)(_sampler, _varinfo, _model.args.$var)))
441402
end
442403

443404
@gensym(evaluator, generator)
@@ -446,10 +407,10 @@ function build_output(model_info)
446407

447408
return quote
448409
function $evaluator(
449-
$model::$(DynamicPPL.Model),
450-
$vi::$(DynamicPPL.VarInfo),
451-
$sampler::$(DynamicPPL.AbstractSampler),
452-
$ctx::$(DynamicPPL.AbstractContext),
410+
_model::$(DynamicPPL.Model),
411+
_varinfo::$(DynamicPPL.VarInfo),
412+
_sampler::$(DynamicPPL.AbstractSampler),
413+
_context::$(DynamicPPL.AbstractContext),
453414
)
454415
$unwrap_data_expr
455416
$main_body

test/Turing/Turing.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,6 @@ end
6666
# Turing essentials - modelling macros and inference algorithms
6767
export @model, # modelling
6868
@varname,
69-
@varinfo,
70-
@logpdf,
71-
@sampler,
7269
DynamicPPL,
7370

7471
MH, # classic sampling

test/Turing/core/Core.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,6 @@ export @model,
5252
ADBACKEND,
5353
setchunksize,
5454
verifygrad,
55-
@varinfo,
56-
@logpdf,
57-
@sampler,
5855
@logprob_str,
5956
@prob_str
6057

0 commit comments

Comments
 (0)