Skip to content

Commit aba5e9f

Browse files
committed
Implement Mohamed`s suggestions
1 parent b177aa3 commit aba5e9f

File tree

3 files changed

+37
-38
lines changed

3 files changed

+37
-38
lines changed

src/compiler.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -188,18 +188,6 @@ function build_model_info(input_expr)
188188
return model_info
189189
end
190190

191-
function to_namedtuple_expr(syms::Vector, vals = syms)
192-
if length(syms) == 0
193-
nt = :(NamedTuple())
194-
else
195-
nt_type = Expr(:curly, :NamedTuple,
196-
Expr(:tuple, QuoteNode.(syms)...),
197-
Expr(:curly, :Tuple, [:(Core.Typeof($x)) for x in vals]...)
198-
)
199-
nt = Expr(:call, :(DynamicPPL.namedtuple), nt_type, Expr(:tuple, vals...))
200-
end
201-
return nt
202-
end
203191

204192
"""
205193
replace_vi!(model_info)

src/prob_macro.jl

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,7 @@ function logprior(
133133
# All `observe` and `dot_observe` calls are no-op in the PriorContext
134134

135135
# When all of model args are on the lhs of |, this is also equal to the logjoint.
136-
args, missings = get_prior_model_args(left, right, modeltype)
137-
model = Model{G, missings}(args)
136+
model = make_prior_model(left, right, modeltype)
138137
vi = _vi === nothing ? VarInfo(deepcopy(model), PriorContext()) : _vi
139138
foreach(keys(vi.metadata)) do n
140139
@assert n in keys(left) "Variable $n is not defined."
@@ -143,35 +142,35 @@ function logprior(
143142
return getlogp(vi)
144143
end
145144

146-
@generated function get_prior_model_args(
145+
@generated function make_prior_model(
147146
left::NamedTuple{leftnames},
148147
right::NamedTuple{rightnames},
149-
modeltype::Type{<:Model{_G, argnames}},
148+
modeltype::Type{<:Model{G, argnames}},
150149
defaults::NamedTuple{defaultnames}=getdefaults(modeltype)
151-
) where {leftnames, rightnames, argnames, defaultnames, _G}
152-
args = []
150+
) where {leftnames, rightnames, G, argnames, defaultnames}
151+
argvals = []
153152
missings = []
154153
warnings = []
155154

156155
for argname in argnames
157156
if argname in leftnames
158-
push!(args, :($argname = deepcopy(left.$argname)))
157+
push!(argvals, :(deepcopy(left.$argname)))
159158
push!(missings, argname)
160159
elseif argname in rightnames
161-
push!(args, :($argname = right.$argname))
160+
push!(argvals, :(right.$argname))
162161
elseif argname in defaultnames
163-
push!(args, :($argname = defaults.$argname))
162+
push!(argvals, :(defaults.$argname))
164163
else
165164
push!(warnings, :(@warn($(warn_msg(argname)))))
166-
push!(args, :($argname = nothing))
165+
push!(argvals, :(nothing))
167166
end
168167
end
169168

170-
# `args` is spatted as a NamedTuple expression; `missings` is splatted into a tuple and
171-
# inserted as literal
169+
# `args` is inserted as properly typed NamedTuple expression;
170+
# `missings` is splatted into a tuple at compile time and inserted as literal
172171
return quote
173172
$(warnings...)
174-
((;$(args...))), $((missings...,))
173+
DynamicPPL.Model{G, $((missings...,))}($(to_namedtuple_expr(argnames, argvals)))
175174
end
176175
end
177176

@@ -183,9 +182,7 @@ function loglikelihood(
183182
modeltype::Type{<:Model{G}},
184183
_vi::Union{Nothing, VarInfo},
185184
) where {G}
186-
# Pass namesl to model constructor, remaining args are missing
187-
args, missings = get_like_model_args(left, right, modeltype)
188-
model = Model{G, missings}(args)
185+
model = make_likelihood_model(left, right, modeltype)
189186
vi = _vi === nothing ? VarInfo(deepcopy(model)) : _vi
190187
if isdefined(right, :chain)
191188
# Element-wise likelihood for each value in chain
@@ -206,31 +203,31 @@ function loglikelihood(
206203
end
207204
end
208205

209-
@generated function get_like_model_args(
206+
@generated function make_likelihood_model(
210207
left::NamedTuple{leftnames},
211208
right::NamedTuple{rightnames},
212-
modeltype::Type{<:Model{_G, argnames}},
209+
modeltype::Type{<:Model{G, argnames}},
213210
defaults::NamedTuple{defaultnames}=getdefaults(modeltype)
214-
) where {leftnames, rightnames, argnames, defaultnames, _G}
215-
args = []
211+
) where {leftnames, rightnames, G, argnames, defaultnames}
212+
argvals = []
216213
missings = []
217214

218215
for argname in argnames
219216
if argname in leftnames
220-
push!(args, :($argname = left.$argname))
217+
push!(argvals, :(left.$argname))
221218
elseif argname in rightnames
222-
push!(args, :($argname = right.$argname))
219+
push!(argvals, :(right.$argname))
223220
push!(missings, argname)
224221
elseif argname in defaultnames
225-
push!(args, :($argname = defaults.$argname))
222+
push!(argvals, :(defaults.$argname))
226223
else
227224
throw("This point should not be reached. Please open an issue in the DynamicPPL.jl repository.")
228225
end
229226
end
230227

231-
# `args` is spatted as a NamedTuple expression; `missings` is splatted into a tuple and
232-
# inserted as literal
233-
return :((;$(args...)), $((missings...,)))
228+
# `args` is inserted as properly typed NamedTuple expression;
229+
# `missings` is splatted into a tuple at compile time and inserted as literal
230+
return :(DynamicPPL.Model{G, $((missings...,))}($(to_namedtuple_expr(argnames, argvals))))
234231
end
235232

236233
_setval!(vi::TypedVarInfo, c::AbstractChains) = _setval!(vi.metadata, vi, c)

src/utils.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
11
############################################
22
# Julia 1.2 temporary fix - Julia PR 33303 #
33
############################################
4+
function to_namedtuple_expr(syms, vals=syms)
5+
if length(syms) == 0
6+
nt = :(NamedTuple())
7+
else
8+
nt_type = Expr(:curly, :NamedTuple,
9+
Expr(:tuple, QuoteNode.(syms)...),
10+
Expr(:curly, :Tuple, [:(Core.Typeof($x)) for x in vals]...)
11+
)
12+
nt = Expr(:call, :(DynamicPPL.namedtuple), nt_type, Expr(:tuple, vals...))
13+
end
14+
return nt
15+
end
16+
17+
418
if VERSION == v"1.2"
519
@eval function namedtuple(::Type{NamedTuple{names, T}}, args::Tuple) where {names, T <: Tuple}
620
if length(args) != length(names)

0 commit comments

Comments
 (0)