|
| 1 | +# The Probability Interface |
| 2 | + |
| 3 | +The easiest way to manipulate and query DynamicPPL models is via the DynamicPPL probability |
| 4 | +interface. |
| 5 | + |
| 6 | +Let's use a simple model of normally-distributed data as an example. |
| 7 | +```@example probinterface |
| 8 | +using DynamicPPL |
| 9 | +using Distributions |
| 10 | +using FillArrays |
| 11 | +using LinearAlgebra |
| 12 | +using Random |
| 13 | +
|
| 14 | +Random.seed!(1776) # Set seed for reproducibility |
| 15 | +
|
| 16 | +@model function gdemo(n) |
| 17 | + μ ~ Normal(0, 1) |
| 18 | + x ~ MvNormal(Fill(μ, n), I) |
| 19 | + return nothing |
| 20 | +end |
| 21 | +nothing # hide |
| 22 | +``` |
| 23 | + |
| 24 | +We generate some data using `μ = 0` and `σ = 1`: |
| 25 | + |
| 26 | +```@example probinterface |
| 27 | +dataset = randn(100) |
| 28 | +nothing # hide |
| 29 | +``` |
| 30 | + |
| 31 | +## Conditioning and Deconditioning |
| 32 | + |
| 33 | +Bayesian models can be transformed with two main operations, conditioning and deconditioning (also known as marginalization). |
| 34 | +Conditioning takes a variable and fixes its value as known. |
| 35 | +We do this by passing a model and a named tuple of conditioned variables to `|`: |
| 36 | +```@example probinterface |
| 37 | +model = gdemo(length(dataset)) | (x=dataset, μ=0, σ=1) |
| 38 | +nothing # hide |
| 39 | +``` |
| 40 | + |
| 41 | +This operation can be reversed by applying `decondition`: |
| 42 | +```@example probinterface |
| 43 | +decondition(model) |
| 44 | +nothing # hide |
| 45 | +``` |
| 46 | + |
| 47 | +We can also decondition only some of the variables: |
| 48 | +```@example probinterface |
| 49 | +decondition(model, :μ) |
| 50 | +nothing # hide |
| 51 | +``` |
| 52 | + |
| 53 | +## Probabilities and Densities |
| 54 | + |
| 55 | +We often want to calculate the (unnormalized) probability density for an event. |
| 56 | +This probability might be a prior, a likelihood, or a posterior (joint) density. |
| 57 | +DynamicPPL provides convenient functions for this. |
| 58 | +For example, if we wanted to calculate the probability of a draw from the prior: |
| 59 | +```@example probinterface |
| 60 | +model = gdemo(length(dataset)) | (x=dataset,) |
| 61 | +x1 = rand(model) |
| 62 | +logjoint(model, x1) |
| 63 | +``` |
| 64 | + |
| 65 | +For convenience, we provide the functions `loglikelihood` and `logjoint` to calculate probabilities for a named tuple, given a model: |
| 66 | +```@example probinterface |
| 67 | +@assert logjoint(model, x1) ≈ loglikelihood(model, x1) + logprior(model, x1) |
| 68 | +``` |
| 69 | + |
| 70 | +## Example: Cross-validation |
| 71 | + |
| 72 | +To give an example of the probability interface in use, we can use it to estimate the performance of our model using cross-validation. In cross-validation, we split the dataset into several equal parts. Then, we choose one of these sets to serve as the validation set. Here, we measure fit using the cross entropy (Bayes loss).¹ |
| 73 | +``` @example probinterface |
| 74 | +function cross_val(model, dataset) |
| 75 | + training_loss = zero(logjoint(model, rand(model))) |
| 76 | +
|
| 77 | + # Partition our dataset into 5 folds with 20 observations: |
| 78 | + test_folds = collect(Iterators.partition(dataset, 20)) |
| 79 | + train_folds = setdiff.((dataset,), test_folds) |
| 80 | +
|
| 81 | + for (train, test) in zip(train_folds, test_folds) |
| 82 | + # First, we train the model on the training set. |
| 83 | + # For normally-distributed data, the posterior can be solved in closed form: |
| 84 | + posterior = Normal(mean(train), 1) |
| 85 | + # Sample from the posterior |
| 86 | + samples = NamedTuple{(:μ,)}.(rand(posterior, 1000)) |
| 87 | + # Test |
| 88 | + testing_model = gdemo(length(test)) | (x = test,) |
| 89 | + training_loss += sum(samples) do sample |
| 90 | + logjoint(testing_model, sample) |
| 91 | + end |
| 92 | + end |
| 93 | + return training_loss |
| 94 | +end |
| 95 | +cross_val(model, dataset) |
| 96 | +``` |
| 97 | + |
| 98 | +¹See [ParetoSmooth.jl](https://github.com/TuringLang/ParetoSmooth.jl) for a faster and more accurate implementation of cross-validation than the one provided here. |
0 commit comments