1
- macro varinfo ()
2
- :(throw (_error_msg ()))
3
- end
4
- macro logpdf ()
5
- :(throw (_error_msg ()))
6
- end
7
- macro sampler ()
8
- :(throw (_error_msg ()))
9
- end
10
- function _error_msg ()
11
- return " This macro is only for use in the `@model` macro and not for external use."
12
- end
13
-
14
1
const DISTMSG = " Right-hand side of a ~ must be subtype of Distribution or a vector of " *
15
2
" Distributions."
16
3
4
+ const INTERNALNAMES = (:_model , :_sampler , :_context , :_varinfo )
5
+
17
6
"""
18
- isassumption(model, expr)
7
+ isassumption(expr)
19
8
20
9
Return an expression that can be evaluated to check if `expr` is an assumption in the
21
- ` model` .
10
+ model.
22
11
23
12
Let `expr` be `:(x[1])`. It is an assumption in the following cases:
24
- 1. `x` is not among the input data to the ` model` ,
25
- 2. `x` is among the input data to the ` model` but with a value `missing`, or
26
- 3. `x` is among the input data to the ` model` with a value other than missing,
13
+ 1. `x` is not among the input data to the model,
14
+ 2. `x` is among the input data to the model but with a value `missing`, or
15
+ 3. `x` is among the input data to the model with a value other than missing,
27
16
but `x[1] === missing`.
28
17
29
18
When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
30
19
"""
31
- function isassumption (model, expr:: Union{Symbol, Expr} )
20
+ function isassumption (expr:: Union{Symbol, Expr} )
32
21
vn = gensym (:vn )
33
22
34
23
return quote
35
24
let $ vn = $ (varname (expr))
36
25
# This branch should compile nicely in all cases except for partial missing data
37
26
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
38
- if ! $ (DynamicPPL. inargnames)($ vn, $ model ) || $ (DynamicPPL. inmissings)($ vn, $ model )
27
+ if ! $ (DynamicPPL. inargnames)($ vn, _model ) || $ (DynamicPPL. inmissings)($ vn, _model )
39
28
true
40
29
else
41
30
# Evaluate the LHS
@@ -46,7 +35,7 @@ function isassumption(model, expr::Union{Symbol, Expr})
46
35
end
47
36
48
37
# failsafe: a literal is never an assumption
49
- isassumption (model, expr) = :(false )
38
+ isassumption (expr) = :(false )
50
39
51
40
# ################
52
41
# Main Compiler #
@@ -77,7 +66,7 @@ function model(expr)
77
66
modelinfo = build_model_info (expr)
78
67
79
68
# Generate main body
80
- modelinfo[:main_body ] = generate_mainbody (modelinfo)
69
+ modelinfo[:main_body ] = generate_mainbody (modelinfo[ :main_body ], modelinfo[ :args ] )
81
70
82
71
return build_output (modelinfo)
83
72
end
@@ -166,84 +155,69 @@ function build_model_info(input_expr)
166
155
:args_nt => args_nt,
167
156
:defaults_nt => defaults_nt,
168
157
:args => args,
169
- :whereparams => modeldef[:whereparams ],
170
- :main_body_names => Dict (
171
- :ctx => gensym (:ctx ),
172
- :vi => gensym (:vi ),
173
- :sampler => gensym (:sampler ),
174
- :model => gensym (:model )
175
- )
158
+ :whereparams => modeldef[:whereparams ]
176
159
)
177
160
178
161
return model_info
179
162
end
180
163
181
164
"""
182
- generate_mainbody([ expr, ]modelinfo )
165
+ generate_mainbody(expr, args )
183
166
184
- Generate the body of the main evaluation function.
167
+ Generate the body of the main evaluation function from expression `expr` and arguments
168
+ `args`.
185
169
"""
186
- generate_mainbody (modelinfo ) = generate_mainbody (modelinfo[ :main_body ], modelinfo )
170
+ generate_mainbody (expr, args ) = generate_mainbody! (Symbol[ ], expr, args )
187
171
188
- generate_mainbody (x, modelinfo) = x
189
- function generate_mainbody (expr:: Expr , modelinfo)
172
+ generate_mainbody! (found, x, args) = x
173
+ function generate_mainbody! (found, sym:: Symbol , args)
174
+ if sym in INTERNALNAMES && sym ∉ found
175
+ @warn " you are using the internal variable `$(sym) `"
176
+ push! (found, sym)
177
+ end
178
+ return sym
179
+ end
180
+ function generate_mainbody! (found, expr:: Expr , args)
190
181
# Do not touch interpolated expressions
191
182
expr. head === :$ && return expr. args[1 ]
192
183
193
184
# Apply the `@.` macro first.
194
185
if Meta. isexpr (expr, :macrocall ) && length (expr. args) > 1 &&
195
186
expr. args[1 ] === Symbol (" @__dot__" )
196
- return generate_mainbody (Base. Broadcast. __dot__ (expr. args[end ]), modelinfo)
197
- end
198
-
199
- # Modify macro calls.
200
- if Meta. isexpr (expr, :macrocall ) && ! isempty (expr. args)
201
- name = expr. args[1 ]
202
- if name === Symbol (" @varinfo" )
203
- return modelinfo[:main_body_names ][:vi ]
204
- elseif name === Symbol (" @logpdf" )
205
- return :($ (modelinfo[:main_body_names ][:vi ]). logp[])
206
- elseif name === Symbol (" @sampler" )
207
- return :($ (modelinfo[:main_body_names ][:sampler ]))
208
- end
187
+ return generate_mainbody! (found, Base. Broadcast. __dot__ (expr. args[end ]), args)
209
188
end
210
189
211
190
# Modify dotted tilde operators.
212
191
args_dottilde = getargs_dottilde (expr)
213
192
if args_dottilde != = nothing
214
193
L, R = args_dottilde
215
- return Base. remove_linenums! (generate_dot_tilde (generate_mainbody ( L, modelinfo ),
216
- generate_mainbody ( R, modelinfo ),
217
- modelinfo ))
194
+ return Base. remove_linenums! (generate_dot_tilde (generate_mainbody! (found, L, args ),
195
+ generate_mainbody! (found, R, args ),
196
+ args ))
218
197
end
219
198
220
199
# Modify tilde operators.
221
200
args_tilde = getargs_tilde (expr)
222
201
if args_tilde != = nothing
223
202
L, R = args_tilde
224
- return Base. remove_linenums! (generate_tilde (generate_mainbody ( L, modelinfo ),
225
- generate_mainbody ( R, modelinfo ),
226
- modelinfo ))
203
+ return Base. remove_linenums! (generate_tilde (generate_mainbody! (found, L, args ),
204
+ generate_mainbody! (found, R, args ),
205
+ args ))
227
206
end
228
207
229
- return Expr (expr. head, map (x -> generate_mainbody ( x, modelinfo ), expr. args)... )
208
+ return Expr (expr. head, map (x -> generate_mainbody! (found, x, args ), expr. args)... )
230
209
end
231
210
232
211
# """ Unbreak code highlighting in Emacs julia-mode
233
212
234
213
235
214
"""
236
- generate_tilde(left, right, model_info )
215
+ generate_tilde(left, right, args )
237
216
238
- The `tilde` function generates `observe` expression for data variables and `assume`
239
- expressions for parameter variables, updating `model_info` in the process .
217
+ Generate an `observe` expression for data variables and `assume` expression for parameter
218
+ variables.
240
219
"""
241
- function generate_tilde (left, right, model_info)
242
- model = model_info[:main_body_names ][:model ]
243
- vi = model_info[:main_body_names ][:vi ]
244
- ctx = model_info[:main_body_names ][:ctx ]
245
- sampler = model_info[:main_body_names ][:sampler ]
246
-
220
+ function generate_tilde (left, right, args)
247
221
@gensym tmpright tmpleft
248
222
top = [:($ tmpright = $ right),
249
223
:($ tmpright isa Union{$ Distribution,AbstractVector{<: $Distribution }}
@@ -254,26 +228,26 @@ function generate_tilde(left, right, model_info)
254
228
push! (top, :($ vn = $ (varname (left))), :($ inds = $ (vinds (left))))
255
229
256
230
assumption = [
257
- :($ out = $ (DynamicPPL. tilde_assume)($ ctx, $ sampler , $ tmpright, $ vn, $ inds,
258
- $ vi )),
259
- :($ (DynamicPPL. acclogp!)($ vi , $ out[2 ])),
231
+ :($ out = $ (DynamicPPL. tilde_assume)(_context, _sampler , $ tmpright, $ vn, $ inds,
232
+ _varinfo )),
233
+ :($ (DynamicPPL. acclogp!)(_varinfo , $ out[2 ])),
260
234
:($ left = $ out[1 ])
261
235
]
262
236
263
237
# It can only be an observation if the LHS is an argument of the model
264
- if vsym (left) in model_info[ : args]
238
+ if vsym (left) in args
265
239
@gensym isassumption
266
240
return quote
267
241
$ (top... )
268
- $ isassumption = $ (DynamicPPL. isassumption (model, left))
242
+ $ isassumption = $ (DynamicPPL. isassumption (left))
269
243
if $ isassumption
270
244
$ (assumption... )
271
245
else
272
246
$ tmpleft = $ left
273
247
$ (DynamicPPL. acclogp!)(
274
- $ vi ,
275
- $ (DynamicPPL. tilde_observe)($ ctx, $ sampler , $ tmpright, $ tmpleft, $ vn ,
276
- $ inds , $ vi )
248
+ _varinfo ,
249
+ $ (DynamicPPL. tilde_observe)(_context, _sampler , $ tmpright, $ tmpleft,
250
+ $ vn , $ inds, _varinfo )
277
251
)
278
252
$ tmpleft
279
253
end
@@ -291,26 +265,19 @@ function generate_tilde(left, right, model_info)
291
265
$ (top... )
292
266
$ tmpleft = $ left
293
267
$ (DynamicPPL. acclogp!)(
294
- $ vi ,
295
- $ (DynamicPPL. tilde_observe)($ ctx, $ sampler , $ tmpright, $ tmpleft, $ vi )
268
+ _varinfo ,
269
+ $ (DynamicPPL. tilde_observe)(_context, _sampler , $ tmpright, $ tmpleft, _varinfo )
296
270
)
297
271
$ tmpleft
298
272
end
299
273
end
300
274
301
275
"""
302
- generate_dot_tilde(left, right, model_info )
276
+ generate_dot_tilde(left, right, args )
303
277
304
- This function returns the expression that replaces `left .~ right` in the model body. If
305
- `preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block
306
- will be run.
278
+ Generate the expression that replaces `left .~ right` in the model body.
307
279
"""
308
- function generate_dot_tilde (left, right, model_info)
309
- model = model_info[:main_body_names ][:model ]
310
- vi = model_info[:main_body_names ][:vi ]
311
- ctx = model_info[:main_body_names ][:ctx ]
312
- sampler = model_info[:main_body_names ][:sampler ]
313
-
280
+ function generate_dot_tilde (left, right, args)
314
281
@gensym tmpright tmpleft
315
282
top = [:($ tmpright = $ right),
316
283
:($ tmpright isa Union{$ Distribution,AbstractVector{<: $Distribution }}
@@ -321,26 +288,26 @@ function generate_dot_tilde(left, right, model_info)
321
288
push! (top, :($ vn = $ (varname (left))), :($ inds = $ (vinds (left))))
322
289
323
290
assumption = [
324
- :($ out = $ (DynamicPPL. dot_tilde_assume)($ ctx, $ sampler , $ tmpright, $ left,
325
- $ vn, $ inds, $ vi )),
326
- :($ (DynamicPPL. acclogp!)($ vi , $ out[2 ])),
291
+ :($ out = $ (DynamicPPL. dot_tilde_assume)(_context, _sampler , $ tmpright, $ left,
292
+ $ vn, $ inds, _varinfo )),
293
+ :($ (DynamicPPL. acclogp!)(_varinfo , $ out[2 ])),
327
294
:($ left .= $ out[1 ])
328
295
]
329
296
330
297
# It can only be an observation if the LHS is an argument of the model
331
- if vsym (left) in model_info[ : args]
298
+ if vsym (left) in args
332
299
@gensym isassumption
333
300
return quote
334
301
$ (top... )
335
- $ isassumption = $ (DynamicPPL. isassumption (model, left))
302
+ $ isassumption = $ (DynamicPPL. isassumption (left))
336
303
if $ isassumption
337
304
$ (assumption... )
338
305
else
339
306
$ tmpleft = $ left
340
307
$ (DynamicPPL. acclogp!)(
341
- $ vi ,
342
- $ (DynamicPPL. dot_tilde_observe)($ ctx, $ sampler , $ tmpright, $ tmpleft ,
343
- $ vn, $ inds, $ vi )
308
+ _varinfo ,
309
+ $ (DynamicPPL. dot_tilde_observe)(_context, _sampler , $ tmpright,
310
+ $ tmpleft, $ vn, $ inds, _varinfo )
344
311
)
345
312
$ tmpleft
346
313
end
@@ -358,8 +325,9 @@ function generate_dot_tilde(left, right, model_info)
358
325
$ (top... )
359
326
$ tmpleft = $ left
360
327
$ (DynamicPPL. acclogp!)(
361
- $ vi,
362
- $ (DynamicPPL. dot_tilde_observe)($ ctx, $ sampler, $ tmpright, $ tmpleft, $ vi)
328
+ _varinfo,
329
+ $ (DynamicPPL. dot_tilde_observe)(_context, _sampler, $ tmpright, $ tmpleft,
330
+ _varinfo)
363
331
)
364
332
$ tmpleft
365
333
end
@@ -376,13 +344,6 @@ hasmissing(T::Type) = false
376
344
Builds the output expression.
377
345
"""
378
346
function build_output (model_info)
379
- # Construct user-facing function
380
- main_body_names = model_info[:main_body_names ]
381
- ctx = main_body_names[:ctx ]
382
- vi = main_body_names[:vi ]
383
- model = main_body_names[:model ]
384
- sampler = main_body_names[:sampler ]
385
-
386
347
# Arguments with default values
387
348
args = model_info[:args ]
388
349
# Argument symbols without default values
@@ -402,7 +363,7 @@ function build_output(model_info)
402
363
unwrap_data_expr = Expr (:block )
403
364
for var in arg_syms
404
365
push! (unwrap_data_expr. args,
405
- :($ var = $ (DynamicPPL. matchingvalue)($ sampler, $ vi, $ (model) . args.$ var)))
366
+ :($ var = $ (DynamicPPL. matchingvalue)(_sampler, _varinfo, _model . args.$ var)))
406
367
end
407
368
408
369
@gensym (evaluator, generator)
@@ -411,10 +372,10 @@ function build_output(model_info)
411
372
412
373
return quote
413
374
function $evaluator (
414
- $ model :: $ (DynamicPPL. Model),
415
- $ vi :: $ (DynamicPPL. VarInfo),
416
- $ sampler :: $ (DynamicPPL. AbstractSampler),
417
- $ ctx :: $ (DynamicPPL. AbstractContext),
375
+ _model :: $ (DynamicPPL. Model),
376
+ _varinfo :: $ (DynamicPPL. VarInfo),
377
+ _sampler :: $ (DynamicPPL. AbstractSampler),
378
+ _context :: $ (DynamicPPL. AbstractContext),
418
379
)
419
380
$ unwrap_data_expr
420
381
$ main_body
0 commit comments