Skip to content

Commit c9394fd

Browse files
committed
Restructure Model type and usage of its type arguments, esp. missings
1 parent 64de286 commit c9394fd

File tree

4 files changed

+141
-133
lines changed

4 files changed

+141
-133
lines changed

src/DynamicPPL.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ export VarName,
6262
set_resume!,
6363
# Model
6464
Model,
65-
getmissing,
65+
getmissings,
66+
getargnames,
67+
getdefaults,
68+
getgenerator,
69+
getmodeltype,
6670
runmodel!,
6771
# Samplers
6872
Sampler,

src/compiler.jl

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ function generate_tilde(left, right, model_info)
321321
ex = quote
322322
$temp_right = $right
323323
$assert_ex
324-
$preprocessed = DynamicPPL.@preprocess($arg_syms, DynamicPPL.getmissing($model), $left)
324+
$preprocessed = DynamicPPL.@preprocess($arg_syms, Val{DynamicPPL.getmissings($model)}(), $left)
325325
if $preprocessed isa Tuple
326326
$vn, $inds = $preprocessed
327327
$out = DynamicPPL.tilde($ctx, $sampler, $temp_right, $vn, $inds, $vi)
@@ -370,7 +370,7 @@ function generate_dot_tilde(left, right, model_info)
370370
ex = quote
371371
$temp_right = $right
372372
$assert_ex
373-
$preprocessed = DynamicPPL.@preprocess($arg_syms, DynamicPPL.getmissing($model), $left)
373+
$preprocessed = DynamicPPL.@preprocess($arg_syms, Val{DynamicPPL.getmissings($model)}(), $left)
374374
if $preprocessed isa Tuple
375375
$vn, $inds = $preprocessed
376376
$temp_left = $left
@@ -452,19 +452,13 @@ function build_output(model_info)
452452
end
453453

454454
modelgen_kw_form = isempty(args) ? () : (:($model_gen(;$(args...)) = $model_gen($(arg_syms...))),)
455-
missings = gensym(:missings)
456-
args_tuple = gensym(:args_tuple)
455+
model_type = :(DynamicPPL.Model{typeof($model_gen), $(Tuple(arg_syms))})
457456

458457
ex = quote
459-
function $model_gen($(args...))
460-
$args_tuple = $args_nt
461-
$missings = DynamicPPL.getmissing($args_tuple)
462-
return DynamicPPL.Model{typeof($model_gen)}($args_tuple, $missings)
463-
end
464-
458+
$model_gen($(args...)) = DynamicPPL.Model{typeof($model_gen)}($args_nt)
465459
$(modelgen_kw_form...)
466460

467-
function ($model::DynamicPPL.Model{typeof($model_gen)})(
461+
function ($model::$model_type)(
468462
$vi::DynamicPPL.VarInfo,
469463
$sampler::DynamicPPL.AbstractSampler,
470464
$ctx::DynamicPPL.AbstractContext,
@@ -474,10 +468,9 @@ function build_output(model_info)
474468
$main_body
475469
end
476470

477-
DynamicPPL.getdefaults(::typeof($model_gen)) = $defaults_nt
478-
DynamicPPL.getargnames(::typeof($model_gen)) = $(Tuple(arg_syms))
479-
DynamicPPL.getmodeltype(::typeof($model_gen)) = DynamicPPL.Model{typeof($model_gen)}
480-
DynamicPPL.getgenerator(::DynamicPPL.Model{typeof($model_gen)}) = $model_gen
471+
DynamicPPL.getmodeltype(::typeof($model_gen)) = $model_type
472+
DynamicPPL.getgenerator(::Type{<:$model_type}) = $model_gen
473+
DynamicPPL.getdefaults(::Type{<:$model_type}) = $defaults_nt
481474
end
482475

483476
return esc(ex)

src/model.jl

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,80 @@
11
"""
2-
struct Model{G, Targs<:NamedTuple, Tmissings <: Val}
3-
args::Targs
4-
missings::Tmissings
2+
struct Model{G, argnames, missings, Targs}
3+
args::NamedTuple{argnames, Targs}
54
end
65
7-
A `Model` struct with arguments `args`, model generator type `G` and
8-
missing data `missings`. `missings` is a `Val` instance, e.g. `Val{(:a, :b)}()`. An
9-
argument in `args` with a value `missing` will be in `missings` by default. However, in
10-
non-traditional use-cases `missings` can be defined differently. All variables in
11-
`missings` are treated as random variables rather than observations.
6+
A `Model` struct with model generator type `G`, arguments names `argnames`, arguments types `Targs`,
7+
and missing arguments `missings`. `argnames` and `missings` are tuples of symbols, e.g. `(:a,
8+
:b)`. An argument with a type of `Missing` will be in `missings` by default. However, in
9+
non-traditional use-cases `missings` can be defined differently. All variables in `missings` are
10+
treated as random variables rather than observations.
11+
12+
# Example
13+
14+
```julia
15+
julia> Model{typeof(gdemo)}((x = 1.0, y = 2.0))
16+
Model{typeof(gdemo),(),(:x, :y),Tuple{Float64,Float64}}((x = 1.0, y = 2.0))
17+
18+
julia> Model{typeof(gdemo), (:y,)}((x = 1.0, y = 2.0))
19+
Model{typeof(gdemo),(:y,),(:x, :y),Tuple{Float64,Float64}}((x = 1.0, y = 2.0))
20+
```
1221
"""
13-
struct Model{G, Targs<:NamedTuple, Tmissings<:Val} <: AbstractModel
14-
args::Targs
15-
missings::Tmissings
22+
struct Model{G, argnames, missings, Targs} <: AbstractModel
23+
args::NamedTuple{argnames, Targs}
24+
25+
Model{G, missings}(args::NamedTuple{argnames, Targs}) where {G, argnames, missings, Targs} =
26+
new{G, argnames, missings, Targs}(args)
27+
end
28+
29+
@generated function Model{G}(args::NamedTuple{argnames, Targs}) where {G, argnames, Targs}
30+
missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing)
31+
return :(Model{G, $missings}(args))
1632
end
1733

18-
Model{G}(args::Targs, missings::Tmissings) where {G, Targs, Tmissings} =
19-
Model{G, Targs, Tmissings}(args, missings)
2034

2135
(model::Model)(vi) = model(vi, SampleFromPrior())
2236
(model::Model)(vi, spl) = model(vi, spl, DefaultContext())
2337

2438

2539
"""
26-
getgenerator(model::Model)
40+
getargnames(model::Model)
2741
28-
Get the generator (the function defined by the `@model` macro) of a certain model instance.
42+
Get a tuple of the argument names of the `model`.
2943
"""
30-
function getgenerator end
44+
getargnames(model::Model) = getargnames(typeof(model))
45+
getargnames(::Type{<:Model{_G, argnames} where {_G}}) where {argnames} = argnames
46+
3147

3248
"""
33-
getmodeltype(::typeof(modelgen))
49+
getmissings(model::Model)
3450
35-
Get the associated model type for a model generating function.
51+
Get a tuple of the names of the missing arguments of the `model`.
3652
"""
37-
function getmodeltype end
53+
getmissings(model::Model{_G, _a, missings}) where {missings, _G, _a} = missings
54+
55+
getmissing(model::Model) = getmissings(model)
56+
@deprecate getmissing(model) getmissings(model)
57+
3858

3959
"""
40-
getdefaults(::typeof(modelgen))
60+
getgenerator(model::Model)
4161
42-
Get a named tuple of the default argument values defined for a model defined by a generating function.
62+
Get the generator (the function defined by the `@model` macro) of a certain model instance.
4363
"""
44-
function getdefaults end
64+
getgenerator(model::Model) = getgenerator(typeof(model))
65+
4566

4667
"""
47-
getargnames(::typeof(modelgen))
68+
getdefaults(model::Model)
4869
49-
Get a tuple of the argument names of the model defined by a generating function.
70+
Get a named tuple of the default argument values defined for a model defined by a generating function.
5071
"""
51-
function getargnames end
72+
getdefaults(model::Model) = getdefaults(typeof(model))
5273

5374

54-
getmissing(model::Model) = model.missings
55-
@generated function getmissing(args::NamedTuple{names, ttuple}) where {names, ttuple}
56-
length(names) == 0 && return :(Val{()}())
57-
minds = filter(1:length(names)) do i
58-
ttuple.types[i] == Missing
59-
end
60-
mnames = names[minds]
61-
return :(Val{$mnames}())
62-
end
75+
"""
76+
getmodeltype(::typeof(modelgen))
77+
78+
Get the associated model type for a model generator (the function defined by the `@model` macro).
79+
"""
80+
getmodeltype(model::Model) = typeof(model)

0 commit comments

Comments
 (0)