Skip to content

Commit 06d3828

Browse files
committed
Get rid of ModelGen and nested functions
1 parent d914dbe commit 06d3828

File tree

2 files changed

+35
-64
lines changed

2 files changed

+35
-64
lines changed

src/compiler.jl

Lines changed: 13 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,3 @@
1-
"""
2-
struct ModelGen{Targs, F, Tdefaults} <: Function
3-
f::F
4-
defaults::Tdefaults
5-
end
6-
7-
A `Model` generator. This is the output of the `@model` macro. `Targs` is the tuple
8-
of the symbols of the model's arguments. `defaults` is the `NamedTuple` of default values
9-
of the arguments, if any. Every `ModelGen` is callable with the arguments `Targs`,
10-
returning an instance of `Model`.
11-
"""
12-
struct ModelGen{Targs, F, Tdefaults} <: Function
13-
f::F
14-
defaults::Tdefaults
15-
end
16-
ModelGen{Targs}(args...) where {Targs} = ModelGen{Targs, typeof.(args)...}(args...)
17-
(m::ModelGen)(args...; kwargs...) = m.f(args...; kwargs...)
18-
function Base.getproperty(m::ModelGen{Targs}, f::Symbol) where {Targs}
19-
f === :args && return Targs
20-
return Base.getfield(m, f)
21-
end
22-
231
macro varinfo()
242
:(throw(_error_msg()))
253
end
@@ -151,7 +129,7 @@ function build_model_info(input_expr)
151129
else
152130
nt_type = Expr(:curly, :NamedTuple,
153131
Expr(:tuple, QuoteNode.(arg_syms)...),
154-
Expr(:curly, :Tuple, [:(DynamicPPL.get_type($x)) for x in arg_syms]...)
132+
Expr(:curly, :Tuple, [:(Core.Typeof($x)) for x in arg_syms]...)
155133
)
156134
args_nt = Expr(:call, :(DynamicPPL.namedtuple), nt_type, Expr(:tuple, arg_syms...))
157135
end
@@ -220,7 +198,7 @@ function to_namedtuple_expr(syms::Vector, vals = syms)
220198
else
221199
nt_type = Expr(:curly, :NamedTuple,
222200
Expr(:tuple, QuoteNode.(syms)...),
223-
Expr(:curly, :Tuple, [:(DynamicPPL.get_type($x)) for x in vals]...)
201+
Expr(:curly, :Tuple, [:(Core.Typeof($x)) for x in vals]...)
224202
)
225203
nt = Expr(:call, :(DynamicPPL.namedtuple), nt_type, Expr(:tuple, vals...))
226204
end
@@ -440,7 +418,6 @@ function build_output(model_info)
440418
vi = main_body_names[:vi]
441419
model = main_body_names[:model]
442420
sampler = main_body_names[:sampler]
443-
model_function = main_body_names[:model_function]
444421

445422
# Arguments with default values
446423
args = model_info[:args]
@@ -455,16 +432,9 @@ function build_output(model_info)
455432
whereparams = model_info[:whereparams]
456433
# Model generator name
457434
model_gen = model_info[:name]
458-
# Outer function name
459-
outer_function = gensym(model_info[:name])
460435
# Main body of the model
461436
main_body = model_info[:main_body]
462-
model_gen_constructor = quote
463-
DynamicPPL.ModelGen{$(Tuple(arg_syms))}(
464-
$outer_function,
465-
$defaults_nt,
466-
)
467-
end
437+
468438
unwrap_data_expr = Expr(:block)
469439
for var in arg_syms
470440
temp_var = gensym(:temp_var)
@@ -483,44 +453,35 @@ function build_output(model_info)
483453
end)
484454
end
485455

456+
model_type = :(DynamicPPL.Model{$(QuoteNode(model_function))})
457+
defaults = gensym(:defaults)
458+
486459
ex = quote
487-
function (::DynamicPPL.ModelFunction{$(QuoteNode(model_function))})(
460+
function ($model::$model_type)(
488461
$vi::DynamicPPL.VarInfo,
489462
$sampler::DynamicPPL.AbstractSampler,
490463
$ctx::DynamicPPL.AbstractContext,
491-
$model
492464
)
493465
$unwrap_data_expr
494466
DynamicPPL.resetlogp!($vi)
495467
$main_body
496468
end
497-
498-
$model_function = DynamicPPL.ModelFunction{$(QuoteNode(model_function))}()
499-
500-
function $outer_function($(args...))
501-
return DynamicPPL.Model($model_function, $args_nt, $model_gen_constructor)
502-
end
503-
504-
$model_gen = $model_gen_constructor
469+
470+
$defaults = $defaults_nt
471+
getdefaults(::$model_type) = $defaults
472+
$model_gen($(args...)) = $model_type($args_nt, $defaults)
505473
end
506474

507475
if !isempty(args)
508476
# Allows passing arguments as kwargs
509-
kwform = quote
510-
$outer_function(;$(args...)) = $outer_function($(arg_syms...))
511-
end
477+
kwform = :($model_gen(;$(args...)) = $model_gen($(arg_syms...)))
512478
push!(ex.args, kwform)
513479
end
514480

481+
515482
return esc(ex)
516483
end
517484

518-
# A hack for NamedTuple type specialization
519-
# (T = Int,) has type NamedTuple{(:T,), Tuple{DataType}} by default
520-
# With this function, we can make it NamedTuple{(:T,), Tuple{Type{Int}}}
521-
# Both are correct, but the latter is what we want for type stability
522-
get_type(::Type{T}) where {T} = Type{T}
523-
get_type(t) = typeof(t)
524485

525486
function warn_empty(body)
526487
if all(l -> isa(l, LineNumberNode), body.args)

src/model.jl

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,41 @@
1-
"""Tagged type for implementation of inner model functions that allow dispatch on their type."""
2-
struct ModelFunction{S} end
3-
4-
51
"""
6-
struct Model{S, Targs <: NamedTuple, Tmodelgen, Tmissings <: Val}
2+
struct Model{S, Targs<:NamedTuple, Tdefaults<:NamedTuple, Tmissings <: Val}
73
f::ModelFunction{S}
84
args::Targs
9-
modelgen::Tmodelgen
5+
defaults::Tmodelgen
106
missings::Tmissings
117
end
128
13-
A `Model` struct with arguments `args`, inner function `f`, model generator `modelgen` and
9+
A `Model` struct with arguments `args`, model generator `modelgen` and
1410
missing data `missings`. `missings` is a `Val` instance, e.g. `Val{(:a, :b)}()`. An
1511
argument in `args` with a value `missing` will be in `missings` by default. However, in
1612
non-traditional use-cases `missings` can be defined differently. All variables in
1713
`missings` are treated as random variables rather than observations.
1814
"""
19-
struct Model{S, Targs <: NamedTuple, Tmodelgen, Tmissings <: Val} <: AbstractModel
20-
f::ModelFunction{S}
15+
struct Model{S, Targs<:NamedTuple, Tdefaults<:NamedTuple, Tmissings<:Val} <: AbstractModel
2116
args::Targs
22-
modelgen::Tmodelgen
17+
defaults::Tdefaults
2318
missings::Tmissings
2419
end
25-
Model(f::ModelFunction, args::NamedTuple, modelgen) = Model(f, args, modelgen, getmissing(args))
20+
21+
function Model{S}(args::NamedTuple, defaults::NamedTuple) where {S}
22+
missings = getmissing(args)
23+
Model{S, typeof(args), typeof(defaults), typeof(missings)}(args, defaults, missings)
24+
end
25+
2626
(model::Model)(vi) = model(vi, SampleFromPrior())
2727
(model::Model)(vi, spl) = model(vi, spl, DefaultContext())
28-
(model::Model)(args...; kwargs...) = model.f(args..., model; kwargs...)
28+
29+
30+
"""
31+
getdefaults(::Type{<:Model})
32+
33+
Get a named tuple of the default argument values defined in a `Model` type.
34+
"""
35+
function getdefaults end
36+
37+
38+
getargtype(::Type{<:Model{S, Targs}}) where {S, Targs} = Targs
2939

3040
getmissing(model::Model) = model.missings
3141
@generated function getmissing(args::NamedTuple{names, ttuple}) where {names, ttuple}

0 commit comments

Comments
 (0)