Skip to content

Commit d473998

Browse files
authored
Add DynamicPPL doc to here (#528)
See TuringLang/DynamicPPL.jl#675 Some minor changes made, namely importing Turing over importing DynamicPPL since that's what most people reading this will be doing.
1 parent 5255f44 commit d473998

File tree

4 files changed

+186
-1
lines changed

4 files changed

+186
-1
lines changed

Manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.10.5"
44
manifest_format = "2.0"
5-
project_hash = "e0661388214f9e03749b3fccc86939dfd3853246"
5+
project_hash = "24c19d9af05129bd51c95e6af97ff0385a185787"
66

77
[[deps.ADTypes]]
88
git-tree-sha1 = "016833eb52ba2d6bea9fcb50ca295980e728ee24"

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@ CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
99
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
1010
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
1111
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
12+
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1213
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
1314
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1415
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1516
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
17+
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
1618
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1719
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1820
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

_quarto.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ website:
6161
contents:
6262
- tutorials/docs-10-using-turing-autodiff/index.qmd
6363
- tutorials/usage-custom-distribution/index.qmd
64+
- tutorials/usage-probability-interface/index.qmd
6465
- tutorials/usage-modifying-logprob/index.qmd
6566
- tutorials/usage-generated-quantities/index.qmd
6667
- tutorials/docs-17-mode-estimation/index.qmd
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
---
2+
title: Querying Model Probabilities
3+
engine: julia
4+
---
5+
6+
```{julia}
7+
#| echo: false
8+
#| output: false
9+
using Pkg;
10+
Pkg.instantiate();
11+
```
12+
13+
The easiest way to manipulate and query Turing models is via the DynamicPPL probability interface.
14+
15+
Let's use a simple model of normally-distributed data as an example.
16+
17+
```{julia}
18+
using Turing
19+
using LinearAlgebra: I
20+
using Random
21+
22+
@model function gdemo(n)
23+
μ ~ Normal(0, 1)
24+
x ~ MvNormal(fill(μ, n), I)
25+
end
26+
```
27+
28+
We generate some data using `μ = 0`:
29+
30+
```{julia}
31+
Random.seed!(1776)
32+
dataset = randn(100)
33+
dataset[1:5]
34+
```
35+
36+
## Conditioning and Deconditioning
37+
38+
Bayesian models can be transformed with two main operations, conditioning and deconditioning (also known as marginalization).
39+
Conditioning takes a variable and fixes its value as known.
40+
We do this by passing a model and a collection of conditioned variables to `|`, or its alias, `condition`:
41+
42+
```{julia}
43+
# (equivalently)
44+
# conditioned_model = condition(gdemo(length(dataset)), (x=dataset, μ=0))
45+
conditioned_model = gdemo(length(dataset)) | (x=dataset, μ=0)
46+
```
47+
48+
This operation can be reversed by applying `decondition`:
49+
50+
```{julia}
51+
original_model = decondition(conditioned_model)
52+
```
53+
54+
We can also decondition only some of the variables:
55+
56+
```{julia}
57+
partially_conditioned = decondition(conditioned_model, :μ)
58+
```
59+
60+
We can see which of the variables in a model have been conditioned with `DynamicPPL.conditioned`:
61+
62+
```{julia}
63+
DynamicPPL.conditioned(partially_conditioned)
64+
```
65+
66+
::: {.callout-note}
67+
Sometimes it is helpful to define convenience functions for conditioning on some variable(s).
68+
For instance, in this example we might want to define a version of `gdemo` that conditions on some observations of `x`:
69+
70+
```julia
71+
gdemo(x::AbstractVector{<:Real}) = gdemo(length(x)) | (; x)
72+
```
73+
74+
For illustrative purposes, however, we do not use this function in the examples below.
75+
:::
76+
77+
## Probabilities and Densities
78+
79+
We often want to calculate the (unnormalized) probability density for an event.
80+
This probability might be a prior, a likelihood, or a posterior (joint) density.
81+
DynamicPPL provides convenient functions for this.
82+
To begin, let's define a model `gdemo`, condition it on a dataset, and draw a sample.
83+
The returned sample only contains `μ`, since the value of `x` has already been fixed:
84+
85+
```{julia}
86+
model = gdemo(length(dataset)) | (x=dataset,)
87+
88+
Random.seed!(124)
89+
sample = rand(model)
90+
```
91+
92+
We can then calculate the joint probability of a set of samples (here drawn from the prior) with `logjoint`.
93+
94+
```{julia}
95+
logjoint(model, sample)
96+
```
97+
98+
For models with many variables `rand(model)` can be prohibitively slow since it returns a `NamedTuple` of samples from the prior distribution of the unconditioned variables.
99+
We recommend working with samples of type `DataStructures.OrderedDict` in this case:
100+
101+
```{julia}
102+
using DataStructures: OrderedDict
103+
104+
Random.seed!(124)
105+
sample_dict = rand(OrderedDict, model)
106+
```
107+
108+
`logjoint` can also be used on this sample:
109+
110+
```{julia}
111+
logjoint(model, sample_dict)
112+
```
113+
114+
The prior probability and the likelihood of a set of samples can be calculated with the functions `logprior` and `loglikelihood` respectively.
115+
The log joint probability is the sum of these two quantities:
116+
117+
```{julia}
118+
logjoint(model, sample) ≈ loglikelihood(model, sample) + logprior(model, sample)
119+
```
120+
121+
```{julia}
122+
logjoint(model, sample_dict) ≈ loglikelihood(model, sample_dict) + logprior(model, sample_dict)
123+
```
124+
125+
## Example: Cross-validation
126+
127+
To give an example of the probability interface in use, we can use it to estimate the performance of our model using cross-validation.
128+
In cross-validation, we split the dataset into several equal parts.
129+
Then, we choose one of these sets to serve as the validation set.
130+
Here, we measure fit using the cross entropy (Bayes loss).[^1]
131+
(For the sake of simplicity, in the following code, we enforce that `nfolds` must divide the number of data points.
132+
For a more competent implementation, see [MLUtils.jl](https://juliaml.github.io/MLUtils.jl/dev/api/#MLUtils.kfolds).)
133+
134+
```{julia}
135+
# Calculate the train/validation splits across `nfolds` partitions, assume `length(dataset)` divides `nfolds`
136+
function kfolds(dataset::Array{<:Real}, nfolds::Int)
137+
fold_size, remaining = divrem(length(dataset), nfolds)
138+
if remaining != 0
139+
error("The number of folds must divide the number of data points.")
140+
end
141+
first_idx = firstindex(dataset)
142+
last_idx = lastindex(dataset)
143+
splits = map(0:(nfolds - 1)) do i
144+
start_idx = first_idx + i * fold_size
145+
end_idx = start_idx + fold_size
146+
train_set_indices = [first_idx:(start_idx - 1); end_idx:last_idx]
147+
return (view(dataset, train_set_indices), view(dataset, start_idx:(end_idx - 1)))
148+
end
149+
return splits
150+
end
151+
152+
function cross_val(
153+
dataset::Vector{<:Real};
154+
nfolds::Int=5,
155+
nsamples::Int=1_000,
156+
rng::Random.AbstractRNG=Random.default_rng(),
157+
)
158+
# Initialize `loss` in a way such that the loop below does not change its type
159+
model = gdemo(1) | (x=[first(dataset)],)
160+
loss = zero(logjoint(model, rand(rng, model)))
161+
162+
for (train, validation) in kfolds(dataset, nfolds)
163+
# First, we train the model on the training set, i.e., we obtain samples from the posterior.
164+
# For normally-distributed data, the posterior can be computed in closed form.
165+
# For general models, however, typically samples will be generated using MCMC with Turing.
166+
posterior = Normal(mean(train), 1)
167+
samples = rand(rng, posterior, nsamples)
168+
169+
# Evaluation on the validation set.
170+
validation_model = gdemo(length(validation)) | (x=validation,)
171+
loss += sum(samples) do sample
172+
logjoint(validation_model, (μ=sample,))
173+
end
174+
end
175+
176+
return loss
177+
end
178+
179+
cross_val(dataset)
180+
```
181+
182+
[^1]: See [ParetoSmooth.jl](https://github.com/TuringLang/ParetoSmooth.jl) for a faster and more accurate implementation of cross-validation than the one provided here.

0 commit comments

Comments
 (0)