@@ -49,19 +49,19 @@ function wrong_dist_errormsg(l)
49
49
end
50
50
51
51
"""
52
- @preprocess (data_vars, missing_vars, ex)
52
+ @isassumption (data_vars, missing_vars, ex)
53
53
54
- Let `ex` be `x[1]`. This macro returns `@varname x[1] ` in any of the following cases:
54
+ Let `ex` be `x[1]`. This macro returns `true ` in any of the following cases:
55
55
1. `x` was not among the input data to the model,
56
56
2. `x` was among the input data to the model but with a value `missing`, or
57
57
3. `x` was among the input data to the model with a value other than missing,
58
- but `x[1] === missing`.
59
- Otherwise, the value of `x[1]` is returned .
58
+ but `x[1] === missing`.
59
+ When `ex` is not a variable (e.g., a literal), the function returns `false` as well .
60
60
"""
61
- macro preprocess (data_vars, missing_vars, ex)
62
- ex
61
+ macro isassumption (data_vars, missing_vars, ex)
62
+ :false
63
63
end
64
- macro preprocess (data_vars, missing_vars, ex:: Union{Symbol, Expr} )
64
+ macro isassumption (data_vars, missing_vars, ex:: Union{Symbol, Expr} )
65
65
sym = gensym (:sym )
66
66
lhs = gensym (:lhs )
67
67
return esc (quote
@@ -70,22 +70,23 @@ macro preprocess(data_vars, missing_vars, ex::Union{Symbol, Expr})
70
70
# This branch should compile nicely in all cases except for partial missing data
71
71
# For example, when `ex` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
72
72
if ! DynamicPPL. inparams ($ sym, $ data_vars) || DynamicPPL. inparams ($ sym, $ missing_vars)
73
- $ ( varname (ex)), $ ( vinds (ex))
73
+ true
74
74
else
75
75
if DynamicPPL. inparams ($ sym, $ data_vars)
76
76
# Evaluate the lhs
77
77
$ lhs = $ ex
78
78
if $ lhs === missing
79
- $ ( varname (ex)), $ ( vinds (ex))
79
+ true
80
80
else
81
- $ lhs
81
+ false
82
82
end
83
83
else
84
84
throw (" This point should not be reached. Please report this error." )
85
85
end
86
86
end
87
87
end )
88
88
end
89
+
89
90
@generated function inparams (:: Val{s} , :: Val{t} ) where {s, t}
90
91
return (s in t) ? :(true ) : :(false )
91
92
end
@@ -319,6 +320,9 @@ function replace_tilde!(model_info)
319
320
end
320
321
""" |> Meta. parse |> eval
321
322
323
+ # """ Unbreak code highlighting in Emacs julia-mode
324
+
325
+
322
326
"""
323
327
generate_tilde(left, right, model_info)
324
328
@@ -331,37 +335,43 @@ function generate_tilde(left, right, model_info)
331
335
vi = model_info[:main_body_names ][:vi ]
332
336
ctx = model_info[:main_body_names ][:ctx ]
333
337
sampler = model_info[:main_body_names ][:sampler ]
334
- temp_right = gensym (:temp_right )
335
- out = gensym (:out )
336
- lp = gensym (:lp )
337
- vn = gensym (:vn )
338
- inds = gensym (:inds )
339
- preprocessed = gensym (:preprocessed )
338
+
339
+ @gensym (out,
340
+ lp,
341
+ vn,
342
+ inds,
343
+ isassumption,
344
+ temp_right)
345
+
340
346
assert_ex = :(DynamicPPL. assert_dist ($ temp_right, msg = $ (wrong_dist_errormsg (@__LINE__ ))))
347
+
341
348
if left isa Symbol || left isa Expr
342
349
ex = quote
343
350
$ temp_right = $ right
344
351
$ assert_ex
345
- $ preprocessed = DynamicPPL. @preprocess ($ arg_syms, DynamicPPL. getmissing ($ model), $ left)
346
- if $ preprocessed isa Tuple
347
- $ vn, $ inds = $ preprocessed
348
- $ out = DynamicPPL. tilde ($ ctx, $ sampler, $ temp_right, $ vn, $ inds, $ vi)
352
+
353
+ $ vn, $ inds = $ (varname (left)), $ (vinds (left))
354
+ $ isassumption = DynamicPPL. @isassumption ($ arg_syms, DynamicPPL. getmissing ($ model), $ left)
355
+ if $ isassumption
356
+ $ out = DynamicPPL. tilde_assume ($ ctx, $ sampler, $ temp_right, $ vn, $ inds, $ vi)
349
357
$ left = $ out[1 ]
350
358
DynamicPPL. acclogp! ($ vi, $ out[2 ])
351
359
else
352
360
DynamicPPL. acclogp! (
353
361
$ vi,
354
- DynamicPPL. tilde ($ ctx, $ sampler, $ temp_right, $ preprocessed , $ vi),
362
+ DynamicPPL. tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vn, $ inds , $ vi),
355
363
)
356
364
end
357
365
end
358
366
else
367
+ # we have a literal, which is automatically an observation
359
368
ex = quote
360
369
$ temp_right = $ right
361
370
$ assert_ex
371
+
362
372
DynamicPPL. acclogp! (
363
373
$ vi,
364
- DynamicPPL. tilde ($ ctx, $ sampler, $ temp_right, $ left, $ vi),
374
+ DynamicPPL. tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vi),
365
375
)
366
376
end
367
377
end
@@ -371,49 +381,55 @@ end
371
381
"""
372
382
generate_dot_tilde(left, right, model_info)
373
383
374
- This function returns the expression that replaces `left .~ right` in the model body. If `preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block will be run.
384
+ This function returns the expression that replaces `left .~ right` in the model body. If
385
+ `preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block
386
+ will be run.
375
387
"""
376
388
function generate_dot_tilde (left, right, model_info)
377
389
arg_syms = Val ((model_info[:arg_syms ]. .. ,))
378
390
model = model_info[:main_body_names ][:model ]
379
391
vi = model_info[:main_body_names ][:vi ]
380
392
ctx = model_info[:main_body_names ][:ctx ]
381
393
sampler = model_info[:main_body_names ][:sampler ]
382
- out = gensym (:out )
383
- temp_left = gensym (:temp_left )
384
- temp_right = gensym (:temp_right )
385
- preprocessed = gensym (:preprocessed )
386
- lp = gensym (:lp )
387
- vn = gensym (:vn )
388
- inds = gensym (:inds )
394
+
395
+ @gensym (out,
396
+ preprocessed,
397
+ lp,
398
+ vn,
399
+ inds,
400
+ isassumption,
401
+ temp_right)
402
+
389
403
assert_ex = :(DynamicPPL. assert_dist ($ temp_right, msg = $ (wrong_dist_errormsg (@__LINE__ ))))
404
+
390
405
if left isa Symbol || left isa Expr
391
406
ex = quote
392
407
$ temp_right = $ right
393
408
$ assert_ex
394
- $ preprocessed = DynamicPPL. @preprocess ($ arg_syms, DynamicPPL. getmissing ($ model), $ left)
395
- if $ preprocessed isa Tuple
396
- $ vn, $ inds = $ preprocessed
397
- $ temp_left = $ left
398
- $ out = DynamicPPL. dot_tilde ($ ctx, $ sampler, $ temp_right, $ temp_left, $ vn, $ inds, $ vi)
409
+
410
+ $ vn, $ inds = $ (varname (left)), $ (vinds (left))
411
+ $ isassumption = DynamicPPL. @isassumption ($ arg_syms, DynamicPPL. getmissing ($ model), $ left)
412
+
413
+ if $ isassumption
414
+ $ out = DynamicPPL. dot_tilde_assume ($ ctx, $ sampler, $ temp_right, $ left, $ vn, $ inds, $ vi)
399
415
$ left .= $ out[1 ]
400
416
DynamicPPL. acclogp! ($ vi, $ out[2 ])
401
417
else
402
- $ temp_left = $ preprocessed
403
418
DynamicPPL. acclogp! (
404
419
$ vi,
405
- DynamicPPL. dot_tilde ($ ctx, $ sampler, $ temp_right, $ temp_left , $ vi),
420
+ DynamicPPL. dot_tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vn, $ inds , $ vi),
406
421
)
407
422
end
408
423
end
409
424
else
425
+ # we have a literal, which is automatically an observation
410
426
ex = quote
411
- $ temp_left = $ left
412
427
$ temp_right = $ right
413
428
$ assert_ex
429
+
414
430
DynamicPPL. acclogp! (
415
431
$ vi,
416
- DynamicPPL. dot_tilde ($ ctx, $ sampler, $ temp_right, $ temp_left , $ vi),
432
+ DynamicPPL. dot_tilde_observe ($ ctx, $ sampler, $ temp_right, $ left , $ vi),
417
433
)
418
434
end
419
435
end
0 commit comments