Skip to content

Commit f45b252

Browse files
Carlos Paradayebaidevmotion
authored andcommitted
Probability interface tutorial (#404)
First addition to the DynamicPPL tutorials; breaking this up as Hong suggested. Goes over how to use the basic interfaces (e.g. logjoint, loglikelihood, logdensityof). Co-authored-by: Hong Ge <[email protected]> Co-authored-by: David Widmann <[email protected]>
1 parent b7fbb5d commit f45b252

File tree

4 files changed

+105
-1
lines changed

4 files changed

+105
-1
lines changed

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
[deps]
22
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
33
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
4+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
45
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
56
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
67

78
[compat]
89
Distributions = "0.25"
910
Documenter = "0.27"
11+
FillArrays = "0.13"
1012
Setfield = "0.7.1, 0.8, 1"
1113
StableRNGs = "1"

docs/make.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@ makedocs(;
1111
sitename="DynamicPPL",
1212
format=Documenter.HTML(),
1313
modules=[DynamicPPL],
14-
pages=["Home" => "index.md", "API" => "api.md"],
14+
pages=[
15+
"Home" => "index.md",
16+
"API" => "api.md",
17+
"Tutorials" => ["tutorials/prob-interface.md"],
18+
],
1519
strict=true,
1620
checkdocs=:exports,
1721
)
File renamed without changes.
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# The Probability Interface
2+
3+
The easiest way to manipulate and query DynamicPPL models is via the DynamicPPL probability
4+
interface.
5+
6+
Let's use a simple model of normally-distributed data as an example.
7+
```@example probinterface
8+
using DynamicPPL
9+
using Distributions
10+
using FillArrays
11+
using LinearAlgebra
12+
using Random
13+
14+
Random.seed!(1776) # Set seed for reproducibility
15+
16+
@model function gdemo(n)
17+
μ ~ Normal(0, 1)
18+
x ~ MvNormal(Fill(μ, n), I)
19+
return nothing
20+
end
21+
nothing # hide
22+
```
23+
24+
We generate some data using `μ = 0` and `σ = 1`:
25+
26+
```@example probinterface
27+
dataset = randn(100)
28+
nothing # hide
29+
```
30+
31+
## Conditioning and Deconditioning
32+
33+
Bayesian models can be transformed with two main operations, conditioning and deconditioning (also known as marginalization).
34+
Conditioning takes a variable and fixes its value as known.
35+
We do this by passing a model and a named tuple of conditioned variables to `|`:
36+
```@example probinterface
37+
model = gdemo(length(dataset)) | (x=dataset, μ=0, σ=1)
38+
nothing # hide
39+
```
40+
41+
This operation can be reversed by applying `decondition`:
42+
```@example probinterface
43+
decondition(model)
44+
nothing # hide
45+
```
46+
47+
We can also decondition only some of the variables:
48+
```@example probinterface
49+
decondition(model, :μ)
50+
nothing # hide
51+
```
52+
53+
## Probabilities and Densities
54+
55+
We often want to calculate the (unnormalized) probability density for an event.
56+
This probability might be a prior, a likelihood, or a posterior (joint) density.
57+
DynamicPPL provides convenient functions for this.
58+
For example, if we wanted to calculate the probability of a draw from the prior:
59+
```@example probinterface
60+
model = gdemo(length(dataset)) | (x=dataset,)
61+
x1 = rand(model)
62+
logjoint(model, x1)
63+
```
64+
65+
For convenience, we provide the functions `loglikelihood` and `logjoint` to calculate probabilities for a named tuple, given a model:
66+
```@example probinterface
67+
@assert logjoint(model, x1) ≈ loglikelihood(model, x1) + logprior(model, x1)
68+
```
69+
70+
## Example: Cross-validation
71+
72+
To give an example of the probability interface in use, we can use it to estimate the performance of our model using cross-validation. In cross-validation, we split the dataset into several equal parts. Then, we choose one of these sets to serve as the validation set. Here, we measure fit using the cross entropy (Bayes loss).¹
73+
``` @example probinterface
74+
function cross_val(model, dataset)
75+
training_loss = zero(logjoint(model, rand(model)))
76+
77+
# Partition our dataset into 5 folds with 20 observations:
78+
test_folds = collect(Iterators.partition(dataset, 20))
79+
train_folds = setdiff.((dataset,), test_folds)
80+
81+
for (train, test) in zip(train_folds, test_folds)
82+
# First, we train the model on the training set.
83+
# For normally-distributed data, the posterior can be solved in closed form:
84+
posterior = Normal(mean(train), 1)
85+
# Sample from the posterior
86+
samples = NamedTuple{(:μ,)}.(rand(posterior, 1000))
87+
# Test
88+
testing_model = gdemo(length(test)) | (x = test,)
89+
training_loss += sum(samples) do sample
90+
logjoint(testing_model, sample)
91+
end
92+
end
93+
return training_loss
94+
end
95+
cross_val(model, dataset)
96+
```
97+
98+
¹See [ParetoSmooth.jl](https://github.com/TuringLang/ParetoSmooth.jl) for a faster and more accurate implementation of cross-validation than the one provided here.

0 commit comments

Comments
 (0)