Skip to content

Commit eee8e1c

Browse files
authored
Merge pull request #14 from TuringLang/phg/model_abstraction
First draft of model abstraction
2 parents 7ec39ab + f64b0a9 commit eee8e1c

File tree

4 files changed

+295
-3
lines changed

4 files changed

+295
-3
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ version = "0.1.4"
77

88
[deps]
99
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
10+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1011

1112
[compat]
1213
AbstractMCMC = "2, 3"
14+
StatsBase = "0.33.4"
1315
julia = "1"

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@ simplified models such as GPs, GLMs, or plain log-density problems.
2020

2121
A more short term goal is to start a process of cleanly refactoring and justifying parts of
2222
AbstractPPL.jl’s design, and hopefully to get on closer terms with Soss.jl.
23+
24+
See [interface draft](interface.md).

interface.md

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
## `AbstractProbabilisticProgram` interface
2+
3+
There are at least two somewhat incompatible conventions used for the term “model”. None of this is
4+
particularly exact, but:
5+
6+
- In Turing.jl, if you write down a `@model` function and call it on arguments, you get a model
7+
object paired with (a possibly empty set of) observations. This can be treated as instantiated
8+
“conditioned” object with fixed values for parameters and observations.
9+
- In Soss.jl, “model” is used for a symbolic “generative” object from which concrete functions, such as
10+
densities and sampling functions, can be derived, _and_ which you can later condition on (and in
11+
turn get a conditional density etc.).
12+
13+
Relevant discussions:
14+
[1](https://julialang.zulipchat.com/#narrow/stream/234072-probprog/topic/Naming.20the.20.22likelihood.22.20thingy),
15+
[2](https://github.com/TuringLang/AbstractPPL.jl/discussions/10).
16+
17+
18+
### TL/DR:
19+
20+
21+
There are three interrelating aspects that this interface intends to standardize:
22+
23+
- Density calculation
24+
- Sampling
25+
- “Conversions” between different conditionings of models
26+
27+
Therefore, the interface consists of:
28+
29+
- `condition(::Model, ::Trace) -> ConditionedModel`
30+
- `decondition(::ConditionedModel) -> GenerativeModel`
31+
- `sample(::Model, ::Sampler = Exact(), [Int])` (from `AbstractMCMC.sample`)
32+
- `logdensity(::Model, ::Trace)`
33+
34+
35+
### Traces & probability expressions
36+
37+
First, an infrastructural requirement which we will need below to write things out.
38+
39+
The kinds of models we consider are, at least in a theoretical sense, distributions over *traces*
40+
types which carry collections of values together with their names. Existing realizations of these
41+
are `VarInfo` in Turing.jl, choice maps in Gen.jl, and the usage of named tuples in Soss.jl.
42+
43+
Traces solve the problem of having to name random variables in function calls, and in samples from
44+
models. In essence, every concrete trace type will just be a fancy kind of dictionary from variable
45+
names (ideally, `VarName`s) to values.
46+
47+
Since we have to use this kind of mapping a lot in the specification of the interface, let’s for now
48+
just choose some arbitrary macro-like syntax like the following:
49+
50+
```julia
51+
@T(Y[1] = , Z = )
52+
```
53+
54+
Some more ideas for this kind of object can be found at the end.
55+
56+
57+
### “Conversions”
58+
59+
The purpose of this part is to provide common names for how we want a model instance to be
60+
understood. As we have seen, in some modelling languages, model instances are primarily generative,
61+
with some parameters fixed, while other instance types pair model instances conditioned on
62+
observations. What I call “conversions” here is just an interface to transform between these two
63+
views and unify the involved objects under one language.
64+
65+
Let’s start from a generative model with parameter `μ`:
66+
67+
```julia
68+
# (hypothetical) generative spec a la Soss
69+
@generative_model function foo_gen(μ)
70+
X ~ Normal(0, μ)
71+
Y[1] ~ Normal(X)
72+
Y[2] ~ Normal(X + 1)
73+
end
74+
```
75+
76+
Applying the “constructor” `foo_gen` now means to fix the parameter, and should return a concrete
77+
object of the generative type:
78+
79+
```julia
80+
g = foo_gen=)::SomeGenerativeModel
81+
```
82+
83+
With this kind of object, we should be able to sample and calculate joint log-densities from, i.e.,
84+
over the combined trace space of `X`, `Y[1]`, and `Y[2]` – either directly, or by deriving the
85+
respective functions (e.g., by converting form a symbolic representation).
86+
87+
For model types that contain enough structural information, it should then be possible to condition
88+
on observed values and obtain a conditioned model:
89+
90+
```julia
91+
condition(g, @T(Y = ))::SomeConditionedModel
92+
```
93+
94+
For this operation, there will probably exist syntactic sugar in the form of
95+
96+
```julia
97+
g | @T(Y = )
98+
```
99+
100+
Now, if we start from a Turing.jl-like model instead, with the “observation part” already specified,
101+
we have a situation like this, with the observations `Y` fixed in the instantiation:
102+
103+
```julia
104+
# conditioned spec a la DPPL
105+
@model function foo(Y, μ)
106+
X ~ Normal(0, μ)
107+
Y[1] ~ Normal(X)
108+
Y[2] ~ Normal(X + 1)
109+
end
110+
111+
m = foo(Y=, μ=)::SomeConditionedModel
112+
```
113+
114+
From this we can, if supported, go back to the generative form via `decondition`, and back via
115+
`condition`:
116+
117+
```julia
118+
decondition(m) == g::SomeGenerativeModel
119+
m == condition(g, @T(Y = ))
120+
```
121+
122+
(with equality in distribution).
123+
124+
In the case of Turing.jl, the object `m` would at the same time contain the information about the
125+
generative and posterior distribution `condition` and `decondition` can simply return different
126+
kinds of “tagged” model types which put the model specification into a certain context.
127+
128+
Soss.jl pretty much already works like the examples above, with one model object being either a
129+
`JointModel` or a `ConditionedModel`, and the `|` syntax just being sugar for the latter.
130+
131+
A hypothetical `DensityModel`, or something like the types from LogDensityProblems.jl, would be a
132+
case for a model type that does not support the structural operations `condition` and
133+
`decondition`.
134+
135+
The invariances between these operations should follow normal rules of probability theory. Not all
136+
methods or directions need to be supported for every modelling language; in this case, a
137+
`MethodError` or some other runtime error should be raised.
138+
139+
There is no strict requirement for generative models and conditioned models to have different types
140+
or be tagged with variable names etc. This is a choice to be made by the concrete implementation.
141+
142+
Decomposing models into prior and observation distributions is not yet specified; the former is
143+
rather easy, since it is only a marginal of the generative distribution, while the latter requires
144+
more structural information. Perhaps both can be generalized under the `query` function I discuss
145+
at the end.
146+
147+
148+
### Sampling
149+
150+
Sampling in this case refers to producing values from the distribution specified in a model
151+
instance, either following the distribution exactly, or approximating it through a Monte Carlo
152+
algorithm.
153+
154+
All sampleable model instances are assumed to implement the `AbstractMCMC` interface – i.e., at
155+
least [`step`](https://github.com/TuringLang/AbstractMCMC.jl#sampling-step), and accordingly
156+
`sample`, `steps`, `Samples`. The most important aspect is `sample`, though, which plays the role
157+
of `rand` for distributions.
158+
159+
The results of `sample` generalize `rand` – while `rand(d, N)` is assumed to give you iid samples,
160+
`sample(m, sampler, N)` returns a sample from a sequence (known as chain in the case of MCMC) of
161+
length `N` approximating `m`’s distribution by a specific sampling algorithm (which of course
162+
subsumes the case that `m` can be sampled from exactly, in which case the “chain” actually is iid).
163+
164+
Depending on which kind of sampling is supported, several methods may be supported. In the case of
165+
a (posterior) conditioned model with no known sampling procedure, we just have what is given through
166+
`AbstractMCMC`:
167+
168+
```julia
169+
sample([rng], m, N, sampler; [args]) # chain of length N using `sampler`
170+
```
171+
172+
In the case of a generative model, or a posterior model with exact solution, we can have some more
173+
methods without the need to specify a sampler:
174+
175+
```julia
176+
sample([rng], m; [args]) # one random sample
177+
sample([rng], m, N; [args]) # N iid samples; equivalent to `rand` in certain cases
178+
```
179+
180+
It should be possible to implement this by a special sampler, say, `Exact` (name still to be
181+
discussed), that can then also be reused for generative sampling:
182+
183+
```
184+
step(g, spl = Exact(), state = nothing) # IID sample from exact distribution with trivial state
185+
sample(g, Exact(), [N])
186+
```
187+
188+
with dispatch failing for models types for which exact sampling is not possible (or not
189+
implemented).
190+
191+
This could even be useful for Monte Carlo methods not being based on Markov Chains, e.g.,
192+
particle-based sampling using a return type with weights, or rejection sampling.
193+
194+
Not all variants need to be supported – for example, a posterior model might not support
195+
`sample(m)` when exact sampling is not possible, only `sample(m, N, alg)` for Markov chains.
196+
197+
`rand` is then just a special case when “trivial” exact sampling works for a model, e.g. a joint
198+
model.
199+
200+
201+
### Density Calculation
202+
203+
Since the different “versions” of how a model is to be understood as generative or conditioned are
204+
to be expressed in the type or dispatch they support, there should be no need for separate functions
205+
`logjoint`, `loglikelihood`, etc., which force these semantic distinctions on the implementor; one
206+
`logdensity` should suffice for all, with the distinction being made by the capabilities of the
207+
concrete model instance.
208+
209+
Note that this generalizes `logpdf`, too, since the posterior density will of course in general be
210+
unnormalized and hence not a probability density.
211+
212+
The evaluation will usually work with the internal, concrete trace type, like `VarInfo` in Turing.jl:
213+
214+
```julia
215+
logdensity(m, vi)
216+
```
217+
218+
But the user will more likely work on the interface using probability expressions:
219+
220+
```julia
221+
logdensity(m, @T(X = ...))
222+
```
223+
224+
(Note that this would replace the current `prob` string macro in Turing.jl.)
225+
226+
Densities need not be normalized.
227+
228+
229+
#### Implementation notes
230+
231+
It should be able to make this fall back on the internal method with the right definition and
232+
implementation of `maketrace`:
233+
234+
```julia
235+
logdensity(m, t::ProbabilityExpression) = logdensity(m, maketrace(m, t))
236+
```
237+
238+
There is one open question – should normalized and unnormalized densities be able to be
239+
distinguished? This could be done by dispatch as well, e.g., if the caller wants to make sure
240+
normalization:
241+
242+
```
243+
logdensity(g, @T(X = ..., Y = ..., Z = ...); normalized=Val{true})
244+
```
245+
246+
Although there is proably a better way through traits; maybe like for arrays, with
247+
`NormalizationStyle(g, t) = IsNormalized()`?
248+
249+
250+
## More on probability expressions
251+
252+
Note that this needs to be a macro, if written this way, since the keys may themselves be more
253+
complex than just symbols (e.g., indexed variables.) (Don’t hang yourselves up on that `@T` name
254+
though, this is just a working draft.)
255+
256+
The idea here is to standardize the construction (and manipulation) of *abstract probability
257+
expressions*, plus the interface for turning them into concrete traces for a specific model – like
258+
[`@formula`](https://juliastats.org/StatsModels.jl/stable/formula/#Modeling-tabular-data) and
259+
[`apply_schema`](https://juliastats.org/StatsModels.jl/stable/internals/#Semantics-time-(apply_schema))
260+
from StatsModels.jl are doing.
261+
262+
Maybe the following would suffice to do that:
263+
264+
```julia
265+
maketrace(m, t)::tracetype(m, t)
266+
```
267+
268+
where `maketrace` produces a concrete trace corresponding to `t` for the model `m`, and `tracetype`
269+
is the corresponding `eltype`–like function giving you the concrete trace type for a certain model
270+
and probability expression combination.
271+
272+
Possible extensions of this idea:
273+
274+
- Pearl-style do-notation: `@T(Y = y | do(X = x))`
275+
- Allowing free variables, to specify model transformations: `query(m, @T(X | Y))`
276+
- “Graph queries”: `@T(X | Parents(X))`, `@T(Y | Not(X))` (a nice way to express Gibbs conditionals!)
277+
- Predicate style for “measure queries”: `@T(X < Y + Z)`
278+
279+
The latter applications are the reason I originally liked the idea of the macro being called `@P`
280+
(or even `@𝓅` or `@ℙ`), since then it would look like a “Bayesian probability expression”: `@P(X <
281+
Y + Z)`. But this would not be so meaningful in the case of representing a trace instance.
282+
283+
Perhaps both `@T` and `@P` can coexist, and both produce different kinds of `ProbabilityExpression`
284+
objects?
285+
286+
NB: the exact details of this kind of “schema application”, and what results from it, will need to
287+
be specified in the interface of `AbstractModelTrace`, aka “the new `VarInfo`”.
288+

src/varname.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ function varname(expr::Expr)
255255
sym, inds = vsym(expr), vinds(expr)
256256
return :($(AbstractPPL.VarName){$(QuoteNode(sym))}($inds))
257257
else
258-
throw("Malformed variable name $(expr)!")
258+
error("Malformed variable name $(expr)!")
259259
end
260260
end
261261

@@ -295,7 +295,7 @@ function vsym(expr::Expr)
295295
if Meta.isexpr(expr, :ref)
296296
return vsym(expr.args[1])
297297
else
298-
throw("Malformed variable name $(expr)!")
298+
error("Malformed variable name $(expr)!")
299299
end
300300
end
301301

@@ -363,6 +363,6 @@ function vinds(expr::Expr)
363363
init = vinds(ex.args[1]).args
364364
return Expr(:tuple, init..., last)
365365
else
366-
throw("VarName: Mis-formed variable name $(expr)!")
366+
error("Mis-formed variable name $(expr)!")
367367
end
368368
end

0 commit comments

Comments
 (0)