@@ -27,43 +27,42 @@ function wrong_dist_errormsg(l)
27
27
end
28
28
29
29
"""
30
- @ isassumption(model, expr)
30
+ isassumption(model, expr)
31
31
32
- Let `expr` be `x[1]`. `vn` is an assumption in the following cases:
33
- 1. `x` was not among the input data to the model,
34
- 2. `x` was among the input data to the model but with a value `missing`, or
35
- 3. `x` was among the input data to the model with a value other than missing,
32
+ Return an expression that can be evaluated to check if `expr` is an assumption in the
33
+ `model`.
34
+
35
+ Let `expr` be `:(x[1])`. It is an assumption in the following cases:
36
+ 1. `x` is not among the input data to the `model`,
37
+ 2. `x` is among the input data to the `model` but with a value `missing`, or
38
+ 3. `x` is among the input data to the `model` with a value other than missing,
36
39
but `x[1] === missing`.
40
+
37
41
When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
38
42
"""
39
- macro isassumption (model, expr:: Union{Symbol, Expr} )
40
- # Note: never put a return in this... don't forget it's a macro!
43
+ function isassumption (model, expr:: Union{Symbol, Expr} )
41
44
vn = gensym (:vn )
42
-
45
+
43
46
return quote
44
- $ vn = @varname ($ expr)
45
-
46
- # This branch should compile nicely in all cases except for partial missing data
47
- # For example, when `expr` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
48
- if ! $ DynamicPPL. inargnames ($ vn, $ model) || $ DynamicPPL. inmissings ($ vn, $ model)
49
- true
50
- else
51
- if $ DynamicPPL. inargnames ($ vn, $ model)
52
- # Evaluate the lhs
53
- $ expr === missing
47
+ let $ vn = $ (varname (expr))
48
+ # This branch should compile nicely in all cases except for partial missing data
49
+ # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
50
+ if ! $ (DynamicPPL. inargnames)($ vn, $ model) || $ (DynamicPPL. inmissings)($ vn, $ model)
51
+ true
54
52
else
55
- throw (" This point should not be reached. Please report this error." )
53
+ if $ (DynamicPPL. inargnames)($ vn, $ model)
54
+ # Evaluate the lhs
55
+ $ expr === missing
56
+ else
57
+ throw (" This point should not be reached. Please report this error." )
58
+ end
56
59
end
57
60
end
58
- end |> esc
59
- end
60
-
61
- macro isassumption (model, expr)
62
- # failsafe: a literal is never an assumption
63
- false
61
+ end
64
62
end
65
63
66
-
64
+ # failsafe: a literal is never an assumption
65
+ isassumption (model, expr) = :(false )
67
66
68
67
# ################
69
68
# Main Compiler #
@@ -128,7 +127,7 @@ function build_model_info(input_expr)
128
127
Expr (:tuple , QuoteNode .(arg_syms)... ),
129
128
Expr (:curly , :Tuple , [:(Core. Typeof ($ x)) for x in arg_syms]. .. )
130
129
)
131
- args_nt = Expr (:call , :($ DynamicPPL . namedtuple), nt_type, Expr (:tuple , arg_syms... ))
130
+ args_nt = Expr (:call , :($ namedtuple), nt_type, Expr (:tuple , arg_syms... ))
132
131
end
133
132
args = map (modeldef[:args ]) do arg
134
133
if (arg isa Symbol)
@@ -217,7 +216,7 @@ function replace_logpdf!(model_info)
217
216
vi = model_info[:main_body_names ][:vi ]
218
217
ex = MacroTools. postwalk (ex) do x
219
218
if @capture (x, @logpdf ())
220
- :($ vi. logp[] )
219
+ :(getlogp ( $ vi) )
221
220
else
222
221
x
223
222
end
@@ -294,45 +293,58 @@ function generate_tilde(left, right, model_info)
294
293
vi = model_info[:main_body_names ][:vi ]
295
294
ctx = model_info[:main_body_names ][:ctx ]
296
295
sampler = model_info[:main_body_names ][:sampler ]
297
- temp_right = gensym (:temp_right )
298
- out = gensym (:out )
299
- lp = gensym (:lp )
300
- vn = gensym (:vn )
301
- inds = gensym (:inds )
302
- isassumption = gensym (:isassumption )
303
- assert_ex = :($ DynamicPPL. assert_dist ($ temp_right, msg = $ (wrong_dist_errormsg (@__LINE__ ))))
296
+
297
+ @gensym tmpright
298
+ expr = quote
299
+ $ tmpright = $ right
300
+ $ (DynamicPPL. assert_dist)($ tmpright, msg = $ (wrong_dist_errormsg (@__LINE__ )))
301
+ end
304
302
305
303
if left isa Symbol || left isa Expr
306
- ex = quote
307
- $ temp_right = $ right
308
- $ assert_ex
309
-
310
- $ vn, $ inds = $ (varname (left)), $ (vinds (left))
311
- $ isassumption = $ DynamicPPL. @isassumption ($ model, $ left)
312
- if $ isassumption
313
- $ out = $ DynamicPPL. tilde_assume ($ ctx, $ sampler, $ temp_right, $ vn, $ inds, $ vi)
314
- $ left = $ out[1 ]
315
- $ DynamicPPL. acclogp! ($ vi, $ out[2 ])
316
- else
317
- $ DynamicPPL. acclogp! (
318
- $ vi,
319
- $ DynamicPPL. tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vn, $ inds, $ vi),
320
- )
304
+ @gensym out vn inds
305
+ push! (expr. args,
306
+ :($ vn = $ (varname (left))),
307
+ :($ inds = $ (vinds (left))))
308
+
309
+ assumption = quote
310
+ $ out = $ (DynamicPPL. tilde_assume)($ ctx, $ sampler, $ tmpright, $ vn, $ inds,
311
+ $ vi)
312
+ $ left = $ out[1 ]
313
+ $ (DynamicPPL. acclogp!)($ vi, $ out[2 ])
314
+ end
315
+
316
+ # It can only be an observation if the LHS is an argument of the model
317
+ if vsym (left) in model_info[:args ]
318
+ @gensym isassumption
319
+ return quote
320
+ $ expr
321
+ $ isassumption = $ (DynamicPPL. isassumption (model, left))
322
+ if $ isassumption
323
+ $ assumption
324
+ else
325
+ $ (DynamicPPL. acclogp!)(
326
+ $ vi,
327
+ $ (DynamicPPL. tilde_observe)($ ctx, $ sampler, $ tmpright, $ left, $ vn,
328
+ $ inds, $ vi)
329
+ )
330
+ end
321
331
end
322
332
end
323
- else
324
- # we have a literal, which is automatically an observation
325
- ex = quote
326
- $ temp_right = $ right
327
- $ assert_ex
328
-
329
- $ DynamicPPL. acclogp! (
330
- $ vi,
331
- $ DynamicPPL. tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vi),
332
- )
333
+
334
+ return quote
335
+ $ expr
336
+ $ assumption
333
337
end
334
338
end
335
- return ex
339
+
340
+ # If the LHS is a literal, it is always an observation
341
+ return quote
342
+ $ expr
343
+ $ (DynamicPPL. acclogp!)(
344
+ $ vi,
345
+ $ (DynamicPPL. tilde_observe)($ ctx, $ sampler, $ tmpright, $ left, $ vi)
346
+ )
347
+ end
336
348
end
337
349
338
350
"""
@@ -347,46 +359,58 @@ function generate_dot_tilde(left, right, model_info)
347
359
vi = model_info[:main_body_names ][:vi ]
348
360
ctx = model_info[:main_body_names ][:ctx ]
349
361
sampler = model_info[:main_body_names ][:sampler ]
350
- out = gensym (:out )
351
- temp_right = gensym (:temp_right )
352
- isassumption = gensym (:isassumption )
353
- lp = gensym (:lp )
354
- vn = gensym (:vn )
355
- inds = gensym (:inds )
356
- assert_ex = :($ DynamicPPL. assert_dist ($ temp_right, msg = $ (wrong_dist_errormsg (@__LINE__ ))))
357
362
358
- if left isa Symbol || left isa Expr
359
- ex = quote
360
- $ temp_right = $ right
361
- $ assert_ex
363
+ @gensym tmpright
364
+ expr = quote
365
+ $ tmpright = $ right
366
+ $ (DynamicPPL. assert_dist)($ tmpright, msg = $ (wrong_dist_errormsg (@__LINE__ )))
367
+ end
362
368
363
- $ vn, $ inds = $ (varname (left)), $ (vinds (left))
364
- $ isassumption = $ DynamicPPL. @isassumption ($ model, $ left)
369
+ if left isa Symbol || left isa Expr
370
+ @gensym out vn inds
371
+ push! (expr. args,
372
+ :($ vn = $ (varname (left))),
373
+ :($ inds = $ (vinds (left))))
374
+
375
+ assumption = quote
376
+ $ out = $ (DynamicPPL. dot_tilde_assume)($ ctx, $ sampler, $ tmpright, $ left,
377
+ $ vn, $ inds, $ vi)
378
+ $ left .= $ out[1 ]
379
+ $ (DynamicPPL. acclogp!)($ vi, $ out[2 ])
380
+ end
365
381
366
- if $ isassumption
367
- $ out = $ DynamicPPL. dot_tilde_assume ($ ctx, $ sampler, $ temp_right, $ left, $ vn, $ inds, $ vi)
368
- $ left .= $ out[1 ]
369
- $ DynamicPPL. acclogp! ($ vi, $ out[2 ])
370
- else
371
- $ DynamicPPL. acclogp! (
372
- $ vi,
373
- $ DynamicPPL. dot_tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vn, $ inds, $ vi),
374
- )
382
+ # It can only be an observation if the LHS is an argument of the model
383
+ if vsym (left) in model_info[:args ]
384
+ @gensym isassumption
385
+ return quote
386
+ $ expr
387
+ $ isassumption = $ (DynamicPPL. isassumption (model, left))
388
+ if $ isassumption
389
+ $ assumption
390
+ else
391
+ $ (DynamicPPL. acclogp!)(
392
+ $ vi,
393
+ $ (DynamicPPL. dot_tilde_observe)($ ctx, $ sampler, $ tmpright, $ left,
394
+ $ vn, $ inds, $ vi)
395
+ )
396
+ end
375
397
end
376
398
end
377
- else
378
- # we have a literal, which is automatically an observation
379
- ex = quote
380
- $ temp_right = $ right
381
- $ assert_ex
382
-
383
- $ DynamicPPL. acclogp! (
384
- $ vi,
385
- $ DynamicPPL. dot_tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vi),
386
- )
399
+
400
+ return quote
401
+ $ expr
402
+ $ assumption
387
403
end
388
404
end
389
- return ex
405
+
406
+ # If the LHS is a literal, it is always an observation
407
+ return quote
408
+ $ expr
409
+ $ (DynamicPPL. acclogp!)(
410
+ $ vi,
411
+ $ (DynamicPPL. dot_tilde_observe)($ ctx, $ sampler, $ tmpright, $ left, $ vi)
412
+ )
413
+ end
390
414
end
391
415
392
416
const FloatOrArrayType = Type{<: Union{AbstractFloat, AbstractArray} }
@@ -425,39 +449,27 @@ function build_output(model_info)
425
449
426
450
unwrap_data_expr = Expr (:block )
427
451
for var in arg_syms
428
- temp_var = gensym (:temp_var )
429
- varT = gensym (:varT )
430
- push! (unwrap_data_expr. args, quote
431
- local $ var
432
- $ temp_var = $ model. args.$ var
433
- $ varT = typeof ($ temp_var)
434
- if $ temp_var isa $ DynamicPPL. FloatOrArrayType
435
- $ var = $ DynamicPPL. get_matching_type ($ sampler, $ vi, $ temp_var)
436
- elseif $ DynamicPPL. hasmissing ($ varT)
437
- $ var = $ DynamicPPL. get_matching_type ($ sampler, $ vi, $ varT)($ temp_var)
438
- else
439
- $ var = $ temp_var
440
- end
441
- end )
452
+ push! (unwrap_data_expr. args,
453
+ :($ var = $ (DynamicPPL. matchingvalue)($ sampler, $ vi, $ (model). args.$ var)))
442
454
end
443
455
444
456
@gensym (evaluator, generator)
445
457
generator_kw_form = isempty (args) ? () : (:($ generator (;$ (args... )) = $ generator ($ (arg_syms... ))),)
446
- model_gen_constructor = :($ DynamicPPL. ModelGen {$(Tuple(arg_syms))} ($ generator, $ defaults_nt))
458
+ model_gen_constructor = :($ ( DynamicPPL. ModelGen) {$ (Tuple (arg_syms))}($ generator, $ defaults_nt))
447
459
448
460
ex = quote
449
461
function $evaluator (
450
- $ model:: $DynamicPPL.Model ,
451
- $ vi:: $DynamicPPL.VarInfo ,
452
- $ sampler:: $DynamicPPL.AbstractSampler ,
453
- $ ctx:: $DynamicPPL.AbstractContext ,
462
+ $ model:: $ ( DynamicPPL. Model) ,
463
+ $ vi:: $ ( DynamicPPL. VarInfo) ,
464
+ $ sampler:: $ ( DynamicPPL. AbstractSampler) ,
465
+ $ ctx:: $ ( DynamicPPL. AbstractContext) ,
454
466
)
455
467
$ unwrap_data_expr
456
- $ DynamicPPL. resetlogp! ($ vi)
468
+ $ ( DynamicPPL. resetlogp!) ($ vi)
457
469
$ main_body
458
470
end
459
471
460
- $ generator ($ (args... )) = $ DynamicPPL. Model ($ evaluator, $ args_nt, $ model_gen_constructor)
472
+ $ generator ($ (args... )) = $ ( DynamicPPL. Model) ($ evaluator, $ args_nt, $ model_gen_constructor)
461
473
$ (generator_kw_form... )
462
474
463
475
$ model_gen = $ model_gen_constructor
@@ -474,6 +486,21 @@ function warn_empty(body)
474
486
return
475
487
end
476
488
489
+ """
490
+ matchingvalue(sampler, vi, value)
491
+
492
+ Convert the `value` to the correct type for the `sampler` and the `vi` object.
493
+ """
494
+ function matchingvalue (sampler, vi, value)
495
+ T = typeof (value)
496
+ if hasmissing (T)
497
+ return get_matching_type (sampler, vi, T)(value)
498
+ else
499
+ return value
500
+ end
501
+ end
502
+ matchingvalue (sampler, vi, value:: FloatOrArrayType ) = get_matching_type (sampler, vi, value)
503
+
477
504
"""
478
505
get_matching_type(spl, vi, ::Type{T}) where {T}
479
506
Get the specialized version of type `T` for sampler `spl`. For example,
0 commit comments