Skip to content

Commit 64de286

Browse files
committed
Rename getargtypes -> getargnames
1 parent cf6c542 commit 64de286

File tree

3 files changed

+41
-35
lines changed

3 files changed

+41
-35
lines changed

src/compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ function build_output(model_info)
475475
end
476476

477477
DynamicPPL.getdefaults(::typeof($model_gen)) = $defaults_nt
478-
DynamicPPL.getargtypes(::typeof($model_gen)) = $(Tuple(arg_syms))
478+
DynamicPPL.getargnames(::typeof($model_gen)) = $(Tuple(arg_syms))
479479
DynamicPPL.getmodeltype(::typeof($model_gen)) = DynamicPPL.Model{typeof($model_gen)}
480480
DynamicPPL.getgenerator(::DynamicPPL.Model{typeof($model_gen)}) = $model_gen
481481
end

src/model.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ Get a named tuple of the default argument values defined for a model defined by
4444
function getdefaults end
4545

4646
"""
47-
getargtypes(::typeof(modelgen))
47+
getargnames(::typeof(modelgen))
4848
49-
Get a named tuple of the argument
49+
Get a tuple of the argument names of the model defined by a generating function.
5050
"""
51-
function getargtypes end
51+
function getargnames end
5252

5353

5454
getmissing(model::Model) = model.missings

src/prob_macro.jl

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {names
5555
defaults = getdefaults(modelgen)
5656
valid_arg(arg) = isdefined(ntl, arg) || isdefined(ntr, arg) ||
5757
isdefined(defaults, arg) && getfield(defaults, arg) !== missing
58-
@assert all(valid_arg, getargtypes(modelgen))
58+
@assert all(valid_arg, getargnames(modelgen))
5959
return Val(:likelihood), modelgen, vi
6060
else
6161
@assert isdefined(ntr, :model)
@@ -76,9 +76,9 @@ function probtype(
7676
modelgen,
7777
defaults::NamedTuple{defs}
7878
) where {namesl, namesr, defs}
79-
args = getargtypes(modelgen)
79+
argnames = getargnames(modelgen)
8080
prior_rhs = all(n -> n in (:model, :varinfo) ||
81-
n in args && getfield(ntr, n) !== missing, namesr)
81+
n in argnames && getfield(ntr, n) !== missing, namesr)
8282
function get_arg(arg)
8383
if arg in namesl
8484
return getfield(ntl, arg)
@@ -94,7 +94,7 @@ function probtype(
9494
a = get_arg(arg)
9595
return a !== nothing && a !== missing
9696
end
97-
valid_args = all(valid_arg, args)
97+
valid_args = all(valid_arg, argnames)
9898

9999
# Uses the default values for model arguments not provided.
100100
# If no default value exists, use `nothing`.
@@ -105,9 +105,9 @@ function probtype(
105105
elseif valid_args
106106
return Val(:likelihood)
107107
else
108-
for arg in args
109-
if !valid_arg(args)
110-
throw(ArgumentError(missing_arg_error_msg(arg, get_arg(arg))))
108+
for argname in argnames
109+
if !valid_arg(argname)
110+
throw(ArgumentError(missing_arg_error_msg(argname, get_arg(argname))))
111111
end
112112
end
113113
end
@@ -134,7 +134,10 @@ function logprior(
134134
# All `observe` and `dot_observe` calls are no-op in the PriorContext
135135

136136
# When all of model args are on the lhs of |, this is also equal to the logjoint.
137-
args, missing_vars = get_prior_model_args(left, right, Val{getargtypes(modelgen)}(), getdefaults(modelgen))
137+
args, missing_vars = get_prior_model_args(left,
138+
right,
139+
Val{getargnames(modelgen)}(),
140+
getdefaults(modelgen))
138141
model = Model{typeof(modelgen)}(args, missing_vars)
139142
vi = _vi === nothing ? VarInfo(deepcopy(model), PriorContext()) : _vi
140143
foreach(keys(vi.metadata)) do n
@@ -146,23 +149,23 @@ end
146149
@generated function get_prior_model_args(
147150
left::NamedTuple{namesl},
148151
right::NamedTuple{namesr},
149-
::Val{args},
152+
_args::Val{argnames},
150153
defaults::NamedTuple{default_args}
151-
) where {namesl, namesr, args, default_args}
154+
) where {namesl, namesr, argnames, default_args}
152155
exprs = []
153156
missing_args = []
154157
warn_expr = Expr(:block)
155-
foreach(args) do arg
156-
if arg in namesl
157-
push!(exprs, :($arg = deepcopy(left.$arg)))
158-
push!(missing_args, arg)
159-
elseif arg in namesr
160-
push!(exprs, :($arg = right.$arg))
161-
elseif arg in default_args
162-
push!(exprs, :($arg = defaults.$arg))
158+
foreach(argnames) do argname
159+
if argname in namesl
160+
push!(exprs, :($argname = deepcopy(left.$argname)))
161+
push!(missing_args, argname)
162+
elseif argname in namesr
163+
push!(exprs, :($argname = right.$argname))
164+
elseif argname in default_args
165+
push!(exprs, :($argname = defaults.$argname))
163166
else
164-
push!(warn_expr.args, :(@warn(warn_msg($(QuoteNode(arg))))))
165-
push!(exprs, :($arg = nothing))
167+
push!(warn_expr.args, :(@warn(warn_msg($(QuoteNode(argname))))))
168+
push!(exprs, :($argname = nothing))
166169
end
167170
end
168171
missing_vars = :(Val{($missing_args...,)}())
@@ -188,7 +191,10 @@ function loglikelihood(
188191
_vi::Union{Nothing, VarInfo},
189192
)
190193
# Pass namesl to model constructor, remaining args are missing
191-
args, missing_vars = get_like_model_args(left, right, Val{getargtypes(modelgen)}(), getdefaults(modelgen))
194+
args, missing_vars = get_like_model_args(left,
195+
right,
196+
Val{getargnames(modelgen)}(),
197+
getdefaults(modelgen))
192198
model = Model{typeof(modelgen)}(args, missing_vars)
193199
vi = _vi === nothing ? VarInfo(deepcopy(model)) : _vi
194200
if isdefined(right, :chain)
@@ -212,19 +218,19 @@ end
212218
@generated function get_like_model_args(
213219
left::NamedTuple{namesl},
214220
right::NamedTuple{namesr},
215-
::Val{args},
221+
_args::Val{argnames},
216222
defaults::NamedTuple{default_args},
217-
) where {namesl, namesr, args, default_args}
223+
) where {namesl, namesr, argnames, default_args}
218224
exprs = []
219225
missing_args = []
220-
foreach(args) do arg
221-
if arg in namesl
222-
push!(exprs, :($arg = left.$arg))
223-
elseif arg in namesr
224-
push!(exprs, :($arg = right.$arg))
225-
push!(missing_args, arg)
226-
elseif arg in default_args
227-
push!(exprs, :($arg = defaults.$arg))
226+
foreach(argnames) do argname
227+
if argname in namesl
228+
push!(exprs, :($argname = left.$argname))
229+
elseif argname in namesr
230+
push!(exprs, :($argname = right.$argname))
231+
push!(missing_args, argname)
232+
elseif argname in default_args
233+
push!(exprs, :($argname = defaults.$argname))
228234
else
229235
throw("This point should not be reached. Please open an issue in the DynamicPPL.jl repository.")
230236
end

0 commit comments

Comments
 (0)