Skip to content

Commit fab187b

Browse files
committed
All the beauty gone (reintroduce ModelGen and coupling)
1 parent 0ec725d commit fab187b

File tree

6 files changed

+172
-105
lines changed

6 files changed

+172
-105
lines changed

src/DynamicPPL.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ export VarName,
6161
vectorize,
6262
set_resume!,
6363
# Model
64+
ModelGen,
6465
Model,
6566
getmissings,
6667
getargnames,
6768
getdefaults,
6869
getgenerator,
69-
getmodeltype,
7070
runmodel!,
7171
# Samplers
7272
Sampler,
@@ -95,6 +95,11 @@ const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_DYNAMICPPL", "0")))
9595
# Used here and overloaded in Turing
9696
function getspace end
9797

98+
# Necessary forward declarations
99+
abstract type AbstractVarInfo end
100+
abstract type AbstractContext end
101+
102+
98103
include("utils.jl")
99104
include("selector.jl")
100105
include("model.jl")

src/compiler.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -435,14 +435,13 @@ function build_output(model_info)
435435
end)
436436
end
437437

438-
modelgen_kw_form = isempty(args) ? () : (:($model_gen(;$(args...)) = $model_gen($(arg_syms...))),)
439-
model_type = :(DynamicPPL.Model{typeof($model_gen), $(Tuple(arg_syms))})
438+
@gensym(evaluator, generator)
439+
generator_kw_form = isempty(args) ? () : (:($generator(;$(args...)) = $generator($(arg_syms...))),)
440+
model_gen_constructor = :(DynamicPPL.ModelGen{$(Tuple(arg_syms))}($generator, $defaults_nt))
440441

441442
ex = quote
442-
$model_gen($(args...)) = DynamicPPL.Model{typeof($model_gen)}($args_nt)
443-
$(modelgen_kw_form...)
444-
445-
function ($model::$model_type)(
443+
function $evaluator(
444+
$model::Model,
446445
$vi::DynamicPPL.VarInfo,
447446
$sampler::DynamicPPL.AbstractSampler,
448447
$ctx::DynamicPPL.AbstractContext,
@@ -452,9 +451,11 @@ function build_output(model_info)
452451
$main_body
453452
end
454453

455-
DynamicPPL.getmodeltype(::typeof($model_gen)) = $model_type
456-
DynamicPPL.getgenerator(::Type{<:$model_type}) = $model_gen
457-
DynamicPPL.getdefaults(::Type{<:$model_type}) = $defaults_nt
454+
455+
$generator($(args...)) = DynamicPPL.Model($evaluator, $args_nt, $model_gen_constructor)
456+
$(generator_kw_form...)
457+
458+
$model_gen = $model_gen_constructor
458459
end
459460

460461
return esc(ex)

src/contexts.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
abstract type AbstractContext end
2-
31
"""
42
struct DefaultContext <: AbstractContext end
53

src/model.jl

Lines changed: 128 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,153 @@
11
"""
2-
struct Model{G, argnames, missings, Targs}
2+
struct ModelGen{G, defaultnames, Tdefaults}
3+
generator::G
4+
defaults::Tdefaults
5+
end
6+
7+
A `ModelGen` struct with model generator function of type `G`, and default arguments `defaultnames`
8+
with values `Tdefaults`.
9+
"""
10+
struct ModelGen{G, argnames, defaultnames, Tdefaults}
11+
generator::G
12+
defaults::NamedTuple{defaultnames, Tdefaults}
13+
14+
function ModelGen{argnames}(
15+
generator::G,
16+
defaults::NamedTuple{defaultnames, Tdefaults}
17+
) where {G, argnames, defaultnames, Tdefaults}
18+
return new{G, argnames, defaultnames, Tdefaults}(generator, defaults)
19+
end
20+
end
21+
22+
(m::ModelGen)(args...; kwargs...) = m.generator(args...; kwargs...)
23+
24+
25+
"""
26+
getdefaults(modelgen::ModelGen)
27+
28+
Get a named tuple of the default argument values defined for a model defined by a generating function.
29+
"""
30+
getdefaults(modelgen::ModelGen) = modelgen.defaults
31+
32+
"""
33+
getargnames(modelgen::ModelGen)
34+
35+
Get a tuple of the argument names of the `modelgen`.
36+
"""
37+
getargnames(model::ModelGen{_G, argnames}) where {argnames, _G} = argnames
38+
39+
40+
41+
"""
42+
struct Model{F, argnames, Targs, missings}
43+
f::F
344
args::NamedTuple{argnames, Targs}
45+
modelgen::Tgen
446
end
547
6-
A `Model` struct with model generator type `G`, arguments names `argnames`, arguments types `Targs`,
7-
and missing arguments `missings`. `argnames` and `missings` are tuples of symbols, e.g. `(:a,
8-
:b)`. An argument with a type of `Missing` will be in `missings` by default. However, in
9-
non-traditional use-cases `missings` can be defined differently. All variables in `missings` are
10-
treated as random variables rather than observations.
48+
A `Model` struct with model evaluation function of type `F`, arguments names `argnames`, arguments
49+
types `Targs`, missing arguments `missings`, and corresponding model generator. `argnames` and
50+
`missings` are tuples of symbols, e.g. `(:a, :b)`. An argument with a type of `Missing` will be in
51+
`missings` by default. However, in non-traditional use-cases `missings` can be defined differently.
52+
All variables in `missings` are treated as random variables rather than observations.
1153
1254
# Example
1355
1456
```julia
15-
julia> Model{typeof(gdemo)}((x = 1.0, y = 2.0))
16-
Model{typeof(gdemo),(),(:x, :y),Tuple{Float64,Float64}}((x = 1.0, y = 2.0))
57+
julia> Model(f, (x = 1.0, y = 2.0))
58+
Model{typeof(f),(),(:x, :y),Tuple{Float64,Float64}}((x = 1.0, y = 2.0))
1759
18-
julia> Model{typeof(gdemo), (:y,)}((x = 1.0, y = 2.0))
19-
Model{typeof(gdemo),(:y,),(:x, :y),Tuple{Float64,Float64}}((x = 1.0, y = 2.0))
60+
julia> Model{(:y,)}(f, (x = 1.0, y = 2.0))
61+
Model{typeof(f),(:y,),(:x, :y),Tuple{Float64,Float64}}((x = 1.0, y = 2.0))
2062
```
2163
"""
22-
struct Model{G, argnames, missings, Targs} <: AbstractModel
64+
struct Model{F, argnames, Targs, missings, Tgen} <: AbstractModel
65+
f::F
2366
args::NamedTuple{argnames, Targs}
24-
25-
Model{G, missings}(args::NamedTuple{argnames, Targs}) where {G, argnames, missings, Targs} =
26-
new{G, argnames, missings, Targs}(args)
67+
modelgen::Tgen
68+
69+
"""
70+
Model{missings}(f, args::NamedTuple, modelgen::ModelGen)
71+
72+
Create a model with evalutation function `f` and missing arguments overwritten by `missings`.
73+
"""
74+
function Model{missings}(
75+
f::F,
76+
args::NamedTuple{argnames, Targs},
77+
modelgen::Tgen
78+
) where {missings, F, argnames, Targs, Tgen<:ModelGen}
79+
return new{F, argnames, Targs, missings, Tgen}(f, args, modelgen)
80+
end
2781
end
2882

29-
@generated function Model{G}(args::NamedTuple{argnames, Targs}) where {G, argnames, Targs}
83+
"""
84+
Model(f, args::NamedTuple, modelgen::ModelGen)
85+
86+
Create a model with evalutation function `f` and missing arguments deduced from `args`.
87+
"""
88+
@generated function Model(
89+
f::F,
90+
args::NamedTuple{argnames, Targs},
91+
modelgen::ModelGen{_G, argnames}
92+
) where {F, argnames, Targs, _G}
3093
missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing)
31-
return :(Model{G, $missings}(args))
94+
return :(Model{$missings}(f, args, modelgen))
3295
end
3396

3497

35-
(model::Model)(vi) = model(vi, SampleFromPrior())
36-
(model::Model)(vi, spl) = model(vi, spl, DefaultContext())
98+
"""
99+
Model{missings}(modelgen::ModelGen, args::NamedTuple)
100+
101+
Create a copy of the model described by `modelgen(args...)`, with missing arguments
102+
overwritten by `missings`.
103+
"""
104+
function Model{missings}(
105+
modelgen::ModelGen,
106+
args::NamedTuple{argnames, Targs}
107+
) where {missings, argnames, Targs}
108+
model = modelgen(args...)
109+
return Model{missings}(model.f, args, modelgen)
110+
end
111+
112+
113+
function (model::Model)(
114+
vi::AbstractVarInfo=VarInfo(),
115+
spl::AbstractSampler=SampleFromPrior(),
116+
ctx::AbstractContext=DefaultContext()
117+
)
118+
return model.f(model, vi, spl, ctx)
119+
end
120+
121+
122+
"""
123+
runmodel!(model::Model, vi::AbstractVarInfo[, spl::AbstractSampler, ctx::AbstractContext])
124+
125+
Sample from `model` using the sampler `spl` storing the sample and log joint probability in `vi`.
126+
Resets the `vi` and increases `spl`s `state.eval_num`.
127+
"""
128+
function runmodel!(
129+
model::Model,
130+
vi::AbstractVarInfo,
131+
spl::AbstractSampler=SampleFromPrior(),
132+
ctx::AbstractContext=DefaultContext()
133+
)
134+
setlogp!(vi, 0)
135+
if has_eval_num(spl)
136+
spl.state.eval_num += 1
137+
end
138+
model(vi, spl, ctx)
139+
return vi
140+
end
37141

38142

39143
"""
40144
getargnames(model::Model)
41145
42146
Get a tuple of the argument names of the `model`.
43147
"""
44-
getargnames(model::Model) = getargnames(typeof(model))
45-
getargnames(::Type{<:Model{_G, argnames} where {_G}}) where {argnames} = argnames
148+
getargnames(model::Model{_F, argnames}) where {argnames, _F} = argnames
46149

47-
@generated function inargnames(::Val{s}, ::Model{_G, argnames}) where {s, _G, argnames}
150+
@generated function inargnames(::Val{s}, ::Model{_F, argnames}) where {s, argnames, _F}
48151
return s in argnames
49152
end
50153

@@ -54,35 +157,19 @@ end
54157
55158
Get a tuple of the names of the missing arguments of the `model`.
56159
"""
57-
getmissings(model::Model{_G, _a, missings}) where {missings, _G, _a} = missings
160+
getmissings(model::Model{_F, _a, _T, missings}) where {missings, _F, _a, _T} = missings
58161

59162
getmissing(model::Model) = getmissings(model)
60163
@deprecate getmissing(model) getmissings(model)
61164

62-
@generated function inmissings(::Val{s}, ::Model{_G, _a, missings}) where {s, missings, _G, _a}
165+
@generated function inmissings(::Val{s}, ::Model{_F, _a, _T, missings}) where {s, missings, _F, _a, _T}
63166
return s in missings
64167
end
65168

66169

67170
"""
68171
getgenerator(model::Model)
69172
70-
Get the generator (the function defined by the `@model` macro) of a certain model instance.
71-
"""
72-
getgenerator(model::Model) = getgenerator(typeof(model))
73-
74-
75-
"""
76-
getdefaults(model::Model)
77-
78-
Get a named tuple of the default argument values defined for a model defined by a generating function.
79-
"""
80-
getdefaults(model::Model) = getdefaults(typeof(model))
81-
82-
83-
"""
84-
getmodeltype(::typeof(modelgen))
85-
86-
Get the associated model type for a model generator (the function defined by the `@model` macro).
173+
Get the model generator associated with `model`.
87174
"""
88-
getmodeltype(model::Model) = typeof(model)
175+
getgenerator(model::Model) = model.modelgen

0 commit comments

Comments
 (0)