Skip to content

Commit 590780e

Browse files
committed
Merge branch 'master' into phg/split_tilde
2 parents 54b2977 + 534826c commit 590780e

File tree

8 files changed

+323
-242
lines changed

8 files changed

+323
-242
lines changed

src/DynamicPPL.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,12 @@ export VarName,
6161
vectorize,
6262
set_resume!,
6363
# Model
64+
ModelGen,
6465
Model,
65-
getmissing,
66+
getmissings,
67+
getargnames,
68+
getdefaults,
69+
getgenerator,
6670
runmodel!,
6771
# Samplers
6872
Sampler,
@@ -91,6 +95,11 @@ const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_DYNAMICPPL", "0")))
9195
# Used here and overloaded in Turing
9296
function getspace end
9397

98+
# Necessary forward declarations
99+
abstract type AbstractVarInfo end
100+
abstract type AbstractContext end
101+
102+
94103
include("utils.jl")
95104
include("selector.jl")
96105
include("model.jl")

src/compiler.jl

Lines changed: 38 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,3 @@
1-
"""
2-
struct ModelGen{Targs, F, Tdefaults} <: Function
3-
f::F
4-
defaults::Tdefaults
5-
end
6-
7-
A `Model` generator. This is the output of the `@model` macro. `Targs` is the tuple
8-
of the symbols of the model's arguments. `defaults` is the `NamedTuple` of default values
9-
of the arguments, if any. Every `ModelGen` is callable with the arguments `Targs`,
10-
returning an instance of `Model`.
11-
"""
12-
struct ModelGen{Targs, F, Tdefaults} <: Function
13-
f::F
14-
defaults::Tdefaults
15-
end
16-
ModelGen{Targs}(args...) where {Targs} = ModelGen{Targs, typeof.(args)...}(args...)
17-
(m::ModelGen)(args...; kwargs...) = m.f(args...; kwargs...)
18-
function Base.getproperty(m::ModelGen{Targs}, f::Symbol) where {Targs}
19-
f === :args && return Targs
20-
return Base.getfield(m, f)
21-
end
22-
231
macro varinfo()
242
:(throw(_error_msg()))
253
end
@@ -61,18 +39,18 @@ When `ex` is not a variable (e.g., a literal), the function returns `false` as w
6139
macro isassumption(data_vars, missing_vars, ex)
6240
:false
6341
end
64-
macro isassumption(data_vars, missing_vars, ex::Union{Symbol, Expr})
42+
macro isassumption(model, ex::Union{Symbol, Expr})
6543
sym = gensym(:sym)
6644
lhs = gensym(:lhs)
6745
return esc(quote
6846
# Extract symbol
6947
$sym = Val($(vsym(ex)))
7048
# This branch should compile nicely in all cases except for partial missing data
7149
# For example, when `ex` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
72-
if !DynamicPPL.inparams($sym, $data_vars) || DynamicPPL.inparams($sym, $missing_vars)
50+
if !DynamicPPL.inargnames($sym, $model) || DynamicPPL.inmissings($sym, $model)
7351
true
7452
else
75-
if DynamicPPL.inparams($sym, $data_vars)
53+
if DynamicPPL.inargnames($sym, $model)
7654
# Evaluate the lhs
7755
$lhs = $ex
7856
if $lhs === missing
@@ -87,10 +65,6 @@ macro isassumption(data_vars, missing_vars, ex::Union{Symbol, Expr})
8765
end)
8866
end
8967

90-
@generated function inparams(::Val{s}, ::Val{t}) where {s, t}
91-
return (s in t) ? :(true) : :(false)
92-
end
93-
9468
#################
9569
# Main Compiler #
9670
#################
@@ -152,7 +126,7 @@ function build_model_info(input_expr)
152126
else
153127
nt_type = Expr(:curly, :NamedTuple,
154128
Expr(:tuple, QuoteNode.(arg_syms)...),
155-
Expr(:curly, :Tuple, [:(DynamicPPL.get_type($x)) for x in arg_syms]...)
129+
Expr(:curly, :Tuple, [:(Core.Typeof($x)) for x in arg_syms]...)
156130
)
157131
args_nt = Expr(:call, :(DynamicPPL.namedtuple), nt_type, Expr(:tuple, arg_syms...))
158132
end
@@ -206,27 +180,13 @@ function build_model_info(input_expr)
206180
:ctx => gensym(:ctx),
207181
:vi => gensym(:vi),
208182
:sampler => gensym(:sampler),
209-
:model => gensym(:model),
210-
:inner_function => gensym(:inner_function),
211-
:defaults => gensym(:defaults)
183+
:model => gensym(:model)
212184
)
213185
)
214186

215187
return model_info
216188
end
217189

218-
function to_namedtuple_expr(syms::Vector, vals = syms)
219-
if length(syms) == 0
220-
nt = :(NamedTuple())
221-
else
222-
nt_type = Expr(:curly, :NamedTuple,
223-
Expr(:tuple, QuoteNode.(syms)...),
224-
Expr(:curly, :Tuple, [:(DynamicPPL.get_type($x)) for x in vals]...)
225-
)
226-
nt = Expr(:call, :(DynamicPPL.namedtuple), nt_type, Expr(:tuple, vals...))
227-
end
228-
return nt
229-
end
230190

231191
"""
232192
replace_vi!(model_info)
@@ -330,19 +290,16 @@ The `tilde` function generates `observe` expression for data variables and `assu
330290
expressions for parameter variables, updating `model_info` in the process.
331291
"""
332292
function generate_tilde(left, right, model_info)
333-
arg_syms = Val((model_info[:arg_syms]...,))
334293
model = model_info[:main_body_names][:model]
335294
vi = model_info[:main_body_names][:vi]
336295
ctx = model_info[:main_body_names][:ctx]
337296
sampler = model_info[:main_body_names][:sampler]
338-
339-
@gensym(out,
340-
lp,
341-
vn,
342-
inds,
343-
isassumption,
344-
temp_right)
345-
297+
temp_right = gensym(:temp_right)
298+
out = gensym(:out)
299+
lp = gensym(:lp)
300+
vn = gensym(:vn)
301+
inds = gensym(:inds)
302+
isassumption = gensym(:isassumption)
346303
assert_ex = :(DynamicPPL.assert_dist($temp_right, msg = $(wrong_dist_errormsg(@__LINE__))))
347304

348305
if left isa Symbol || left isa Expr
@@ -351,7 +308,7 @@ function generate_tilde(left, right, model_info)
351308
$assert_ex
352309

353310
$vn, $inds = $(varname(left)), $(vinds(left))
354-
$isassumption = DynamicPPL.@isassumption($arg_syms, DynamicPPL.getmissing($model), $left)
311+
$isassumption = DynamicPPL.@isassumption($model, $left)
355312
if $isassumption
356313
$out = DynamicPPL.tilde_assume($ctx, $sampler, $temp_right, $vn, $inds, $vi)
357314
$left = $out[1]
@@ -386,20 +343,16 @@ This function returns the expression that replaces `left .~ right` in the model
386343
will be run.
387344
"""
388345
function generate_dot_tilde(left, right, model_info)
389-
arg_syms = Val((model_info[:arg_syms]...,))
390346
model = model_info[:main_body_names][:model]
391347
vi = model_info[:main_body_names][:vi]
392348
ctx = model_info[:main_body_names][:ctx]
393349
sampler = model_info[:main_body_names][:sampler]
394-
395-
@gensym(out,
396-
preprocessed,
397-
lp,
398-
vn,
399-
inds,
400-
isassumption,
401-
temp_right)
402-
350+
out = gensym(:out)
351+
temp_right = gensym(:temp_right)
352+
isassumption = gensym(:isassumption)
353+
lp = gensym(:lp)
354+
vn = gensym(:vn)
355+
inds = gensym(:inds)
403356
assert_ex = :(DynamicPPL.assert_dist($temp_right, msg = $(wrong_dist_errormsg(@__LINE__))))
404357

405358
if left isa Symbol || left isa Expr
@@ -408,7 +361,7 @@ function generate_dot_tilde(left, right, model_info)
408361
$assert_ex
409362

410363
$vn, $inds = $(varname(left)), $(vinds(left))
411-
$isassumption = DynamicPPL.@isassumption($arg_syms, DynamicPPL.getmissing($model), $left)
364+
$isassumption = DynamicPPL.@isassumption($model, $left)
412365

413366
if $isassumption
414367
$out = DynamicPPL.dot_tilde_assume($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi)
@@ -453,7 +406,6 @@ function build_output(model_info)
453406
vi = main_body_names[:vi]
454407
model = main_body_names[:model]
455408
sampler = main_body_names[:sampler]
456-
inner_function = main_body_names[:inner_function]
457409

458410
# Arguments with default values
459411
args = model_info[:args]
@@ -468,16 +420,9 @@ function build_output(model_info)
468420
whereparams = model_info[:whereparams]
469421
# Model generator name
470422
model_gen = model_info[:name]
471-
# Outer function name
472-
outer_function = gensym(model_info[:name])
473423
# Main body of the model
474424
main_body = model_info[:main_body]
475-
model_gen_constructor = quote
476-
DynamicPPL.ModelGen{$(Tuple(arg_syms))}(
477-
$outer_function,
478-
$defaults_nt,
479-
)
480-
end
425+
481426
unwrap_data_expr = Expr(:block)
482427
for var in arg_syms
483428
temp_var = gensym(:temp_var)
@@ -496,40 +441,32 @@ function build_output(model_info)
496441
end)
497442
end
498443

444+
@gensym(evaluator, generator)
445+
generator_kw_form = isempty(args) ? () : (:($generator(;$(args...)) = $generator($(arg_syms...))),)
446+
model_gen_constructor = :(DynamicPPL.ModelGen{$(Tuple(arg_syms))}($generator, $defaults_nt))
447+
499448
ex = quote
500-
function $outer_function($(args...))
501-
function $inner_function(
502-
$vi::DynamicPPL.VarInfo,
503-
$sampler::DynamicPPL.AbstractSampler,
504-
$ctx::DynamicPPL.AbstractContext,
505-
$model
506-
)
507-
$unwrap_data_expr
508-
DynamicPPL.resetlogp!($vi)
509-
$main_body
510-
end
511-
return DynamicPPL.Model($inner_function, $args_nt, $model_gen_constructor)
449+
function $evaluator(
450+
$model::Model,
451+
$vi::DynamicPPL.VarInfo,
452+
$sampler::DynamicPPL.AbstractSampler,
453+
$ctx::DynamicPPL.AbstractContext,
454+
)
455+
$unwrap_data_expr
456+
DynamicPPL.resetlogp!($vi)
457+
$main_body
512458
end
513-
$model_gen = $model_gen_constructor
514-
end
459+
515460

516-
if !isempty(args)
517-
ex = quote
518-
$ex
519-
# Allows passing arguments as kwargs
520-
$outer_function(;$(args...)) = $outer_function($(arg_syms...))
521-
end
461+
$generator($(args...)) = DynamicPPL.Model($evaluator, $args_nt, $model_gen_constructor)
462+
$(generator_kw_form...)
463+
464+
$model_gen = $model_gen_constructor
522465
end
523466

524467
return esc(ex)
525468
end
526469

527-
# A hack for NamedTuple type specialization
528-
# (T = Int,) has type NamedTuple{(:T,), Tuple{DataType}} by default
529-
# With this function, we can make it NamedTuple{(:T,), Tuple{Type{Int}}}
530-
# Both are correct, but the latter is what we want for type stability
531-
get_type(::Type{T}) where {T} = Type{T}
532-
get_type(t) = typeof(t)
533470

534471
function warn_empty(body)
535472
if all(l -> isa(l, LineNumberNode), body.args)

src/contexts.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
abstract type AbstractContext end
2-
31
"""
42
struct DefaultContext <: AbstractContext end
53

0 commit comments

Comments
 (0)