@@ -27,45 +27,44 @@ function wrong_dist_errormsg(l)
27
27
end
28
28
29
29
"""
30
- @preprocess(data_vars, missing_vars, ex )
30
+ @isassumption(model, expr )
31
31
32
- Let `ex ` be `x[1]`. This macro returns `@varname x[1]` in any of the following cases:
32
+ Let `expr ` be `x[1]`. `vn` is an assumption in the following cases:
33
33
1. `x` was not among the input data to the model,
34
34
2. `x` was among the input data to the model but with a value `missing`, or
35
35
3. `x` was among the input data to the model with a value other than missing,
36
- but `x[1] === missing`.
37
- Otherwise, the value of `x[1]` is returned .
36
+ but `x[1] === missing`.
37
+ When `expr` is not an expression or symbol (i.e., a literal), this expands to `false` .
38
38
"""
39
- macro preprocess (data_vars, missing_vars, ex)
40
- ex
41
- end
42
- macro preprocess (model, ex:: Union{Symbol, Expr} )
43
- sym = gensym (:sym )
44
- lhs = gensym (:lhs )
45
- return esc (quote
46
- # Extract symbol
47
- $ sym = Val ($ (vsym (ex)))
39
+ macro isassumption (model, expr:: Union{Symbol, Expr} )
40
+ # Note: never put a return in this... don't forget it's a macro!
41
+ vn = gensym (:vn )
42
+
43
+ return quote
44
+ $ vn = @varname ($ expr)
45
+
48
46
# This branch should compile nicely in all cases except for partial missing data
49
- # For example, when `ex ` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
50
- if ! DynamicPPL. inargnames ($ sym , $ model) || DynamicPPL. inmissings ($ sym , $ model)
51
- $ ( varname (ex)), $ ( vinds (ex))
47
+ # For example, when `expr ` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
48
+ if ! DynamicPPL. inargnames ($ vn , $ model) || DynamicPPL. inmissings ($ vn , $ model)
49
+ true
52
50
else
53
- if DynamicPPL. inargnames ($ sym , $ model)
51
+ if DynamicPPL. inargnames ($ vn , $ model)
54
52
# Evaluate the lhs
55
- $ lhs = $ ex
56
- if $ lhs === missing
57
- $ (varname (ex)), $ (vinds (ex))
58
- else
59
- $ lhs
60
- end
53
+ $ expr === missing
61
54
else
62
55
throw (" This point should not be reached. Please report this error." )
63
56
end
64
57
end
65
- end )
58
+ end |> esc
59
+ end
60
+
61
+ macro isassumption (model, expr)
62
+ # failsafe: a literal is never an assumption
63
+ false
66
64
end
67
65
68
66
67
+
69
68
# ################
70
69
# Main Compiler #
71
70
# ################
@@ -246,40 +245,40 @@ function replace_sampler!(model_info)
246
245
return model_info
247
246
end
248
247
249
- # The next function is defined that way because .~ gives a parsing error in Julia 1.0
250
248
"""
251
- \" ""
252
249
replace_tilde!(model_info)
253
250
254
- Replaces ` ~` expressions with observation or assumption expressions, updating `model_info`.
255
- \ " ""
251
+ Replace `~` and `. ~` expressions with observation or assumption expressions, updating `model_info`.
252
+ """
256
253
function replace_tilde! (model_info)
257
- ex = model_info[:main_body]
258
- ex = MacroTools.postwalk(ex) do x
259
- if @capture(x, @M_ L_ ~ R_) && M == Symbol("@__dot__")
260
- generate_dot_tilde(L, R, model_info)
261
- else
262
- x
254
+ # Apply the `@.` macro first.
255
+ expr = model_info[:main_body ]
256
+ dottedexpr = MacroTools. postwalk (apply_dotted, expr)
257
+
258
+ # Check for tilde operators.
259
+ tildeexpr = MacroTools. postwalk (dottedexpr) do x
260
+ # Check dot tilde first.
261
+ dotargs = getargs_dottilde (x)
262
+ if dotargs != = nothing
263
+ L, R = dotargs
264
+ return generate_dot_tilde (L, R, model_info)
263
265
end
264
- end
265
- $(VERSION >= v " 1.1" ? " ex = MacroTools.postwalk(ex) do x
266
- if @capture(x, L_ .~ R_)
267
- generate_dot_tilde(L, R, model_info)
268
- else
269
- x
270
- end
271
- end" : " " )
272
- ex = MacroTools.postwalk(ex) do x
273
- if @capture(x, L_ ~ R_)
274
- generate_tilde(L, R, model_info)
275
- else
276
- x
266
+
267
+ # Check tilde.
268
+ args = getargs_tilde (x)
269
+ if args != = nothing
270
+ L, R = args
271
+ return generate_tilde (L, R, model_info)
277
272
end
273
+
274
+ return x
278
275
end
279
- model_info[:main_body] = ex
276
+
277
+ # Update the function body.
278
+ model_info[:main_body ] = tildeexpr
279
+
280
280
return model_info
281
281
end
282
- """ |> Meta. parse |> eval
283
282
284
283
# """ Unbreak code highlighting in Emacs julia-mode
285
284
@@ -300,32 +299,36 @@ function generate_tilde(left, right, model_info)
300
299
lp = gensym (:lp )
301
300
vn = gensym (:vn )
302
301
inds = gensym (:inds )
303
- preprocessed = gensym (:preprocessed )
302
+ isassumption = gensym (:isassumption )
304
303
assert_ex = :(DynamicPPL. assert_dist ($ temp_right, msg = $ (wrong_dist_errormsg (@__LINE__ ))))
304
+
305
305
if left isa Symbol || left isa Expr
306
306
ex = quote
307
307
$ temp_right = $ right
308
308
$ assert_ex
309
- $ preprocessed = DynamicPPL. @preprocess ($ model, $ left)
310
- if $ preprocessed isa Tuple
311
- $ vn, $ inds = $ preprocessed
312
- $ out = DynamicPPL. tilde ($ ctx, $ sampler, $ temp_right, $ vn, $ inds, $ vi)
309
+
310
+ $ vn, $ inds = $ (varname (left)), $ (vinds (left))
311
+ $ isassumption = DynamicPPL. @isassumption ($ model, $ left)
312
+ if $ isassumption
313
+ $ out = DynamicPPL. tilde_assume ($ ctx, $ sampler, $ temp_right, $ vn, $ inds, $ vi)
313
314
$ left = $ out[1 ]
314
315
DynamicPPL. acclogp! ($ vi, $ out[2 ])
315
316
else
316
317
DynamicPPL. acclogp! (
317
318
$ vi,
318
- DynamicPPL. tilde ($ ctx, $ sampler, $ temp_right, $ preprocessed , $ vi),
319
+ DynamicPPL. tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vn, $ inds , $ vi),
319
320
)
320
321
end
321
322
end
322
323
else
324
+ # we have a literal, which is automatically an observation
323
325
ex = quote
324
326
$ temp_right = $ right
325
327
$ assert_ex
328
+
326
329
DynamicPPL. acclogp! (
327
330
$ vi,
328
- DynamicPPL. tilde ($ ctx, $ sampler, $ temp_right, $ left, $ vi),
331
+ DynamicPPL. tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vi),
329
332
)
330
333
end
331
334
end
@@ -335,48 +338,51 @@ end
335
338
"""
336
339
generate_dot_tilde(left, right, model_info)
337
340
338
- This function returns the expression that replaces `left .~ right` in the model body. If `preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block will be run.
341
+ This function returns the expression that replaces `left .~ right` in the model body. If
342
+ `preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block
343
+ will be run.
339
344
"""
340
345
function generate_dot_tilde (left, right, model_info)
341
346
model = model_info[:main_body_names ][:model ]
342
347
vi = model_info[:main_body_names ][:vi ]
343
348
ctx = model_info[:main_body_names ][:ctx ]
344
349
sampler = model_info[:main_body_names ][:sampler ]
345
350
out = gensym (:out )
346
- temp_left = gensym (:temp_left )
347
351
temp_right = gensym (:temp_right )
348
- preprocessed = gensym (:preprocessed )
352
+ isassumption = gensym (:isassumption )
349
353
lp = gensym (:lp )
350
354
vn = gensym (:vn )
351
355
inds = gensym (:inds )
352
356
assert_ex = :(DynamicPPL. assert_dist ($ temp_right, msg = $ (wrong_dist_errormsg (@__LINE__ ))))
357
+
353
358
if left isa Symbol || left isa Expr
354
359
ex = quote
355
360
$ temp_right = $ right
356
361
$ assert_ex
357
- $ preprocessed = DynamicPPL. @preprocess ($ model, $ left)
358
- if $ preprocessed isa Tuple
359
- $ vn, $ inds = $ preprocessed
360
- $ temp_left = $ left
361
- $ out = DynamicPPL. dot_tilde ($ ctx, $ sampler, $ temp_right, $ temp_left, $ vn, $ inds, $ vi)
362
+
363
+ $ vn, $ inds = $ (varname (left)), $ (vinds (left))
364
+ $ isassumption = DynamicPPL. @isassumption ($ model, $ left)
365
+
366
+ if $ isassumption
367
+ $ out = DynamicPPL. dot_tilde_assume ($ ctx, $ sampler, $ temp_right, $ left, $ vn, $ inds, $ vi)
362
368
$ left .= $ out[1 ]
363
369
DynamicPPL. acclogp! ($ vi, $ out[2 ])
364
370
else
365
- $ temp_left = $ preprocessed
366
371
DynamicPPL. acclogp! (
367
372
$ vi,
368
- DynamicPPL. dot_tilde ($ ctx, $ sampler, $ temp_right, $ temp_left , $ vi),
373
+ DynamicPPL. dot_tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vn, $ inds , $ vi),
369
374
)
370
375
end
371
376
end
372
377
else
378
+ # we have a literal, which is automatically an observation
373
379
ex = quote
374
- $ temp_left = $ left
375
380
$ temp_right = $ right
376
381
$ assert_ex
382
+
377
383
DynamicPPL. acclogp! (
378
384
$ vi,
379
- DynamicPPL. dot_tilde ($ ctx, $ sampler, $ temp_right, $ temp_left , $ vi),
385
+ DynamicPPL. dot_tilde_observe ($ ctx, $ sampler, $ temp_right, $ left , $ vi),
380
386
)
381
387
end
382
388
end
@@ -416,7 +422,7 @@ function build_output(model_info)
416
422
model_gen = model_info[:name ]
417
423
# Main body of the model
418
424
main_body = model_info[:main_body ]
419
-
425
+
420
426
unwrap_data_expr = Expr (:block )
421
427
for var in arg_syms
422
428
temp_var = gensym (:temp_var )
0 commit comments