@@ -11,59 +11,42 @@ function _error_msg()
11
11
return " This macro is only for use in the `@model` macro and not for external use."
12
12
end
13
13
14
-
15
-
16
- # Check if the right-hand side is a distribution.
17
- function assert_dist (dist; msg)
18
- isa (dist, Distribution) || throw (ArgumentError (msg))
19
- end
20
- function assert_dist (dist:: AbstractVector ; msg)
21
- all (d -> isa (d, Distribution), dist) || throw (ArgumentError (msg))
22
- end
23
-
24
- function wrong_dist_errormsg (l)
25
- return " Right-hand side of a ~ must be subtype of Distribution or a vector of " *
26
- " Distributions on line $(l) ."
27
- end
14
+ const DISTMSG = " Right-hand side of a ~ must be subtype of Distribution or a vector of " *
15
+ " Distributions."
28
16
29
17
"""
30
- @isassumption(model, expr)
18
+ isassumption(model, expr)
19
+
20
+ Return an expression that can be evaluated to check if `expr` is an assumption in the
21
+ `model`.
31
22
32
- Let `expr` be `x[1]`. `vn` is an assumption in the following cases:
33
- 1. `x` was not among the input data to the model,
34
- 2. `x` was among the input data to the model but with a value `missing`, or
35
- 3. `x` was among the input data to the model with a value other than missing,
23
+ Let `expr` be `:( x[1]) `. It is an assumption in the following cases:
24
+ 1. `x` is not among the input data to the ` model` ,
25
+ 2. `x` is among the input data to the ` model` but with a value `missing`, or
26
+ 3. `x` is among the input data to the ` model` with a value other than missing,
36
27
but `x[1] === missing`.
28
+
37
29
When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
38
30
"""
39
- macro isassumption (model, expr:: Union{Symbol, Expr} )
40
- # Note: never put a return in this... don't forget it's a macro!
31
+ function isassumption (model, expr:: Union{Symbol, Expr} )
41
32
vn = gensym (:vn )
42
-
33
+
43
34
return quote
44
- $ vn = @varname ($ expr)
45
-
46
- # This branch should compile nicely in all cases except for partial missing data
47
- # For example, when `expr` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
48
- if ! DynamicPPL. inargnames ($ vn, $ model) || DynamicPPL. inmissings ($ vn, $ model)
49
- true
50
- else
51
- if DynamicPPL. inargnames ($ vn, $ model)
52
- # Evaluate the lhs
53
- $ expr === missing
35
+ let $ vn = $ (varname (expr))
36
+ # This branch should compile nicely in all cases except for partial missing data
37
+ # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
38
+ if ! $ (DynamicPPL. inargnames)($ vn, $ model) || $ (DynamicPPL. inmissings)($ vn, $ model)
39
+ true
54
40
else
55
- throw (" This point should not be reached. Please report this error." )
41
+ # Evaluate the LHS
42
+ $ expr === missing
56
43
end
57
44
end
58
- end |> esc
59
- end
60
-
61
- macro isassumption (model, expr)
62
- # failsafe: a literal is never an assumption
63
- false
45
+ end
64
46
end
65
47
66
-
48
+ # failsafe: a literal is never an assumption
49
+ isassumption (model, expr) = :(false )
67
50
68
51
# ################
69
52
# Main Compiler #
@@ -128,7 +111,7 @@ function build_model_info(input_expr)
128
111
Expr (:tuple , QuoteNode .(arg_syms)... ),
129
112
Expr (:curly , :Tuple , [:(Core. Typeof ($ x)) for x in arg_syms]. .. )
130
113
)
131
- args_nt = Expr (:call , :(DynamicPPL . namedtuple), nt_type, Expr (:tuple , arg_syms... ))
114
+ args_nt = Expr (:call , :($ namedtuple), nt_type, Expr (:tuple , arg_syms... ))
132
115
end
133
116
args = map (modeldef[:args ]) do arg
134
117
if (arg isa Symbol)
@@ -217,7 +200,7 @@ function replace_logpdf!(model_info)
217
200
vi = model_info[:main_body_names ][:vi ]
218
201
ex = MacroTools. postwalk (ex) do x
219
202
if @capture (x, @logpdf ())
220
- :($ vi . logp[])
203
+ :($ (vi) . logp[])
221
204
else
222
205
x
223
206
end
@@ -261,14 +244,14 @@ function replace_tilde!(model_info)
261
244
dotargs = getargs_dottilde (x)
262
245
if dotargs != = nothing
263
246
L, R = dotargs
264
- return generate_dot_tilde (L, R, model_info)
247
+ return Base . remove_linenums! ( generate_dot_tilde (L, R, model_info) )
265
248
end
266
249
267
250
# Check tilde.
268
251
args = getargs_tilde (x)
269
252
if args != = nothing
270
253
L, R = args
271
- return generate_tilde (L, R, model_info)
254
+ return Base . remove_linenums! ( generate_tilde (L, R, model_info) )
272
255
end
273
256
274
257
return x
@@ -294,45 +277,55 @@ function generate_tilde(left, right, model_info)
294
277
vi = model_info[:main_body_names ][:vi ]
295
278
ctx = model_info[:main_body_names ][:ctx ]
296
279
sampler = model_info[:main_body_names ][:sampler ]
297
- temp_right = gensym (:temp_right )
298
- out = gensym (:out )
299
- lp = gensym (:lp )
300
- vn = gensym (:vn )
301
- inds = gensym (:inds )
302
- isassumption = gensym (:isassumption )
303
- assert_ex = :(DynamicPPL. assert_dist ($ temp_right, msg = $ (wrong_dist_errormsg (@__LINE__ ))))
304
-
280
+
281
+ @gensym tmpright
282
+ top = [:($ tmpright = $ right),
283
+ :($ tmpright isa Union{$ Distribution,AbstractVector{<: $Distribution }}
284
+ || throw (ArgumentError ($ DISTMSG)))]
285
+
305
286
if left isa Symbol || left isa Expr
306
- ex = quote
307
- $ temp_right = $ right
308
- $ assert_ex
309
-
310
- $ vn, $ inds = $ (varname (left)), $ (vinds (left))
311
- $ isassumption = DynamicPPL. @isassumption ($ model, $ left)
312
- if $ isassumption
313
- $ out = DynamicPPL. tilde_assume ($ ctx, $ sampler, $ temp_right, $ vn, $ inds, $ vi)
314
- $ left = $ out[1 ]
315
- DynamicPPL. acclogp! ($ vi, $ out[2 ])
316
- else
317
- DynamicPPL. acclogp! (
318
- $ vi,
319
- DynamicPPL. tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vn, $ inds, $ vi),
320
- )
287
+ @gensym out vn inds
288
+ push! (top, :($ vn = $ (varname (left))), :($ inds = $ (vinds (left))))
289
+
290
+ assumption = [
291
+ :($ out = $ (DynamicPPL. tilde_assume)($ ctx, $ sampler, $ tmpright, $ vn, $ inds,
292
+ $ vi)),
293
+ :($ left = $ out[1 ]),
294
+ :($ (DynamicPPL. acclogp!)($ vi, $ out[2 ]))
295
+ ]
296
+
297
+ # It can only be an observation if the LHS is an argument of the model
298
+ if vsym (left) in model_info[:args ]
299
+ @gensym isassumption
300
+ return quote
301
+ $ (top... )
302
+ $ isassumption = $ (DynamicPPL. isassumption (model, left))
303
+ if $ isassumption
304
+ $ (assumption... )
305
+ else
306
+ $ (DynamicPPL. acclogp!)(
307
+ $ vi,
308
+ $ (DynamicPPL. tilde_observe)($ ctx, $ sampler, $ tmpright, $ left, $ vn,
309
+ $ inds, $ vi)
310
+ )
311
+ end
321
312
end
322
313
end
323
- else
324
- # we have a literal, which is automatically an observation
325
- ex = quote
326
- $ temp_right = $ right
327
- $ assert_ex
328
-
329
- DynamicPPL. acclogp! (
330
- $ vi,
331
- DynamicPPL. tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vi),
332
- )
314
+
315
+ return quote
316
+ $ (top... )
317
+ $ (assumption... )
333
318
end
334
319
end
335
- return ex
320
+
321
+ # If the LHS is a literal, it is always an observation
322
+ return quote
323
+ $ (top... )
324
+ $ (DynamicPPL. acclogp!)(
325
+ $ vi,
326
+ $ (DynamicPPL. tilde_observe)($ ctx, $ sampler, $ tmpright, $ left, $ vi)
327
+ )
328
+ end
336
329
end
337
330
338
331
"""
@@ -347,46 +340,55 @@ function generate_dot_tilde(left, right, model_info)
347
340
vi = model_info[:main_body_names ][:vi ]
348
341
ctx = model_info[:main_body_names ][:ctx ]
349
342
sampler = model_info[:main_body_names ][:sampler ]
350
- out = gensym (:out )
351
- temp_right = gensym (:temp_right )
352
- isassumption = gensym (:isassumption )
353
- lp = gensym (:lp )
354
- vn = gensym (:vn )
355
- inds = gensym (:inds )
356
- assert_ex = :(DynamicPPL. assert_dist ($ temp_right, msg = $ (wrong_dist_errormsg (@__LINE__ ))))
357
-
343
+
344
+ @gensym tmpright
345
+ top = [:($ tmpright = $ right),
346
+ :($ tmpright isa Union{$ Distribution,AbstractVector{<: $Distribution }}
347
+ || throw (ArgumentError ($ DISTMSG)))]
348
+
358
349
if left isa Symbol || left isa Expr
359
- ex = quote
360
- $ temp_right = $ right
361
- $ assert_ex
362
-
363
- $ vn, $ inds = $ (varname (left)), $ (vinds (left))
364
- $ isassumption = DynamicPPL. @isassumption ($ model, $ left)
365
-
366
- if $ isassumption
367
- $ out = DynamicPPL. dot_tilde_assume ($ ctx, $ sampler, $ temp_right, $ left, $ vn, $ inds, $ vi)
368
- $ left .= $ out[1 ]
369
- DynamicPPL. acclogp! ($ vi, $ out[2 ])
370
- else
371
- DynamicPPL. acclogp! (
372
- $ vi,
373
- DynamicPPL. dot_tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vn, $ inds, $ vi),
374
- )
350
+ @gensym out vn inds
351
+ push! (top, :($ vn = $ (varname (left))), :($ inds = $ (vinds (left))))
352
+
353
+ assumption = [
354
+ :($ out = $ (DynamicPPL. dot_tilde_assume)($ ctx, $ sampler, $ tmpright, $ left,
355
+ $ vn, $ inds, $ vi)),
356
+ :($ left .= $ out[1 ]),
357
+ :($ (DynamicPPL. acclogp!)($ vi, $ out[2 ]))
358
+ ]
359
+
360
+ # It can only be an observation if the LHS is an argument of the model
361
+ if vsym (left) in model_info[:args ]
362
+ @gensym isassumption
363
+ return quote
364
+ $ (top... )
365
+ $ isassumption = $ (DynamicPPL. isassumption (model, left))
366
+ if $ isassumption
367
+ $ (assumption... )
368
+ else
369
+ $ (DynamicPPL. acclogp!)(
370
+ $ vi,
371
+ $ (DynamicPPL. dot_tilde_observe)($ ctx, $ sampler, $ tmpright, $ left,
372
+ $ vn, $ inds, $ vi)
373
+ )
374
+ end
375
375
end
376
376
end
377
- else
378
- # we have a literal, which is automatically an observation
379
- ex = quote
380
- $ temp_right = $ right
381
- $ assert_ex
382
-
383
- DynamicPPL. acclogp! (
384
- $ vi,
385
- DynamicPPL. dot_tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vi),
386
- )
377
+
378
+ return quote
379
+ $ (top... )
380
+ $ (assumption... )
387
381
end
388
382
end
389
- return ex
383
+
384
+ # If the LHS is a literal, it is always an observation
385
+ return quote
386
+ $ (top... )
387
+ $ (DynamicPPL. acclogp!)(
388
+ $ vi,
389
+ $ (DynamicPPL. dot_tilde_observe)($ ctx, $ sampler, $ tmpright, $ left, $ vi)
390
+ )
391
+ end
390
392
end
391
393
392
394
const FloatOrArrayType = Type{<: Union{AbstractFloat, AbstractArray} }
@@ -425,42 +427,29 @@ function build_output(model_info)
425
427
426
428
unwrap_data_expr = Expr (:block )
427
429
for var in arg_syms
428
- temp_var = gensym (:temp_var )
429
- varT = gensym (:varT )
430
- push! (unwrap_data_expr. args, quote
431
- local $ var
432
- $ temp_var = $ model. args.$ var
433
- $ varT = typeof ($ temp_var)
434
- if $ temp_var isa DynamicPPL. FloatOrArrayType
435
- $ var = DynamicPPL. get_matching_type ($ sampler, $ vi, $ temp_var)
436
- elseif DynamicPPL. hasmissing ($ varT)
437
- $ var = DynamicPPL. get_matching_type ($ sampler, $ vi, $ varT)($ temp_var)
438
- else
439
- $ var = $ temp_var
440
- end
441
- end )
430
+ push! (unwrap_data_expr. args,
431
+ :($ var = $ (DynamicPPL. matchingvalue)($ sampler, $ vi, $ (model). args.$ var)))
442
432
end
443
433
444
434
@gensym (evaluator, generator)
445
435
generator_kw_form = isempty (args) ? () : (:($ generator (;$ (args... )) = $ generator ($ (arg_syms... ))),)
446
- model_gen_constructor = :(DynamicPPL. ModelGen {$(Tuple(arg_syms))} ($ generator, $ defaults_nt))
447
-
436
+ model_gen_constructor = :($ ( DynamicPPL. ModelGen) {$ (Tuple (arg_syms))}($ generator, $ defaults_nt))
437
+
448
438
ex = quote
449
439
function $evaluator (
450
- $ model:: Model ,
451
- $ vi:: DynamicPPL.VarInfo ,
452
- $ sampler:: DynamicPPL.AbstractSampler ,
453
- $ ctx:: DynamicPPL.AbstractContext ,
440
+ $ model:: $ (DynamicPPL . Model) ,
441
+ $ vi:: $ ( DynamicPPL. VarInfo) ,
442
+ $ sampler:: $ ( DynamicPPL. AbstractSampler) ,
443
+ $ ctx:: $ ( DynamicPPL. AbstractContext) ,
454
444
)
455
445
$ unwrap_data_expr
456
- DynamicPPL. resetlogp! ($ vi)
446
+ $ ( DynamicPPL. resetlogp!) ($ vi)
457
447
$ main_body
458
448
end
459
-
460
449
461
- $ generator ($ (args... )) = DynamicPPL. Model ($ evaluator, $ args_nt, $ model_gen_constructor)
450
+ $ generator ($ (args... )) = $ ( DynamicPPL. Model) ($ evaluator, $ args_nt, $ model_gen_constructor)
462
451
$ (generator_kw_form... )
463
-
452
+
464
453
$ model_gen = $ model_gen_constructor
465
454
end
466
455
@@ -475,6 +464,21 @@ function warn_empty(body)
475
464
return
476
465
end
477
466
467
+ """
468
+ matchingvalue(sampler, vi, value)
469
+
470
+ Convert the `value` to the correct type for the `sampler` and the `vi` object.
471
+ """
472
+ function matchingvalue (sampler, vi, value)
473
+ T = typeof (value)
474
+ if hasmissing (T)
475
+ return get_matching_type (sampler, vi, T)(value)
476
+ else
477
+ return value
478
+ end
479
+ end
480
+ matchingvalue (sampler, vi, value:: FloatOrArrayType ) = get_matching_type (sampler, vi, value)
481
+
478
482
"""
479
483
get_matching_type(spl, vi, ::Type{T}) where {T}
480
484
Get the specialized version of type `T` for sampler `spl`. For example,
0 commit comments