@@ -8,19 +8,19 @@ macro prob_str(str)
8
8
end
9
9
10
10
function get_exprs (str:: String )
11
- ind = findfirst (isequal (' |' ), str)
12
- ind === nothing && throw (" Invalid expression." )
11
+ ind = findfirst (isequal (' |' ), str)
12
+ ind === nothing && throw (" Invalid expression." )
13
13
14
- str1 = str[1 : (ind - 1 )]
15
- str2 = str[(ind + 1 ): end ]
14
+ str1 = str[1 : (ind - 1 )]
15
+ str2 = str[(ind + 1 ): end ]
16
16
17
- expr1 = Meta. parse (" ($str1 ,)" )
18
- expr1 = Expr (:tuple , expr1. args... )
17
+ expr1 = Meta. parse (" ($str1 ,)" )
18
+ expr1 = Expr (:tuple , expr1. args... )
19
19
20
- expr2 = Meta. parse (" ($str2 ,)" )
21
- expr2 = Expr (:tuple , expr2. args... )
20
+ expr2 = Meta. parse (" ($str2 ,)" )
21
+ expr2 = Expr (:tuple , expr2. args... )
22
22
23
- return expr1, expr2
23
+ return expr1, expr2
24
24
end
25
25
26
26
function logprob (ex1, ex2)
@@ -37,7 +37,7 @@ function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {names
37
37
if isdefined (ntr. chain. info, :model )
38
38
model = ntr. chain. info. model
39
39
@assert model isa Model
40
- modelgen = model. modelgen
40
+ modelgen = getgenerator ( model)
41
41
elseif isdefined (ntr, :model )
42
42
modelgen = ntr. model
43
43
else
@@ -54,10 +54,10 @@ function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {names
54
54
else
55
55
vi = nothing
56
56
end
57
- defaults = modelgen. defaults
57
+ defaults = getdefaults ( modelgen)
58
58
valid_arg (arg) = isdefined (ntl, arg) || isdefined (ntr, arg) ||
59
59
isdefined (defaults, arg) && getfield (defaults, arg) != = missing
60
- @assert all (valid_arg .(modelgen. args ))
60
+ @assert all (valid_arg .(getargtypes ( modelgen) ))
61
61
return Val (:likelihood ), modelgen, vi
62
62
else
63
63
@assert isdefined (ntr, :model )
@@ -69,15 +69,16 @@ function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {names
69
69
else
70
70
vi = nothing
71
71
end
72
- return probtype (ntl, ntr, modelgen, modelgen. defaults ), modelgen, vi
72
+ return probtype (ntl, ntr, modelgen, getdefaults ( modelgen) ), modelgen, vi
73
73
end
74
74
end
75
75
function probtype (
76
- ntl:: NamedTuple{namesl} ,
77
- ntr:: NamedTuple{namesr} ,
78
- modelgen:: ModelGen{args} ,
79
- defaults:: NamedTuple{defs} ,
80
- ) where {namesl, namesr, args, defs}
76
+ ntl:: NamedTuple{namesl} ,
77
+ ntr:: NamedTuple{namesr} ,
78
+ modelgen,
79
+ defaults:: NamedTuple{defs}
80
+ ) where {namesl, namesr, defs}
81
+ args = getargtypes (modelgen)
81
82
prior_rhs = all (n -> n in (:model , :varinfo ) ||
82
83
n in args && getfield (ntr, n) != = missing , namesr)
83
84
function get_arg (arg)
@@ -120,7 +121,7 @@ missing_arg_error_msg(arg, ::Nothing) = """Variable $arg is not defined and has
120
121
function logprior (
121
122
left:: NamedTuple ,
122
123
right:: NamedTuple ,
123
- modelgen:: ModelGen ,
124
+ modelgen,
124
125
_vi:: Union{Nothing, VarInfo} ,
125
126
)
126
127
# For model args on the LHS of |, use their passed value but add the symbol to
@@ -135,8 +136,8 @@ function logprior(
135
136
# All `observe` and `dot_observe` calls are no-op in the PriorContext
136
137
137
138
# When all of model args are on the lhs of |, this is also equal to the logjoint.
138
- args, missing_vars = get_prior_model_args (left, right, modelgen, modelgen. defaults )
139
- model = get_model (modelgen, args, missing_vars)
139
+ args, missing_vars = get_prior_model_args (left, right, Val {getargtypes( modelgen)} (), getdefaults ( modelgen) )
140
+ model = Model {typeof (modelgen)} ( args, missing_vars)
140
141
vi = _vi === nothing ? VarInfo (deepcopy (model), PriorContext ()) : _vi
141
142
foreach (keys (vi. metadata)) do n
142
143
@assert n in keys (left) " Variable $n is not defined."
147
148
@generated function get_prior_model_args (
148
149
left:: NamedTuple{namesl} ,
149
150
right:: NamedTuple{namesr} ,
150
- modelgen :: ModelGen {args} ,
151
- defaults:: NamedTuple{default_args} ,
151
+ :: Val {args} ,
152
+ defaults:: NamedTuple{default_args}
152
153
) where {namesl, namesr, args, default_args}
153
154
exprs = []
154
155
missing_args = []
@@ -182,20 +183,15 @@ end
182
183
183
184
warn_msg (arg) = " Argument $arg is not defined. A value of `nothing` is used."
184
185
185
- function get_model (modelgen, args, missing_vars)
186
- _model = modelgen (; args... )
187
- return Model (_model. f, args, modelgen, missing_vars)
188
- end
189
-
190
186
function loglikelihood (
191
187
left:: NamedTuple ,
192
188
right:: NamedTuple ,
193
- modelgen:: ModelGen ,
189
+ modelgen,
194
190
_vi:: Union{Nothing, VarInfo} ,
195
191
)
196
192
# Pass namesl to model constructor, remaining args are missing
197
- args, missing_vars = get_like_model_args (left, right, modelgen, modelgen. defaults )
198
- model = get_model (modelgen, args, missing_vars)
193
+ args, missing_vars = get_like_model_args (left, right, Val {getargtypes( modelgen)} (), getdefaults ( modelgen) )
194
+ model = Model {typeof (modelgen)} ( args, missing_vars)
199
195
vi = _vi === nothing ? VarInfo (deepcopy (model)) : _vi
200
196
if isdefined (right, :chain )
201
197
# Element-wise likelihood for each value in chain
218
214
@generated function get_like_model_args (
219
215
left:: NamedTuple{namesl} ,
220
216
right:: NamedTuple{namesr} ,
221
- modelgen :: ModelGen {args} ,
217
+ :: Val {args} ,
222
218
defaults:: NamedTuple{default_args} ,
223
219
) where {namesl, namesr, args, default_args}
224
220
exprs = []
0 commit comments