1
- macro varinfo ()
2
- :(throw (_error_msg ()))
3
- end
4
- macro logpdf ()
5
- :(throw (_error_msg ()))
6
- end
7
- macro sampler ()
8
- :(throw (_error_msg ()))
9
- end
10
- function _error_msg ()
11
- return " This macro is only for use in the `@model` macro and not for external use."
12
- end
13
-
14
1
const DISTMSG = " Right-hand side of a ~ must be subtype of Distribution or a vector of " *
15
2
" Distributions."
16
3
4
+ const INTERNALNAMES = (:_model , :_sampler , :_context , :_varinfo )
5
+
17
6
"""
18
- isassumption(model, expr)
7
+ isassumption(expr)
19
8
20
9
Return an expression that can be evaluated to check if `expr` is an assumption in the
21
- ` model` .
10
+ model.
22
11
23
12
Let `expr` be `:(x[1])`. It is an assumption in the following cases:
24
- 1. `x` is not among the input data to the ` model` ,
25
- 2. `x` is among the input data to the ` model` but with a value `missing`, or
26
- 3. `x` is among the input data to the ` model` with a value other than missing,
13
+ 1. `x` is not among the input data to the model,
14
+ 2. `x` is among the input data to the model but with a value `missing`, or
15
+ 3. `x` is among the input data to the model with a value other than missing,
27
16
but `x[1] === missing`.
28
17
29
18
When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
30
19
"""
31
- function isassumption (model, expr:: Union{Symbol, Expr} )
20
+ function isassumption (expr:: Union{Symbol, Expr} )
32
21
vn = gensym (:vn )
33
22
34
23
return quote
35
24
let $ vn = $ (varname (expr))
36
25
# This branch should compile nicely in all cases except for partial missing data
37
26
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
38
- if ! $ (DynamicPPL. inargnames)($ vn, $ model ) || $ (DynamicPPL. inmissings)($ vn, $ model )
27
+ if ! $ (DynamicPPL. inargnames)($ vn, _model ) || $ (DynamicPPL. inmissings)($ vn, _model )
39
28
true
40
29
else
41
30
# Evaluate the LHS
@@ -46,7 +35,7 @@ function isassumption(model, expr::Union{Symbol, Expr})
46
35
end
47
36
48
37
# failsafe: a literal is never an assumption
49
- isassumption (model, expr) = :(false )
38
+ isassumption (expr) = :(false )
50
39
51
40
# ################
52
41
# Main Compiler #
@@ -77,7 +66,7 @@ function model(expr)
77
66
modelinfo = build_model_info (expr)
78
67
79
68
# Generate main body
80
- modelinfo[:main_body ] = generate_mainbody (modelinfo)
69
+ modelinfo[:main_body ] = generate_mainbody (modelinfo[ :main_body ], modelinfo[ :args ] )
81
70
82
71
return build_output (modelinfo)
83
72
end
@@ -166,67 +155,57 @@ function build_model_info(input_expr)
166
155
:args_nt => args_nt,
167
156
:defaults_nt => defaults_nt,
168
157
:args => args,
169
- :whereparams => modeldef[:whereparams ],
170
- :main_body_names => Dict (
171
- :ctx => gensym (:ctx ),
172
- :vi => gensym (:vi ),
173
- :sampler => gensym (:sampler ),
174
- :model => gensym (:model )
175
- )
158
+ :whereparams => modeldef[:whereparams ]
176
159
)
177
160
178
161
return model_info
179
162
end
180
163
181
164
"""
182
- generate_mainbody([ expr, ]modelinfo )
165
+ generate_mainbody(expr, args )
183
166
184
- Generate the body of the main evaluation function.
167
+ Generate the body of the main evaluation function from expression `expr` and arguments
168
+ `args`.
185
169
"""
186
- generate_mainbody (modelinfo ) = generate_mainbody (modelinfo[ :main_body ], modelinfo )
170
+ generate_mainbody (expr, args ) = generate_mainbody! (Symbol[ ], expr, args )
187
171
188
- generate_mainbody (x, modelinfo) = x
189
- function generate_mainbody (expr:: Expr , modelinfo)
172
+ generate_mainbody! (found, x, args) = x
173
+ function generate_mainbody! (found, sym:: Symbol , args)
174
+ if sym in INTERNALNAMES && sym ∉ found
175
+ @warn " you are using the internal variable `$(sym) `"
176
+ push! (found, sym)
177
+ end
178
+ return sym
179
+ end
180
+ function generate_mainbody! (found, expr:: Expr , args)
190
181
# Do not touch interpolated expressions
191
182
expr. head === :$ && return expr. args[1 ]
192
183
193
184
# Apply the `@.` macro first.
194
185
if Meta. isexpr (expr, :macrocall ) && length (expr. args) > 1 &&
195
186
expr. args[1 ] === Symbol (" @__dot__" )
196
- return generate_mainbody (Base. Broadcast. __dot__ (expr. args[end ]), modelinfo)
197
- end
198
-
199
- # Modify macro calls.
200
- if Meta. isexpr (expr, :macrocall ) && ! isempty (expr. args)
201
- name = expr. args[1 ]
202
- if name === Symbol (" @varinfo" )
203
- return modelinfo[:main_body_names ][:vi ]
204
- elseif name === Symbol (" @logpdf" )
205
- return :($ (modelinfo[:main_body_names ][:vi ]). logp[])
206
- elseif name === Symbol (" @sampler" )
207
- return :($ (modelinfo[:main_body_names ][:sampler ]))
208
- end
187
+ return generate_mainbody! (found, Base. Broadcast. __dot__ (expr. args[end ]), args)
209
188
end
210
189
211
190
# Modify dotted tilde operators.
212
191
args_dottilde = getargs_dottilde (expr)
213
192
if args_dottilde != = nothing
214
193
L, R = args_dottilde
215
- return Base. remove_linenums! (generate_dot_tilde (generate_mainbody ( L, modelinfo ),
216
- generate_mainbody ( R, modelinfo ),
217
- modelinfo ))
194
+ return Base. remove_linenums! (generate_dot_tilde (generate_mainbody! (found, L, args ),
195
+ generate_mainbody! (found, R, args ),
196
+ args ))
218
197
end
219
198
220
199
# Modify tilde operators.
221
200
args_tilde = getargs_tilde (expr)
222
201
if args_tilde != = nothing
223
202
L, R = args_tilde
224
- return Base. remove_linenums! (generate_tilde (generate_mainbody ( L, modelinfo ),
225
- generate_mainbody ( R, modelinfo ),
226
- modelinfo ))
203
+ return Base. remove_linenums! (generate_tilde (generate_mainbody! (found, L, args ),
204
+ generate_mainbody! (found, R, args ),
205
+ args ))
227
206
end
228
207
229
- return Expr (expr. head, map (x -> generate_mainbody ( x, modelinfo ), expr. args)... )
208
+ return Expr (expr. head, map (x -> generate_mainbody! (found, x, args ), expr. args)... )
230
209
end
231
210
232
211
"""
@@ -268,17 +247,12 @@ end
268
247
269
248
270
249
"""
271
- generate_tilde(left, right, model_info )
250
+ generate_tilde(left, right, args )
272
251
273
- The `tilde` function generates `observe` expression for data variables and `assume`
274
- expressions for parameter variables, updating `model_info` in the process .
252
+ Generate an `observe` expression for data variables and `assume` expression for parameter
253
+ variables.
275
254
"""
276
- function generate_tilde (left, right, model_info)
277
- model = model_info[:main_body_names ][:model ]
278
- vi = model_info[:main_body_names ][:vi ]
279
- ctx = model_info[:main_body_names ][:ctx ]
280
- sampler = model_info[:main_body_names ][:sampler ]
281
-
255
+ function generate_tilde (left, right, args)
282
256
@gensym tmpright tmpleft
283
257
top = [:($ tmpright = $ right),
284
258
:($ tmpright isa Union{$ Distribution,AbstractVector{<: $Distribution }}
@@ -289,26 +263,26 @@ function generate_tilde(left, right, model_info)
289
263
push! (top, :($ vn = $ (varname (left))), :($ inds = $ (vinds (left))))
290
264
291
265
assumption = [
292
- :($ out = $ (DynamicPPL. tilde_assume)($ ctx, $ sampler , $ tmpright, $ vn, $ inds,
293
- $ vi )),
294
- :($ (DynamicPPL. acclogp!)($ vi , $ out[2 ])),
266
+ :($ out = $ (DynamicPPL. tilde_assume)(_context, _sampler , $ tmpright, $ vn, $ inds,
267
+ _varinfo )),
268
+ :($ (DynamicPPL. acclogp!)(_varinfo , $ out[2 ])),
295
269
:($ left = $ out[1 ])
296
270
]
297
271
298
272
# It can only be an observation if the LHS is an argument of the model
299
- if vsym (left) in model_info[ : args]
273
+ if vsym (left) in args
300
274
@gensym isassumption
301
275
return quote
302
276
$ (top... )
303
- $ isassumption = $ (DynamicPPL. isassumption (model, left))
277
+ $ isassumption = $ (DynamicPPL. isassumption (left))
304
278
if $ isassumption
305
279
$ (assumption... )
306
280
else
307
281
$ tmpleft = $ left
308
282
$ (DynamicPPL. acclogp!)(
309
- $ vi ,
310
- $ (DynamicPPL. tilde_observe)($ ctx, $ sampler , $ tmpright, $ tmpleft, $ vn ,
311
- $ inds , $ vi )
283
+ _varinfo ,
284
+ $ (DynamicPPL. tilde_observe)(_context, _sampler , $ tmpright, $ tmpleft,
285
+ $ vn , $ inds, _varinfo )
312
286
)
313
287
$ tmpleft
314
288
end
@@ -326,26 +300,19 @@ function generate_tilde(left, right, model_info)
326
300
$ (top... )
327
301
$ tmpleft = $ left
328
302
$ (DynamicPPL. acclogp!)(
329
- $ vi ,
330
- $ (DynamicPPL. tilde_observe)($ ctx, $ sampler , $ tmpright, $ tmpleft, $ vi )
303
+ _varinfo ,
304
+ $ (DynamicPPL. tilde_observe)(_context, _sampler , $ tmpright, $ tmpleft, _varinfo )
331
305
)
332
306
$ tmpleft
333
307
end
334
308
end
335
309
336
310
"""
337
- generate_dot_tilde(left, right, model_info )
311
+ generate_dot_tilde(left, right, args )
338
312
339
- This function returns the expression that replaces `left .~ right` in the model body. If
340
- `preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block
341
- will be run.
313
+ Generate the expression that replaces `left .~ right` in the model body.
342
314
"""
343
- function generate_dot_tilde (left, right, model_info)
344
- model = model_info[:main_body_names ][:model ]
345
- vi = model_info[:main_body_names ][:vi ]
346
- ctx = model_info[:main_body_names ][:ctx ]
347
- sampler = model_info[:main_body_names ][:sampler ]
348
-
315
+ function generate_dot_tilde (left, right, args)
349
316
@gensym tmpright tmpleft
350
317
top = [:($ tmpright = $ right),
351
318
:($ tmpright isa Union{$ Distribution,AbstractVector{<: $Distribution }}
@@ -356,26 +323,26 @@ function generate_dot_tilde(left, right, model_info)
356
323
push! (top, :($ vn = $ (varname (left))), :($ inds = $ (vinds (left))))
357
324
358
325
assumption = [
359
- :($ out = $ (DynamicPPL. dot_tilde_assume)($ ctx, $ sampler , $ tmpright, $ left,
360
- $ vn, $ inds, $ vi )),
361
- :($ (DynamicPPL. acclogp!)($ vi , $ out[2 ])),
326
+ :($ out = $ (DynamicPPL. dot_tilde_assume)(_context, _sampler , $ tmpright, $ left,
327
+ $ vn, $ inds, _varinfo )),
328
+ :($ (DynamicPPL. acclogp!)(_varinfo , $ out[2 ])),
362
329
:($ left .= $ out[1 ])
363
330
]
364
331
365
332
# It can only be an observation if the LHS is an argument of the model
366
- if vsym (left) in model_info[ : args]
333
+ if vsym (left) in args
367
334
@gensym isassumption
368
335
return quote
369
336
$ (top... )
370
- $ isassumption = $ (DynamicPPL. isassumption (model, left))
337
+ $ isassumption = $ (DynamicPPL. isassumption (left))
371
338
if $ isassumption
372
339
$ (assumption... )
373
340
else
374
341
$ tmpleft = $ left
375
342
$ (DynamicPPL. acclogp!)(
376
- $ vi ,
377
- $ (DynamicPPL. dot_tilde_observe)($ ctx, $ sampler , $ tmpright, $ tmpleft ,
378
- $ vn, $ inds, $ vi )
343
+ _varinfo ,
344
+ $ (DynamicPPL. dot_tilde_observe)(_context, _sampler , $ tmpright,
345
+ $ tmpleft, $ vn, $ inds, _varinfo )
379
346
)
380
347
$ tmpleft
381
348
end
@@ -393,8 +360,9 @@ function generate_dot_tilde(left, right, model_info)
393
360
$ (top... )
394
361
$ tmpleft = $ left
395
362
$ (DynamicPPL. acclogp!)(
396
- $ vi,
397
- $ (DynamicPPL. dot_tilde_observe)($ ctx, $ sampler, $ tmpright, $ tmpleft, $ vi)
363
+ _varinfo,
364
+ $ (DynamicPPL. dot_tilde_observe)(_context, _sampler, $ tmpright, $ tmpleft,
365
+ _varinfo)
398
366
)
399
367
$ tmpleft
400
368
end
@@ -411,13 +379,6 @@ hasmissing(T::Type) = false
411
379
Builds the output expression.
412
380
"""
413
381
function build_output (model_info)
414
- # Construct user-facing function
415
- main_body_names = model_info[:main_body_names ]
416
- ctx = main_body_names[:ctx ]
417
- vi = main_body_names[:vi ]
418
- model = main_body_names[:model ]
419
- sampler = main_body_names[:sampler ]
420
-
421
382
# Arguments with default values
422
383
args = model_info[:args ]
423
384
# Argument symbols without default values
@@ -437,7 +398,7 @@ function build_output(model_info)
437
398
unwrap_data_expr = Expr (:block )
438
399
for var in arg_syms
439
400
push! (unwrap_data_expr. args,
440
- :($ var = $ (DynamicPPL. matchingvalue)($ sampler, $ vi, $ (model) . args.$ var)))
401
+ :($ var = $ (DynamicPPL. matchingvalue)(_sampler, _varinfo, _model . args.$ var)))
441
402
end
442
403
443
404
@gensym (evaluator, generator)
@@ -446,10 +407,10 @@ function build_output(model_info)
446
407
447
408
return quote
448
409
function $evaluator (
449
- $ model :: $ (DynamicPPL. Model),
450
- $ vi :: $ (DynamicPPL. VarInfo),
451
- $ sampler :: $ (DynamicPPL. AbstractSampler),
452
- $ ctx :: $ (DynamicPPL. AbstractContext),
410
+ _model :: $ (DynamicPPL. Model),
411
+ _varinfo :: $ (DynamicPPL. VarInfo),
412
+ _sampler :: $ (DynamicPPL. AbstractSampler),
413
+ _context :: $ (DynamicPPL. AbstractContext),
453
414
)
454
415
$ unwrap_data_expr
455
416
$ main_body
0 commit comments