1
- """
2
- struct ModelGen{Targs, F, Tdefaults} <: Function
3
- f::F
4
- defaults::Tdefaults
5
- end
6
-
7
- A `Model` generator. This is the output of the `@model` macro. `Targs` is the tuple
8
- of the symbols of the model's arguments. `defaults` is the `NamedTuple` of default values
9
- of the arguments, if any. Every `ModelGen` is callable with the arguments `Targs`,
10
- returning an instance of `Model`.
11
- """
12
- struct ModelGen{Targs, F, Tdefaults} <: Function
13
- f:: F
14
- defaults:: Tdefaults
15
- end
16
- ModelGen {Targs} (args... ) where {Targs} = ModelGen {Targs, typeof.(args)...} (args... )
17
- (m:: ModelGen )(args... ; kwargs... ) = m. f (args... ; kwargs... )
18
- function Base. getproperty (m:: ModelGen{Targs} , f:: Symbol ) where {Targs}
19
- f === :args && return Targs
20
- return Base. getfield (m, f)
21
- end
22
-
23
1
macro varinfo ()
24
2
:(throw (_error_msg ()))
25
3
end
@@ -61,18 +39,18 @@ Otherwise, the value of `x[1]` is returned.
61
39
macro preprocess (data_vars, missing_vars, ex)
62
40
ex
63
41
end
64
- macro preprocess (data_vars, missing_vars , ex:: Union{Symbol, Expr} )
42
+ macro preprocess (model , ex:: Union{Symbol, Expr} )
65
43
sym = gensym (:sym )
66
44
lhs = gensym (:lhs )
67
45
return esc (quote
68
46
# Extract symbol
69
47
$ sym = Val ($ (vsym (ex)))
70
48
# This branch should compile nicely in all cases except for partial missing data
71
49
# For example, when `ex` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
72
- if ! DynamicPPL. inparams ($ sym, $ data_vars ) || DynamicPPL. inparams ($ sym, $ missing_vars )
50
+ if ! DynamicPPL. inargnames ($ sym, $ model ) || DynamicPPL. inmissings ($ sym, $ model )
73
51
$ (varname (ex)), $ (vinds (ex))
74
52
else
75
- if DynamicPPL. inparams ($ sym, $ data_vars )
53
+ if DynamicPPL. inargnames ($ sym, $ model )
76
54
# Evaluate the lhs
77
55
$ lhs = $ ex
78
56
if $ lhs === missing
@@ -86,9 +64,7 @@ macro preprocess(data_vars, missing_vars, ex::Union{Symbol, Expr})
86
64
end
87
65
end )
88
66
end
89
- @generated function inparams (:: Val{s} , :: Val{t} ) where {s, t}
90
- return (s in t) ? :(true ) : :(false )
91
- end
67
+
92
68
93
69
# ################
94
70
# Main Compiler #
@@ -151,7 +127,7 @@ function build_model_info(input_expr)
151
127
else
152
128
nt_type = Expr (:curly , :NamedTuple ,
153
129
Expr (:tuple , QuoteNode .(arg_syms)... ),
154
- Expr (:curly , :Tuple , [:(DynamicPPL . get_type ($ x)) for x in arg_syms]. .. )
130
+ Expr (:curly , :Tuple , [:(Core . Typeof ($ x)) for x in arg_syms]. .. )
155
131
)
156
132
args_nt = Expr (:call , :(DynamicPPL. namedtuple), nt_type, Expr (:tuple , arg_syms... ))
157
133
end
@@ -205,27 +181,13 @@ function build_model_info(input_expr)
205
181
:ctx => gensym (:ctx ),
206
182
:vi => gensym (:vi ),
207
183
:sampler => gensym (:sampler ),
208
- :model => gensym (:model ),
209
- :inner_function => gensym (:inner_function ),
210
- :defaults => gensym (:defaults )
184
+ :model => gensym (:model )
211
185
)
212
186
)
213
187
214
188
return model_info
215
189
end
216
190
217
- function to_namedtuple_expr (syms:: Vector , vals = syms)
218
- if length (syms) == 0
219
- nt = :(NamedTuple ())
220
- else
221
- nt_type = Expr (:curly , :NamedTuple ,
222
- Expr (:tuple , QuoteNode .(syms)... ),
223
- Expr (:curly , :Tuple , [:(DynamicPPL. get_type ($ x)) for x in vals]. .. )
224
- )
225
- nt = Expr (:call , :(DynamicPPL. namedtuple), nt_type, Expr (:tuple , vals... ))
226
- end
227
- return nt
228
- end
229
191
230
192
"""
231
193
replace_vi!(model_info)
@@ -319,14 +281,16 @@ function replace_tilde!(model_info)
319
281
end
320
282
""" |> Meta. parse |> eval
321
283
284
+ # """ Unbreak code highlighting in Emacs julia-mode
285
+
286
+
322
287
"""
323
288
generate_tilde(left, right, model_info)
324
289
325
290
The `tilde` function generates `observe` expression for data variables and `assume`
326
291
expressions for parameter variables, updating `model_info` in the process.
327
292
"""
328
293
function generate_tilde (left, right, model_info)
329
- arg_syms = Val ((model_info[:arg_syms ]. .. ,))
330
294
model = model_info[:main_body_names ][:model ]
331
295
vi = model_info[:main_body_names ][:vi ]
332
296
ctx = model_info[:main_body_names ][:ctx ]
@@ -342,7 +306,7 @@ function generate_tilde(left, right, model_info)
342
306
ex = quote
343
307
$ temp_right = $ right
344
308
$ assert_ex
345
- $ preprocessed = DynamicPPL. @preprocess ($ arg_syms, DynamicPPL . getmissing ( $ model) , $ left)
309
+ $ preprocessed = DynamicPPL. @preprocess ($ model, $ left)
346
310
if $ preprocessed isa Tuple
347
311
$ vn, $ inds = $ preprocessed
348
312
$ out = DynamicPPL. tilde ($ ctx, $ sampler, $ temp_right, $ vn, $ inds, $ vi)
374
338
This function returns the expression that replaces `left .~ right` in the model body. If `preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block will be run.
375
339
"""
376
340
function generate_dot_tilde (left, right, model_info)
377
- arg_syms = Val ((model_info[:arg_syms ]. .. ,))
378
341
model = model_info[:main_body_names ][:model ]
379
342
vi = model_info[:main_body_names ][:vi ]
380
343
ctx = model_info[:main_body_names ][:ctx ]
@@ -391,7 +354,7 @@ function generate_dot_tilde(left, right, model_info)
391
354
ex = quote
392
355
$ temp_right = $ right
393
356
$ assert_ex
394
- $ preprocessed = DynamicPPL. @preprocess ($ arg_syms, DynamicPPL . getmissing ( $ model) , $ left)
357
+ $ preprocessed = DynamicPPL. @preprocess ($ model, $ left)
395
358
if $ preprocessed isa Tuple
396
359
$ vn, $ inds = $ preprocessed
397
360
$ temp_left = $ left
@@ -437,7 +400,6 @@ function build_output(model_info)
437
400
vi = main_body_names[:vi ]
438
401
model = main_body_names[:model ]
439
402
sampler = main_body_names[:sampler ]
440
- inner_function = main_body_names[:inner_function ]
441
403
442
404
# Arguments with default values
443
405
args = model_info[:args ]
@@ -452,16 +414,9 @@ function build_output(model_info)
452
414
whereparams = model_info[:whereparams ]
453
415
# Model generator name
454
416
model_gen = model_info[:name ]
455
- # Outer function name
456
- outer_function = gensym (model_info[:name ])
457
417
# Main body of the model
458
418
main_body = model_info[:main_body ]
459
- model_gen_constructor = quote
460
- DynamicPPL. ModelGen {$(Tuple(arg_syms))} (
461
- $ outer_function,
462
- $ defaults_nt,
463
- )
464
- end
419
+
465
420
unwrap_data_expr = Expr (:block )
466
421
for var in arg_syms
467
422
temp_var = gensym (:temp_var )
@@ -480,40 +435,32 @@ function build_output(model_info)
480
435
end )
481
436
end
482
437
438
+ @gensym (evaluator, generator)
439
+ generator_kw_form = isempty (args) ? () : (:($ generator (;$ (args... )) = $ generator ($ (arg_syms... ))),)
440
+ model_gen_constructor = :(DynamicPPL. ModelGen {$(Tuple(arg_syms))} ($ generator, $ defaults_nt))
441
+
483
442
ex = quote
484
- function $outer_function ($ (args... ))
485
- function $inner_function (
486
- $ vi:: DynamicPPL.VarInfo ,
487
- $ sampler:: DynamicPPL.AbstractSampler ,
488
- $ ctx:: DynamicPPL.AbstractContext ,
489
- $ model
490
- )
491
- $ unwrap_data_expr
492
- DynamicPPL. resetlogp! ($ vi)
493
- $ main_body
494
- end
495
- return DynamicPPL. Model ($ inner_function, $ args_nt, $ model_gen_constructor)
443
+ function $evaluator (
444
+ $ model:: Model ,
445
+ $ vi:: DynamicPPL.VarInfo ,
446
+ $ sampler:: DynamicPPL.AbstractSampler ,
447
+ $ ctx:: DynamicPPL.AbstractContext ,
448
+ )
449
+ $ unwrap_data_expr
450
+ DynamicPPL. resetlogp! ($ vi)
451
+ $ main_body
496
452
end
497
- $ model_gen = $ model_gen_constructor
498
- end
453
+
499
454
500
- if ! isempty (args)
501
- ex = quote
502
- $ ex
503
- # Allows passing arguments as kwargs
504
- $ outer_function (;$ (args... )) = $ outer_function ($ (arg_syms... ))
505
- end
455
+ $ generator ($ (args... )) = DynamicPPL. Model ($ evaluator, $ args_nt, $ model_gen_constructor)
456
+ $ (generator_kw_form... )
457
+
458
+ $ model_gen = $ model_gen_constructor
506
459
end
507
460
508
461
return esc (ex)
509
462
end
510
463
511
- # A hack for NamedTuple type specialization
512
- # (T = Int,) has type NamedTuple{(:T,), Tuple{DataType}} by default
513
- # With this function, we can make it NamedTuple{(:T,), Tuple{Type{Int}}}
514
- # Both are correct, but the latter is what we want for type stability
515
- get_type (:: Type{T} ) where {T} = Type{T}
516
- get_type (t) = typeof (t)
517
464
518
465
function warn_empty (body)
519
466
if all (l -> isa (l, LineNumberNode), body. args)
0 commit comments