Skip to content

Commit d914dbe

Browse files
committed
Introduce ModelFunction type to dispatch on
1 parent 08561a6 commit d914dbe

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

src/compiler.jl

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ function build_model_info(input_expr)
206206
:vi => gensym(:vi),
207207
:sampler => gensym(:sampler),
208208
:model => gensym(:model),
209-
:inner_function => gensym(:inner_function),
209+
:model_function => gensym(modeldef[:name]),
210210
:defaults => gensym(:defaults)
211211
)
212212
)
@@ -319,6 +319,9 @@ function replace_tilde!(model_info)
319319
end
320320
""" |> Meta.parse |> eval
321321

322+
# """ Unbreak code highlighting in Emacs julia-mode
323+
324+
322325
"""
323326
generate_tilde(left, right, model_info)
324327
@@ -437,7 +440,7 @@ function build_output(model_info)
437440
vi = main_body_names[:vi]
438441
model = main_body_names[:model]
439442
sampler = main_body_names[:sampler]
440-
inner_function = main_body_names[:inner_function]
443+
model_function = main_body_names[:model_function]
441444

442445
# Arguments with default values
443446
args = model_info[:args]
@@ -481,28 +484,32 @@ function build_output(model_info)
481484
end
482485

483486
ex = quote
487+
function (::DynamicPPL.ModelFunction{$(QuoteNode(model_function))})(
488+
$vi::DynamicPPL.VarInfo,
489+
$sampler::DynamicPPL.AbstractSampler,
490+
$ctx::DynamicPPL.AbstractContext,
491+
$model
492+
)
493+
$unwrap_data_expr
494+
DynamicPPL.resetlogp!($vi)
495+
$main_body
496+
end
497+
498+
$model_function = DynamicPPL.ModelFunction{$(QuoteNode(model_function))}()
499+
484500
function $outer_function($(args...))
485-
function $inner_function(
486-
$vi::DynamicPPL.VarInfo,
487-
$sampler::DynamicPPL.AbstractSampler,
488-
$ctx::DynamicPPL.AbstractContext,
489-
$model
490-
)
491-
$unwrap_data_expr
492-
DynamicPPL.resetlogp!($vi)
493-
$main_body
494-
end
495-
return DynamicPPL.Model($inner_function, $args_nt, $model_gen_constructor)
501+
return DynamicPPL.Model($model_function, $args_nt, $model_gen_constructor)
496502
end
503+
497504
$model_gen = $model_gen_constructor
498505
end
499506

500507
if !isempty(args)
501-
ex = quote
502-
$ex
503-
# Allows passing arguments as kwargs
508+
# Allows passing arguments as kwargs
509+
kwform = quote
504510
$outer_function(;$(args...)) = $outer_function($(arg_syms...))
505511
end
512+
push!(ex.args, kwform)
506513
end
507514

508515
return esc(ex)

src/model.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
"""Tagged type for implementation of inner model functions that allow dispatch on their type."""
2+
struct ModelFunction{S} end
3+
4+
15
"""
2-
struct Model{F, Targs <: NamedTuple, Tmodelgen, Tmissings <: Val}
3-
f::F
6+
struct Model{S, Targs <: NamedTuple, Tmodelgen, Tmissings <: Val}
7+
f::ModelFunction{S}
48
args::Targs
59
modelgen::Tmodelgen
610
missings::Tmissings
@@ -12,13 +16,13 @@ argument in `args` with a value `missing` will be in `missings` by default. Howe
1216
non-traditional use-cases `missings` can be defined differently. All variables in
1317
`missings` are treated as random variables rather than observations.
1418
"""
15-
struct Model{F, Targs <: NamedTuple, Tmodelgen, Tmissings <: Val} <: AbstractModel
16-
f::F
19+
struct Model{S, Targs <: NamedTuple, Tmodelgen, Tmissings <: Val} <: AbstractModel
20+
f::ModelFunction{S}
1721
args::Targs
1822
modelgen::Tmodelgen
1923
missings::Tmissings
2024
end
21-
Model(f, args::NamedTuple, modelgen) = Model(f, args, modelgen, getmissing(args))
25+
Model(f::ModelFunction, args::NamedTuple, modelgen) = Model(f, args, modelgen, getmissing(args))
2226
(model::Model)(vi) = model(vi, SampleFromPrior())
2327
(model::Model)(vi, spl) = model(vi, spl, DefaultContext())
2428
(model::Model)(args...; kwargs...) = model.f(args..., model; kwargs...)

0 commit comments

Comments
 (0)