@@ -27,45 +27,44 @@ function wrong_dist_errormsg(l)
27
27
end
28
28
29
29
"""
30
- @preprocess(data_vars, missing_vars, ex )
30
+ @isassumption(model, expr )
31
31
32
- Let `ex ` be `x[1]`. This macro returns `@varname x[1]` in any of the following cases:
32
+ Let `expr ` be `x[1]`. `vn` is an assumption in the following cases:
33
33
1. `x` was not among the input data to the model,
34
34
2. `x` was among the input data to the model but with a value `missing`, or
35
35
3. `x` was among the input data to the model with a value other than missing,
36
- but `x[1] === missing`.
37
- Otherwise, the value of `x[1]` is returned .
36
+ but `x[1] === missing`.
37
+ When `expr` is not an expression or symbol (i.e., a literal), this expands to `false` .
38
38
"""
39
- macro preprocess (data_vars, missing_vars, ex)
40
- ex
41
- end
42
- macro preprocess (model, ex:: Union{Symbol, Expr} )
43
- sym = gensym (:sym )
44
- lhs = gensym (:lhs )
45
- return esc (quote
46
- # Extract symbol
47
- $ sym = Val ($ (vsym (ex)))
39
+ macro isassumption (model, expr:: Union{Symbol, Expr} )
40
+ # Note: never put a return in this... don't forget it's a macro!
41
+ vn = gensym (:vn )
42
+
43
+ return quote
44
+ $ vn = @varname ($ expr)
45
+
48
46
# This branch should compile nicely in all cases except for partial missing data
49
- # For example, when `ex ` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
50
- if ! DynamicPPL. inargnames ($ sym , $ model) || DynamicPPL. inmissings ($ sym , $ model)
51
- $ ( varname (ex)), $ ( vinds (ex))
47
+ # For example, when `expr ` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
48
+ if ! DynamicPPL. inargnames ($ vn , $ model) || DynamicPPL. inmissings ($ vn , $ model)
49
+ true
52
50
else
53
- if DynamicPPL. inargnames ($ sym , $ model)
51
+ if DynamicPPL. inargnames ($ vn , $ model)
54
52
# Evaluate the lhs
55
- $ lhs = $ ex
56
- if $ lhs === missing
57
- $ (varname (ex)), $ (vinds (ex))
58
- else
59
- $ lhs
60
- end
53
+ $ expr === missing
61
54
else
62
55
throw (" This point should not be reached. Please report this error." )
63
56
end
64
57
end
65
- end )
58
+ end |> esc
59
+ end
60
+
61
+ macro isassumption (model, expr)
62
+ # failsafe: a literal is never an assumption
63
+ false
66
64
end
67
65
68
66
67
+
69
68
# ################
70
69
# Main Compiler #
71
70
# ################
@@ -300,32 +299,36 @@ function generate_tilde(left, right, model_info)
300
299
lp = gensym (:lp )
301
300
vn = gensym (:vn )
302
301
inds = gensym (:inds )
303
- preprocessed = gensym (:preprocessed )
302
+ isassumption = gensym (:isassumption )
304
303
assert_ex = :(DynamicPPL. assert_dist ($ temp_right, msg = $ (wrong_dist_errormsg (@__LINE__ ))))
304
+
305
305
if left isa Symbol || left isa Expr
306
306
ex = quote
307
307
$ temp_right = $ right
308
308
$ assert_ex
309
- $ preprocessed = DynamicPPL. @preprocess ($ model, $ left)
310
- if $ preprocessed isa Tuple
311
- $ vn, $ inds = $ preprocessed
312
- $ out = DynamicPPL. tilde ($ ctx, $ sampler, $ temp_right, $ vn, $ inds, $ vi)
309
+
310
+ $ vn, $ inds = $ (varname (left)), $ (vinds (left))
311
+ $ isassumption = DynamicPPL. @isassumption ($ model, $ left)
312
+ if $ isassumption
313
+ $ out = DynamicPPL. tilde_assume ($ ctx, $ sampler, $ temp_right, $ vn, $ inds, $ vi)
313
314
$ left = $ out[1 ]
314
315
DynamicPPL. acclogp! ($ vi, $ out[2 ])
315
316
else
316
317
DynamicPPL. acclogp! (
317
318
$ vi,
318
- DynamicPPL. tilde ($ ctx, $ sampler, $ temp_right, $ preprocessed , $ vi),
319
+ DynamicPPL. tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vn, $ inds , $ vi),
319
320
)
320
321
end
321
322
end
322
323
else
324
+ # we have a literal, which is automatically an observation
323
325
ex = quote
324
326
$ temp_right = $ right
325
327
$ assert_ex
328
+
326
329
DynamicPPL. acclogp! (
327
330
$ vi,
328
- DynamicPPL. tilde ($ ctx, $ sampler, $ temp_right, $ left, $ vi),
331
+ DynamicPPL. tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vi),
329
332
)
330
333
end
331
334
end
@@ -335,48 +338,51 @@ end
335
338
"""
336
339
generate_dot_tilde(left, right, model_info)
337
340
338
- This function returns the expression that replaces `left .~ right` in the model body. If `preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block will be run.
341
+ This function returns the expression that replaces `left .~ right` in the model body. If
342
+ `preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block
343
+ will be run.
339
344
"""
340
345
function generate_dot_tilde (left, right, model_info)
341
346
model = model_info[:main_body_names ][:model ]
342
347
vi = model_info[:main_body_names ][:vi ]
343
348
ctx = model_info[:main_body_names ][:ctx ]
344
349
sampler = model_info[:main_body_names ][:sampler ]
345
350
out = gensym (:out )
346
- temp_left = gensym (:temp_left )
347
351
temp_right = gensym (:temp_right )
348
- preprocessed = gensym (:preprocessed )
352
+ isassumption = gensym (:isassumption )
349
353
lp = gensym (:lp )
350
354
vn = gensym (:vn )
351
355
inds = gensym (:inds )
352
356
assert_ex = :(DynamicPPL. assert_dist ($ temp_right, msg = $ (wrong_dist_errormsg (@__LINE__ ))))
357
+
353
358
if left isa Symbol || left isa Expr
354
359
ex = quote
355
360
$ temp_right = $ right
356
361
$ assert_ex
357
- $ preprocessed = DynamicPPL. @preprocess ($ model, $ left)
358
- if $ preprocessed isa Tuple
359
- $ vn, $ inds = $ preprocessed
360
- $ temp_left = $ left
361
- $ out = DynamicPPL. dot_tilde ($ ctx, $ sampler, $ temp_right, $ temp_left, $ vn, $ inds, $ vi)
362
+
363
+ $ vn, $ inds = $ (varname (left)), $ (vinds (left))
364
+ $ isassumption = DynamicPPL. @isassumption ($ model, $ left)
365
+
366
+ if $ isassumption
367
+ $ out = DynamicPPL. dot_tilde_assume ($ ctx, $ sampler, $ temp_right, $ left, $ vn, $ inds, $ vi)
362
368
$ left .= $ out[1 ]
363
369
DynamicPPL. acclogp! ($ vi, $ out[2 ])
364
370
else
365
- $ temp_left = $ preprocessed
366
371
DynamicPPL. acclogp! (
367
372
$ vi,
368
- DynamicPPL. dot_tilde ($ ctx, $ sampler, $ temp_right, $ temp_left , $ vi),
373
+ DynamicPPL. dot_tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vn, $ inds , $ vi),
369
374
)
370
375
end
371
376
end
372
377
else
378
+ # we have a literal, which is automatically an observation
373
379
ex = quote
374
- $ temp_left = $ left
375
380
$ temp_right = $ right
376
381
$ assert_ex
382
+
377
383
DynamicPPL. acclogp! (
378
384
$ vi,
379
- DynamicPPL. dot_tilde ($ ctx, $ sampler, $ temp_right, $ temp_left , $ vi),
385
+ DynamicPPL. dot_tilde_observe ($ ctx, $ sampler, $ temp_right, $ left , $ vi),
380
386
)
381
387
end
382
388
end
@@ -416,7 +422,7 @@ function build_output(model_info)
416
422
model_gen = model_info[:name ]
417
423
# Main body of the model
418
424
main_body = model_info[:main_body ]
419
-
425
+
420
426
unwrap_data_expr = Expr (:block )
421
427
for var in arg_syms
422
428
temp_var = gensym (:temp_var )
0 commit comments