1
- """
2
- struct ModelGen{Targs, F, Tdefaults} <: Function
3
- f::F
4
- defaults::Tdefaults
5
- end
6
-
7
- A `Model` generator. This is the output of the `@model` macro. `Targs` is the tuple
8
- of the symbols of the model's arguments. `defaults` is the `NamedTuple` of default values
9
- of the arguments, if any. Every `ModelGen` is callable with the arguments `Targs`,
10
- returning an instance of `Model`.
11
- """
12
- struct ModelGen{Targs, F, Tdefaults} <: Function
13
- f:: F
14
- defaults:: Tdefaults
15
- end
16
- ModelGen {Targs} (args... ) where {Targs} = ModelGen {Targs, typeof.(args)...} (args... )
17
- (m:: ModelGen )(args... ; kwargs... ) = m. f (args... ; kwargs... )
18
- function Base. getproperty (m:: ModelGen{Targs} , f:: Symbol ) where {Targs}
19
- f === :args && return Targs
20
- return Base. getfield (m, f)
21
- end
22
-
23
1
macro varinfo ()
24
2
:(throw (_error_msg ()))
25
3
end
@@ -61,18 +39,18 @@ When `ex` is not a variable (e.g., a literal), the function returns `false` as w
61
39
macro isassumption (data_vars, missing_vars, ex)
62
40
:false
63
41
end
64
- macro isassumption (data_vars, missing_vars , ex:: Union{Symbol, Expr} )
42
+ macro isassumption (model , ex:: Union{Symbol, Expr} )
65
43
sym = gensym (:sym )
66
44
lhs = gensym (:lhs )
67
45
return esc (quote
68
46
# Extract symbol
69
47
$ sym = Val ($ (vsym (ex)))
70
48
# This branch should compile nicely in all cases except for partial missing data
71
49
# For example, when `ex` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
72
- if ! DynamicPPL. inparams ($ sym, $ data_vars ) || DynamicPPL. inparams ($ sym, $ missing_vars )
50
+ if ! DynamicPPL. inargnames ($ sym, $ model ) || DynamicPPL. inmissings ($ sym, $ model )
73
51
true
74
52
else
75
- if DynamicPPL. inparams ($ sym, $ data_vars )
53
+ if DynamicPPL. inargnames ($ sym, $ model )
76
54
# Evaluate the lhs
77
55
$ lhs = $ ex
78
56
if $ lhs === missing
@@ -87,10 +65,6 @@ macro isassumption(data_vars, missing_vars, ex::Union{Symbol, Expr})
87
65
end )
88
66
end
89
67
90
- @generated function inparams (:: Val{s} , :: Val{t} ) where {s, t}
91
- return (s in t) ? :(true ) : :(false )
92
- end
93
-
94
68
# ################
95
69
# Main Compiler #
96
70
# ################
@@ -152,7 +126,7 @@ function build_model_info(input_expr)
152
126
else
153
127
nt_type = Expr (:curly , :NamedTuple ,
154
128
Expr (:tuple , QuoteNode .(arg_syms)... ),
155
- Expr (:curly , :Tuple , [:(DynamicPPL . get_type ($ x)) for x in arg_syms]. .. )
129
+ Expr (:curly , :Tuple , [:(Core . Typeof ($ x)) for x in arg_syms]. .. )
156
130
)
157
131
args_nt = Expr (:call , :(DynamicPPL. namedtuple), nt_type, Expr (:tuple , arg_syms... ))
158
132
end
@@ -206,27 +180,13 @@ function build_model_info(input_expr)
206
180
:ctx => gensym (:ctx ),
207
181
:vi => gensym (:vi ),
208
182
:sampler => gensym (:sampler ),
209
- :model => gensym (:model ),
210
- :inner_function => gensym (:inner_function ),
211
- :defaults => gensym (:defaults )
183
+ :model => gensym (:model )
212
184
)
213
185
)
214
186
215
187
return model_info
216
188
end
217
189
218
- function to_namedtuple_expr (syms:: Vector , vals = syms)
219
- if length (syms) == 0
220
- nt = :(NamedTuple ())
221
- else
222
- nt_type = Expr (:curly , :NamedTuple ,
223
- Expr (:tuple , QuoteNode .(syms)... ),
224
- Expr (:curly , :Tuple , [:(DynamicPPL. get_type ($ x)) for x in vals]. .. )
225
- )
226
- nt = Expr (:call , :(DynamicPPL. namedtuple), nt_type, Expr (:tuple , vals... ))
227
- end
228
- return nt
229
- end
230
190
231
191
"""
232
192
replace_vi!(model_info)
@@ -330,19 +290,16 @@ The `tilde` function generates `observe` expression for data variables and `assu
330
290
expressions for parameter variables, updating `model_info` in the process.
331
291
"""
332
292
function generate_tilde (left, right, model_info)
333
- arg_syms = Val ((model_info[:arg_syms ]. .. ,))
334
293
model = model_info[:main_body_names ][:model ]
335
294
vi = model_info[:main_body_names ][:vi ]
336
295
ctx = model_info[:main_body_names ][:ctx ]
337
296
sampler = model_info[:main_body_names ][:sampler ]
338
-
339
- @gensym (out,
340
- lp,
341
- vn,
342
- inds,
343
- isassumption,
344
- temp_right)
345
-
297
+ temp_right = gensym (:temp_right )
298
+ out = gensym (:out )
299
+ lp = gensym (:lp )
300
+ vn = gensym (:vn )
301
+ inds = gensym (:inds )
302
+ isassumption = gensym (:isassumption )
346
303
assert_ex = :(DynamicPPL. assert_dist ($ temp_right, msg = $ (wrong_dist_errormsg (@__LINE__ ))))
347
304
348
305
if left isa Symbol || left isa Expr
@@ -351,7 +308,7 @@ function generate_tilde(left, right, model_info)
351
308
$ assert_ex
352
309
353
310
$ vn, $ inds = $ (varname (left)), $ (vinds (left))
354
- $ isassumption = DynamicPPL. @isassumption ($ arg_syms, DynamicPPL . getmissing ( $ model) , $ left)
311
+ $ isassumption = DynamicPPL. @isassumption ($ model, $ left)
355
312
if $ isassumption
356
313
$ out = DynamicPPL. tilde_assume ($ ctx, $ sampler, $ temp_right, $ vn, $ inds, $ vi)
357
314
$ left = $ out[1 ]
@@ -386,20 +343,16 @@ This function returns the expression that replaces `left .~ right` in the model
386
343
will be run.
387
344
"""
388
345
function generate_dot_tilde (left, right, model_info)
389
- arg_syms = Val ((model_info[:arg_syms ]. .. ,))
390
346
model = model_info[:main_body_names ][:model ]
391
347
vi = model_info[:main_body_names ][:vi ]
392
348
ctx = model_info[:main_body_names ][:ctx ]
393
349
sampler = model_info[:main_body_names ][:sampler ]
394
-
395
- @gensym (out,
396
- preprocessed,
397
- lp,
398
- vn,
399
- inds,
400
- isassumption,
401
- temp_right)
402
-
350
+ out = gensym (:out )
351
+ temp_right = gensym (:temp_right )
352
+ isassumption = gensym (:isassumption )
353
+ lp = gensym (:lp )
354
+ vn = gensym (:vn )
355
+ inds = gensym (:inds )
403
356
assert_ex = :(DynamicPPL. assert_dist ($ temp_right, msg = $ (wrong_dist_errormsg (@__LINE__ ))))
404
357
405
358
if left isa Symbol || left isa Expr
@@ -408,7 +361,7 @@ function generate_dot_tilde(left, right, model_info)
408
361
$ assert_ex
409
362
410
363
$ vn, $ inds = $ (varname (left)), $ (vinds (left))
411
- $ isassumption = DynamicPPL. @isassumption ($ arg_syms, DynamicPPL . getmissing ( $ model) , $ left)
364
+ $ isassumption = DynamicPPL. @isassumption ($ model, $ left)
412
365
413
366
if $ isassumption
414
367
$ out = DynamicPPL. dot_tilde_assume ($ ctx, $ sampler, $ temp_right, $ left, $ vn, $ inds, $ vi)
@@ -453,7 +406,6 @@ function build_output(model_info)
453
406
vi = main_body_names[:vi ]
454
407
model = main_body_names[:model ]
455
408
sampler = main_body_names[:sampler ]
456
- inner_function = main_body_names[:inner_function ]
457
409
458
410
# Arguments with default values
459
411
args = model_info[:args ]
@@ -468,16 +420,9 @@ function build_output(model_info)
468
420
whereparams = model_info[:whereparams ]
469
421
# Model generator name
470
422
model_gen = model_info[:name ]
471
- # Outer function name
472
- outer_function = gensym (model_info[:name ])
473
423
# Main body of the model
474
424
main_body = model_info[:main_body ]
475
- model_gen_constructor = quote
476
- DynamicPPL. ModelGen {$(Tuple(arg_syms))} (
477
- $ outer_function,
478
- $ defaults_nt,
479
- )
480
- end
425
+
481
426
unwrap_data_expr = Expr (:block )
482
427
for var in arg_syms
483
428
temp_var = gensym (:temp_var )
@@ -496,40 +441,32 @@ function build_output(model_info)
496
441
end )
497
442
end
498
443
444
+ @gensym (evaluator, generator)
445
+ generator_kw_form = isempty (args) ? () : (:($ generator (;$ (args... )) = $ generator ($ (arg_syms... ))),)
446
+ model_gen_constructor = :(DynamicPPL. ModelGen {$(Tuple(arg_syms))} ($ generator, $ defaults_nt))
447
+
499
448
ex = quote
500
- function $outer_function ($ (args... ))
501
- function $inner_function (
502
- $ vi:: DynamicPPL.VarInfo ,
503
- $ sampler:: DynamicPPL.AbstractSampler ,
504
- $ ctx:: DynamicPPL.AbstractContext ,
505
- $ model
506
- )
507
- $ unwrap_data_expr
508
- DynamicPPL. resetlogp! ($ vi)
509
- $ main_body
510
- end
511
- return DynamicPPL. Model ($ inner_function, $ args_nt, $ model_gen_constructor)
449
+ function $evaluator (
450
+ $ model:: Model ,
451
+ $ vi:: DynamicPPL.VarInfo ,
452
+ $ sampler:: DynamicPPL.AbstractSampler ,
453
+ $ ctx:: DynamicPPL.AbstractContext ,
454
+ )
455
+ $ unwrap_data_expr
456
+ DynamicPPL. resetlogp! ($ vi)
457
+ $ main_body
512
458
end
513
- $ model_gen = $ model_gen_constructor
514
- end
459
+
515
460
516
- if ! isempty (args)
517
- ex = quote
518
- $ ex
519
- # Allows passing arguments as kwargs
520
- $ outer_function (;$ (args... )) = $ outer_function ($ (arg_syms... ))
521
- end
461
+ $ generator ($ (args... )) = DynamicPPL. Model ($ evaluator, $ args_nt, $ model_gen_constructor)
462
+ $ (generator_kw_form... )
463
+
464
+ $ model_gen = $ model_gen_constructor
522
465
end
523
466
524
467
return esc (ex)
525
468
end
526
469
527
- # A hack for NamedTuple type specialization
528
- # (T = Int,) has type NamedTuple{(:T,), Tuple{DataType}} by default
529
- # With this function, we can make it NamedTuple{(:T,), Tuple{Type{Int}}}
530
- # Both are correct, but the latter is what we want for type stability
531
- get_type (:: Type{T} ) where {T} = Type{T}
532
- get_type (t) = typeof (t)
533
470
534
471
function warn_empty (body)
535
472
if all (l -> isa (l, LineNumberNode), body. args)
0 commit comments