Skip to content

Commit 534826c

Browse files
authored
Merge pull request #51 from phipsgabler/phg/abstractmodelfunction
Refactor ModelGen and nested closures
2 parents 08561a6 + fab187b commit 534826c

File tree

8 files changed

+315
-224
lines changed

8 files changed

+315
-224
lines changed

src/DynamicPPL.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,12 @@ export VarName,
6161
vectorize,
6262
set_resume!,
6363
# Model
64+
ModelGen,
6465
Model,
65-
getmissing,
66+
getmissings,
67+
getargnames,
68+
getdefaults,
69+
getgenerator,
6670
runmodel!,
6771
# Samplers
6872
Sampler,
@@ -91,6 +95,11 @@ const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_DYNAMICPPL", "0")))
9195
# Used here and overloaded in Turing
9296
function getspace end
9397

98+
# Necessary forward declarations
99+
abstract type AbstractVarInfo end
100+
abstract type AbstractContext end
101+
102+
94103
include("utils.jl")
95104
include("selector.jl")
96105
include("model.jl")

src/compiler.jl

Lines changed: 30 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,3 @@
1-
"""
2-
struct ModelGen{Targs, F, Tdefaults} <: Function
3-
f::F
4-
defaults::Tdefaults
5-
end
6-
7-
A `Model` generator. This is the output of the `@model` macro. `Targs` is the tuple
8-
of the symbols of the model's arguments. `defaults` is the `NamedTuple` of default values
9-
of the arguments, if any. Every `ModelGen` is callable with the arguments `Targs`,
10-
returning an instance of `Model`.
11-
"""
12-
struct ModelGen{Targs, F, Tdefaults} <: Function
13-
f::F
14-
defaults::Tdefaults
15-
end
16-
ModelGen{Targs}(args...) where {Targs} = ModelGen{Targs, typeof.(args)...}(args...)
17-
(m::ModelGen)(args...; kwargs...) = m.f(args...; kwargs...)
18-
function Base.getproperty(m::ModelGen{Targs}, f::Symbol) where {Targs}
19-
f === :args && return Targs
20-
return Base.getfield(m, f)
21-
end
22-
231
macro varinfo()
242
:(throw(_error_msg()))
253
end
@@ -61,18 +39,18 @@ Otherwise, the value of `x[1]` is returned.
6139
macro preprocess(data_vars, missing_vars, ex)
6240
ex
6341
end
64-
macro preprocess(data_vars, missing_vars, ex::Union{Symbol, Expr})
42+
macro preprocess(model, ex::Union{Symbol, Expr})
6543
sym = gensym(:sym)
6644
lhs = gensym(:lhs)
6745
return esc(quote
6846
# Extract symbol
6947
$sym = Val($(vsym(ex)))
7048
# This branch should compile nicely in all cases except for partial missing data
7149
# For example, when `ex` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
72-
if !DynamicPPL.inparams($sym, $data_vars) || DynamicPPL.inparams($sym, $missing_vars)
50+
if !DynamicPPL.inargnames($sym, $model) || DynamicPPL.inmissings($sym, $model)
7351
$(varname(ex)), $(vinds(ex))
7452
else
75-
if DynamicPPL.inparams($sym, $data_vars)
53+
if DynamicPPL.inargnames($sym, $model)
7654
# Evaluate the lhs
7755
$lhs = $ex
7856
if $lhs === missing
@@ -86,9 +64,7 @@ macro preprocess(data_vars, missing_vars, ex::Union{Symbol, Expr})
8664
end
8765
end)
8866
end
89-
@generated function inparams(::Val{s}, ::Val{t}) where {s, t}
90-
return (s in t) ? :(true) : :(false)
91-
end
67+
9268

9369
#################
9470
# Main Compiler #
@@ -151,7 +127,7 @@ function build_model_info(input_expr)
151127
else
152128
nt_type = Expr(:curly, :NamedTuple,
153129
Expr(:tuple, QuoteNode.(arg_syms)...),
154-
Expr(:curly, :Tuple, [:(DynamicPPL.get_type($x)) for x in arg_syms]...)
130+
Expr(:curly, :Tuple, [:(Core.Typeof($x)) for x in arg_syms]...)
155131
)
156132
args_nt = Expr(:call, :(DynamicPPL.namedtuple), nt_type, Expr(:tuple, arg_syms...))
157133
end
@@ -205,27 +181,13 @@ function build_model_info(input_expr)
205181
:ctx => gensym(:ctx),
206182
:vi => gensym(:vi),
207183
:sampler => gensym(:sampler),
208-
:model => gensym(:model),
209-
:inner_function => gensym(:inner_function),
210-
:defaults => gensym(:defaults)
184+
:model => gensym(:model)
211185
)
212186
)
213187

214188
return model_info
215189
end
216190

217-
function to_namedtuple_expr(syms::Vector, vals = syms)
218-
if length(syms) == 0
219-
nt = :(NamedTuple())
220-
else
221-
nt_type = Expr(:curly, :NamedTuple,
222-
Expr(:tuple, QuoteNode.(syms)...),
223-
Expr(:curly, :Tuple, [:(DynamicPPL.get_type($x)) for x in vals]...)
224-
)
225-
nt = Expr(:call, :(DynamicPPL.namedtuple), nt_type, Expr(:tuple, vals...))
226-
end
227-
return nt
228-
end
229191

230192
"""
231193
replace_vi!(model_info)
@@ -319,14 +281,16 @@ function replace_tilde!(model_info)
319281
end
320282
""" |> Meta.parse |> eval
321283

284+
# """ Unbreak code highlighting in Emacs julia-mode
285+
286+
322287
"""
323288
generate_tilde(left, right, model_info)
324289
325290
The `tilde` function generates `observe` expression for data variables and `assume`
326291
expressions for parameter variables, updating `model_info` in the process.
327292
"""
328293
function generate_tilde(left, right, model_info)
329-
arg_syms = Val((model_info[:arg_syms]...,))
330294
model = model_info[:main_body_names][:model]
331295
vi = model_info[:main_body_names][:vi]
332296
ctx = model_info[:main_body_names][:ctx]
@@ -342,7 +306,7 @@ function generate_tilde(left, right, model_info)
342306
ex = quote
343307
$temp_right = $right
344308
$assert_ex
345-
$preprocessed = DynamicPPL.@preprocess($arg_syms, DynamicPPL.getmissing($model), $left)
309+
$preprocessed = DynamicPPL.@preprocess($model, $left)
346310
if $preprocessed isa Tuple
347311
$vn, $inds = $preprocessed
348312
$out = DynamicPPL.tilde($ctx, $sampler, $temp_right, $vn, $inds, $vi)
@@ -374,7 +338,6 @@ end
374338
This function returns the expression that replaces `left .~ right` in the model body. If `preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block will be run.
375339
"""
376340
function generate_dot_tilde(left, right, model_info)
377-
arg_syms = Val((model_info[:arg_syms]...,))
378341
model = model_info[:main_body_names][:model]
379342
vi = model_info[:main_body_names][:vi]
380343
ctx = model_info[:main_body_names][:ctx]
@@ -391,7 +354,7 @@ function generate_dot_tilde(left, right, model_info)
391354
ex = quote
392355
$temp_right = $right
393356
$assert_ex
394-
$preprocessed = DynamicPPL.@preprocess($arg_syms, DynamicPPL.getmissing($model), $left)
357+
$preprocessed = DynamicPPL.@preprocess($model, $left)
395358
if $preprocessed isa Tuple
396359
$vn, $inds = $preprocessed
397360
$temp_left = $left
@@ -437,7 +400,6 @@ function build_output(model_info)
437400
vi = main_body_names[:vi]
438401
model = main_body_names[:model]
439402
sampler = main_body_names[:sampler]
440-
inner_function = main_body_names[:inner_function]
441403

442404
# Arguments with default values
443405
args = model_info[:args]
@@ -452,16 +414,9 @@ function build_output(model_info)
452414
whereparams = model_info[:whereparams]
453415
# Model generator name
454416
model_gen = model_info[:name]
455-
# Outer function name
456-
outer_function = gensym(model_info[:name])
457417
# Main body of the model
458418
main_body = model_info[:main_body]
459-
model_gen_constructor = quote
460-
DynamicPPL.ModelGen{$(Tuple(arg_syms))}(
461-
$outer_function,
462-
$defaults_nt,
463-
)
464-
end
419+
465420
unwrap_data_expr = Expr(:block)
466421
for var in arg_syms
467422
temp_var = gensym(:temp_var)
@@ -480,40 +435,32 @@ function build_output(model_info)
480435
end)
481436
end
482437

438+
@gensym(evaluator, generator)
439+
generator_kw_form = isempty(args) ? () : (:($generator(;$(args...)) = $generator($(arg_syms...))),)
440+
model_gen_constructor = :(DynamicPPL.ModelGen{$(Tuple(arg_syms))}($generator, $defaults_nt))
441+
483442
ex = quote
484-
function $outer_function($(args...))
485-
function $inner_function(
486-
$vi::DynamicPPL.VarInfo,
487-
$sampler::DynamicPPL.AbstractSampler,
488-
$ctx::DynamicPPL.AbstractContext,
489-
$model
490-
)
491-
$unwrap_data_expr
492-
DynamicPPL.resetlogp!($vi)
493-
$main_body
494-
end
495-
return DynamicPPL.Model($inner_function, $args_nt, $model_gen_constructor)
443+
function $evaluator(
444+
$model::Model,
445+
$vi::DynamicPPL.VarInfo,
446+
$sampler::DynamicPPL.AbstractSampler,
447+
$ctx::DynamicPPL.AbstractContext,
448+
)
449+
$unwrap_data_expr
450+
DynamicPPL.resetlogp!($vi)
451+
$main_body
496452
end
497-
$model_gen = $model_gen_constructor
498-
end
453+
499454

500-
if !isempty(args)
501-
ex = quote
502-
$ex
503-
# Allows passing arguments as kwargs
504-
$outer_function(;$(args...)) = $outer_function($(arg_syms...))
505-
end
455+
$generator($(args...)) = DynamicPPL.Model($evaluator, $args_nt, $model_gen_constructor)
456+
$(generator_kw_form...)
457+
458+
$model_gen = $model_gen_constructor
506459
end
507460

508461
return esc(ex)
509462
end
510463

511-
# A hack for NamedTuple type specialization
512-
# (T = Int,) has type NamedTuple{(:T,), Tuple{DataType}} by default
513-
# With this function, we can make it NamedTuple{(:T,), Tuple{Type{Int}}}
514-
# Both are correct, but the latter is what we want for type stability
515-
get_type(::Type{T}) where {T} = Type{T}
516-
get_type(t) = typeof(t)
517464

518465
function warn_empty(body)
519466
if all(l -> isa(l, LineNumberNode), body.args)

src/contexts.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
abstract type AbstractContext end
2-
31
"""
42
struct DefaultContext <: AbstractContext end
53

0 commit comments

Comments
 (0)