Skip to content

Commit 558409b

Browse files
committed
Simplify defaults
1 parent ebf7f34 commit 558409b

File tree

2 files changed

+13
-16
lines changed

2 files changed

+13
-16
lines changed

src/compiler.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ function build_model_info(input_expr)
173173

174174
model_info = Dict(
175175
:name => modeldef[:name],
176+
:model_tag => gensym(modeldef[:name]),
176177
:main_body => modeldef[:body],
177178
:arg_syms => arg_syms,
178179
:args_nt => args_nt,
@@ -183,9 +184,7 @@ function build_model_info(input_expr)
183184
:ctx => gensym(:ctx),
184185
:vi => gensym(:vi),
185186
:sampler => gensym(:sampler),
186-
:model => gensym(:model),
187-
:model_function => gensym(modeldef[:name]),
188-
:defaults => gensym(:defaults)
187+
:model => gensym(:model)
189188
)
190189
)
191190

@@ -432,6 +431,8 @@ function build_output(model_info)
432431
whereparams = model_info[:whereparams]
433432
# Model generator name
434433
model_gen = model_info[:name]
434+
# Tag used for the model type
435+
model_tag = model_info[:model_tag]
435436
# Main body of the model
436437
main_body = model_info[:main_body]
437438

@@ -453,8 +454,7 @@ function build_output(model_info)
453454
end)
454455
end
455456

456-
model_type = :(DynamicPPL.Model{$(QuoteNode(model_function))})
457-
defaults = gensym(:defaults)
457+
model_type = :(DynamicPPL.Model{$(QuoteNode(model_tag))})
458458

459459
ex = quote
460460
function ($model::$model_type)(
@@ -467,9 +467,8 @@ function build_output(model_info)
467467
$main_body
468468
end
469469

470-
$defaults = $defaults_nt
471-
getdefaults(::$model_type) = $defaults
472-
$model_gen($(args...)) = $model_type($args_nt, $defaults)
470+
getdefaults(::$model_type) = $defaults_nt
471+
$model_gen($(args...)) = $model_type($args_nt)
473472
end
474473

475474
if !isempty(args)

src/model.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
2-
struct Model{S, Targs<:NamedTuple, Tdefaults<:NamedTuple, Tmissings <: Val}
2+
struct Model{S, Targs<:NamedTuple, Tmissings <: Val}
33
args::Targs
4-
defaults::Tdefaults
54
missings::Tmissings
65
end
76
@@ -11,27 +10,26 @@ argument in `args` with a value `missing` will be in `missings` by default. Howe
1110
non-traditional use-cases `missings` can be defined differently. All variables in
1211
`missings` are treated as random variables rather than observations.
1312
"""
14-
struct Model{S, Targs<:NamedTuple, Tdefaults<:NamedTuple, Tmissings<:Val} <: AbstractModel
13+
struct Model{S, Targs<:NamedTuple, Tmissings<:Val} <: AbstractModel
1514
args::Targs
16-
defaults::Tdefaults
1715
missings::Tmissings
1816
end
1917

20-
function Model{S}(args::NamedTuple, defaults::NamedTuple) where {S}
18+
function Model{S}(args::NamedTuple) where {S}
2119
missings = getmissing(args)
22-
Model{S, typeof(args), typeof(defaults), typeof(missings)}(args, defaults, missings)
20+
Model{S, typeof(args), typeof(missings)}(args, missings)
2321
end
2422

2523
(model::Model)(vi) = model(vi, SampleFromPrior())
2624
(model::Model)(vi, spl) = model(vi, spl, DefaultContext())
2725

2826

2927
"""
30-
getdefaults(::Type{<:Model})
28+
getdefaults(model)
3129
3230
Get a named tuple of the default argument values defined in a `Model` type.
3331
"""
34-
function getdefaults end
32+
getdefaults(model::Model) = getdefaults(typeof(model))
3533

3634

3735
getargtype(::Type{<:Model{S, Targs}}) where {S, Targs} = Targs

0 commit comments

Comments
 (0)