1
- """
2
- struct ModelGen{Targs, F, Tdefaults} <: Function
3
- f::F
4
- defaults::Tdefaults
5
- end
6
-
7
- A `Model` generator. This is the output of the `@model` macro. `Targs` is the tuple
8
- of the symbols of the model's arguments. `defaults` is the `NamedTuple` of default values
9
- of the arguments, if any. Every `ModelGen` is callable with the arguments `Targs`,
10
- returning an instance of `Model`.
11
- """
12
- struct ModelGen{Targs, F, Tdefaults} <: Function
13
- f:: F
14
- defaults:: Tdefaults
15
- end
16
- ModelGen {Targs} (args... ) where {Targs} = ModelGen {Targs, typeof.(args)...} (args... )
17
- (m:: ModelGen )(args... ; kwargs... ) = m. f (args... ; kwargs... )
18
- function Base. getproperty (m:: ModelGen{Targs} , f:: Symbol ) where {Targs}
19
- f === :args && return Targs
20
- return Base. getfield (m, f)
21
- end
22
-
23
1
macro varinfo ()
24
2
:(throw (_error_msg ()))
25
3
end
@@ -151,7 +129,7 @@ function build_model_info(input_expr)
151
129
else
152
130
nt_type = Expr (:curly , :NamedTuple ,
153
131
Expr (:tuple , QuoteNode .(arg_syms)... ),
154
- Expr (:curly , :Tuple , [:(DynamicPPL . get_type ($ x)) for x in arg_syms]. .. )
132
+ Expr (:curly , :Tuple , [:(Core . Typeof ($ x)) for x in arg_syms]. .. )
155
133
)
156
134
args_nt = Expr (:call , :(DynamicPPL. namedtuple), nt_type, Expr (:tuple , arg_syms... ))
157
135
end
@@ -220,7 +198,7 @@ function to_namedtuple_expr(syms::Vector, vals = syms)
220
198
else
221
199
nt_type = Expr (:curly , :NamedTuple ,
222
200
Expr (:tuple , QuoteNode .(syms)... ),
223
- Expr (:curly , :Tuple , [:(DynamicPPL . get_type ($ x)) for x in vals]. .. )
201
+ Expr (:curly , :Tuple , [:(Core . Typeof ($ x)) for x in vals]. .. )
224
202
)
225
203
nt = Expr (:call , :(DynamicPPL. namedtuple), nt_type, Expr (:tuple , vals... ))
226
204
end
@@ -440,7 +418,6 @@ function build_output(model_info)
440
418
vi = main_body_names[:vi ]
441
419
model = main_body_names[:model ]
442
420
sampler = main_body_names[:sampler ]
443
- model_function = main_body_names[:model_function ]
444
421
445
422
# Arguments with default values
446
423
args = model_info[:args ]
@@ -455,16 +432,9 @@ function build_output(model_info)
455
432
whereparams = model_info[:whereparams ]
456
433
# Model generator name
457
434
model_gen = model_info[:name ]
458
- # Outer function name
459
- outer_function = gensym (model_info[:name ])
460
435
# Main body of the model
461
436
main_body = model_info[:main_body ]
462
- model_gen_constructor = quote
463
- DynamicPPL. ModelGen {$(Tuple(arg_syms))} (
464
- $ outer_function,
465
- $ defaults_nt,
466
- )
467
- end
437
+
468
438
unwrap_data_expr = Expr (:block )
469
439
for var in arg_syms
470
440
temp_var = gensym (:temp_var )
@@ -483,44 +453,35 @@ function build_output(model_info)
483
453
end )
484
454
end
485
455
456
+ model_type = :(DynamicPPL. Model{$ (QuoteNode (model_function))})
457
+ defaults = gensym (:defaults )
458
+
486
459
ex = quote
487
- function (:: DynamicPPL.ModelFunction{$(QuoteNode(model_function))} )(
460
+ function ($ model :: $model_type )(
488
461
$ vi:: DynamicPPL.VarInfo ,
489
462
$ sampler:: DynamicPPL.AbstractSampler ,
490
463
$ ctx:: DynamicPPL.AbstractContext ,
491
- $ model
492
464
)
493
465
$ unwrap_data_expr
494
466
DynamicPPL. resetlogp! ($ vi)
495
467
$ main_body
496
468
end
497
-
498
- $ model_function = DynamicPPL. ModelFunction {$(QuoteNode(model_function))} ()
499
-
500
- function $outer_function ($ (args... ))
501
- return DynamicPPL. Model ($ model_function, $ args_nt, $ model_gen_constructor)
502
- end
503
-
504
- $ model_gen = $ model_gen_constructor
469
+
470
+ $ defaults = $ defaults_nt
471
+ getdefaults (:: $model_type ) = $ defaults
472
+ $ model_gen ($ (args... )) = $ model_type ($ args_nt, $ defaults)
505
473
end
506
474
507
475
if ! isempty (args)
508
476
# Allows passing arguments as kwargs
509
- kwform = quote
510
- $ outer_function (;$ (args... )) = $ outer_function ($ (arg_syms... ))
511
- end
477
+ kwform = :($ model_gen (;$ (args... )) = $ model_gen ($ (arg_syms... )))
512
478
push! (ex. args, kwform)
513
479
end
514
480
481
+
515
482
return esc (ex)
516
483
end
517
484
518
- # A hack for NamedTuple type specialization
519
- # (T = Int,) has type NamedTuple{(:T,), Tuple{DataType}} by default
520
- # With this function, we can make it NamedTuple{(:T,), Tuple{Type{Int}}}
521
- # Both are correct, but the latter is what we want for type stability
522
- get_type (:: Type{T} ) where {T} = Type{T}
523
- get_type (t) = typeof (t)
524
485
525
486
function warn_empty (body)
526
487
if all (l -> isa (l, LineNumberNode), body. args)
0 commit comments