Skip to content

Commit a0e2635

Browse files
committed
Update things to make prob macro work
1 parent 0f2e114 commit a0e2635

File tree

3 files changed

+63
-44
lines changed

3 files changed

+63
-44
lines changed

src/compiler.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -451,20 +451,18 @@ function build_output(model_info)
451451
end)
452452
end
453453

454-
kw_form = isempty(args) ? () : (:($model_gen(;$(args...)) = $model_gen($(arg_syms...))),)
454+
modelgen_kw_form = isempty(args) ? () : (:($model_gen(;$(args...)) = $model_gen($(arg_syms...))),)
455455
missings = gensym(:missings)
456456
args_tuple = gensym(:args_tuple)
457457

458458
ex = quote
459459
function $model_gen($(args...))
460460
$args_tuple = $args_nt
461461
$missings = DynamicPPL.getmissing($args_tuple)
462-
return DynamicPPL.Model{typeof($model_gen),
463-
typeof($args_tuple),
464-
typeof($missings)}($args_tuple, $missings)
462+
return DynamicPPL.Model{typeof($model_gen)}($args_tuple, $missings)
465463
end
466464

467-
$(kw_form...)
465+
$(modelgen_kw_form...)
468466

469467
function ($model::DynamicPPL.Model{typeof($model_gen)})(
470468
$vi::DynamicPPL.VarInfo,
@@ -477,6 +475,9 @@ function build_output(model_info)
477475
end
478476

479477
DynamicPPL.getdefaults(::typeof($model_gen)) = $defaults_nt
478+
DynamicPPL.getargtypes(::typeof($model_gen)) = $(Tuple(arg_syms))
479+
DynamicPPL.getmodeltype(::typeof($model_gen)) = DynamicPPL.Model{typeof($model_gen)}
480+
DynamicPPL.getgenerator(::DynamicPPL.Model{typeof($model_gen)}) = $model_gen
480481
end
481482

482483
return esc(ex)

src/model.jl

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,55 @@
11
"""
2-
struct Model{S, Targs<:NamedTuple, Tmissings <: Val}
2+
struct Model{G, Targs<:NamedTuple, Tmissings <: Val}
33
args::Targs
44
missings::Tmissings
55
end
66
7-
A `Model` struct with arguments `args`, model generator `modelgen` and
7+
A `Model` struct with arguments `args`, model generator type `G` and
88
missing data `missings`. `missings` is a `Val` instance, e.g. `Val{(:a, :b)}()`. An
99
argument in `args` with a value `missing` will be in `missings` by default. However, in
1010
non-traditional use-cases `missings` can be defined differently. All variables in
1111
`missings` are treated as random variables rather than observations.
1212
"""
13-
struct Model{S, Targs<:NamedTuple, Tmissings<:Val} <: AbstractModel
13+
struct Model{G, Targs<:NamedTuple, Tmissings<:Val} <: AbstractModel
1414
args::Targs
1515
missings::Tmissings
1616
end
1717

18+
Model{G}(args::Targs, missings::Tmissings) where {G, Targs, Tmissings} =
19+
Model{G, Targs, Tmissings}(args, missings)
20+
1821
(model::Model)(vi) = model(vi, SampleFromPrior())
1922
(model::Model)(vi, spl) = model(vi, spl, DefaultContext())
2023

2124

2225
"""
23-
getdefaults(model)
26+
getgenerator(model::Model)
27+
28+
Get the generator (the function defined by the `@model` macro) of a certain model instance.
29+
"""
30+
function getgenerator end
31+
32+
"""
33+
getmodeltype(::typeof(modelgen))
34+
35+
Get the associated model type for a model generating function.
36+
"""
37+
function getmodeltype end
38+
39+
"""
40+
getdefaults(::typeof(modelgen))
2441
25-
Get a named tuple of the default argument values defined in a `Model` type.
42+
Get a named tuple of the default argument values defined for a model defined by a generating function.
2643
"""
27-
getdefaults(model::Model) = getdefaults(typeof(model))
44+
function getdefaults end
2845

46+
"""
47+
getargtypes(::typeof(modelgen))
48+
49+
Get a named tuple of the argument
50+
"""
51+
function getargtypes end
2952

30-
getargtype(::Type{<:Model{S, Targs}}) where {S, Targs} = Targs
3153

3254
getmissing(model::Model) = model.missings
3355
@generated function getmissing(args::NamedTuple{names, ttuple}) where {names, ttuple}

src/prob_macro.jl

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@ macro prob_str(str)
88
end
99

1010
function get_exprs(str::String)
11-
ind = findfirst(isequal('|'), str)
12-
ind === nothing && throw("Invalid expression.")
11+
ind = findfirst(isequal('|'), str)
12+
ind === nothing && throw("Invalid expression.")
1313

14-
str1 = str[1:(ind - 1)]
15-
str2 = str[(ind + 1):end]
14+
str1 = str[1:(ind - 1)]
15+
str2 = str[(ind + 1):end]
1616

17-
expr1 = Meta.parse("($str1,)")
18-
expr1 = Expr(:tuple, expr1.args...)
17+
expr1 = Meta.parse("($str1,)")
18+
expr1 = Expr(:tuple, expr1.args...)
1919

20-
expr2 = Meta.parse("($str2,)")
21-
expr2 = Expr(:tuple, expr2.args...)
20+
expr2 = Meta.parse("($str2,)")
21+
expr2 = Expr(:tuple, expr2.args...)
2222

23-
return expr1, expr2
23+
return expr1, expr2
2424
end
2525

2626
function logprob(ex1, ex2)
@@ -37,7 +37,7 @@ function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {names
3737
if isdefined(ntr.chain.info, :model)
3838
model = ntr.chain.info.model
3939
@assert model isa Model
40-
modelgen = model.modelgen
40+
modelgen = getgenerator(model)
4141
elseif isdefined(ntr, :model)
4242
modelgen = ntr.model
4343
else
@@ -54,10 +54,10 @@ function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {names
5454
else
5555
vi = nothing
5656
end
57-
defaults = modelgen.defaults
57+
defaults = getdefaults(modelgen)
5858
valid_arg(arg) = isdefined(ntl, arg) || isdefined(ntr, arg) ||
5959
isdefined(defaults, arg) && getfield(defaults, arg) !== missing
60-
@assert all(valid_arg.(modelgen.args))
60+
@assert all(valid_arg.(getargtypes(modelgen)))
6161
return Val(:likelihood), modelgen, vi
6262
else
6363
@assert isdefined(ntr, :model)
@@ -69,15 +69,16 @@ function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {names
6969
else
7070
vi = nothing
7171
end
72-
return probtype(ntl, ntr, modelgen, modelgen.defaults), modelgen, vi
72+
return probtype(ntl, ntr, modelgen, getdefaults(modelgen)), modelgen, vi
7373
end
7474
end
7575
function probtype(
76-
ntl::NamedTuple{namesl},
77-
ntr::NamedTuple{namesr},
78-
modelgen::ModelGen{args},
79-
defaults::NamedTuple{defs},
80-
) where {namesl, namesr, args, defs}
76+
ntl::NamedTuple{namesl},
77+
ntr::NamedTuple{namesr},
78+
modelgen,
79+
defaults::NamedTuple{defs}
80+
) where {namesl, namesr, defs}
81+
args = getargtypes(modelgen)
8182
prior_rhs = all(n -> n in (:model, :varinfo) ||
8283
n in args && getfield(ntr, n) !== missing, namesr)
8384
function get_arg(arg)
@@ -120,7 +121,7 @@ missing_arg_error_msg(arg, ::Nothing) = """Variable $arg is not defined and has
120121
function logprior(
121122
left::NamedTuple,
122123
right::NamedTuple,
123-
modelgen::ModelGen,
124+
modelgen,
124125
_vi::Union{Nothing, VarInfo},
125126
)
126127
# For model args on the LHS of |, use their passed value but add the symbol to
@@ -135,8 +136,8 @@ function logprior(
135136
# All `observe` and `dot_observe` calls are no-op in the PriorContext
136137

137138
# When all of model args are on the lhs of |, this is also equal to the logjoint.
138-
args, missing_vars = get_prior_model_args(left, right, modelgen, modelgen.defaults)
139-
model = get_model(modelgen, args, missing_vars)
139+
args, missing_vars = get_prior_model_args(left, right, Val{getargtypes(modelgen)}(), getdefaults(modelgen))
140+
model = Model{typeof(modelgen)}(args, missing_vars)
140141
vi = _vi === nothing ? VarInfo(deepcopy(model), PriorContext()) : _vi
141142
foreach(keys(vi.metadata)) do n
142143
@assert n in keys(left) "Variable $n is not defined."
@@ -147,8 +148,8 @@ end
147148
@generated function get_prior_model_args(
148149
left::NamedTuple{namesl},
149150
right::NamedTuple{namesr},
150-
modelgen::ModelGen{args},
151-
defaults::NamedTuple{default_args},
151+
::Val{args},
152+
defaults::NamedTuple{default_args}
152153
) where {namesl, namesr, args, default_args}
153154
exprs = []
154155
missing_args = []
@@ -182,20 +183,15 @@ end
182183

183184
warn_msg(arg) = "Argument $arg is not defined. A value of `nothing` is used."
184185

185-
function get_model(modelgen, args, missing_vars)
186-
_model = modelgen(; args...)
187-
return Model(_model.f, args, modelgen, missing_vars)
188-
end
189-
190186
function loglikelihood(
191187
left::NamedTuple,
192188
right::NamedTuple,
193-
modelgen::ModelGen,
189+
modelgen,
194190
_vi::Union{Nothing, VarInfo},
195191
)
196192
# Pass namesl to model constructor, remaining args are missing
197-
args, missing_vars = get_like_model_args(left, right, modelgen, modelgen.defaults)
198-
model = get_model(modelgen, args, missing_vars)
193+
args, missing_vars = get_like_model_args(left, right, Val{getargtypes(modelgen)}(), getdefaults(modelgen))
194+
model = Model{typeof(modelgen)}(args, missing_vars)
199195
vi = _vi === nothing ? VarInfo(deepcopy(model)) : _vi
200196
if isdefined(right, :chain)
201197
# Element-wise likelihood for each value in chain
@@ -218,7 +214,7 @@ end
218214
@generated function get_like_model_args(
219215
left::NamedTuple{namesl},
220216
right::NamedTuple{namesr},
221-
modelgen::ModelGen{args},
217+
::Val{args},
222218
defaults::NamedTuple{default_args},
223219
) where {namesl, namesr, args, default_args}
224220
exprs = []

0 commit comments

Comments
 (0)