|
| 1 | +--- |
| 2 | +title: Querying Model Probabilities |
| 3 | +engine: julia |
| 4 | +--- |
| 5 | + |
| 6 | +```{julia} |
| 7 | +#| echo: false |
| 8 | +#| output: false |
| 9 | +using Pkg; |
| 10 | +Pkg.instantiate(); |
| 11 | +``` |
| 12 | + |
| 13 | +The easiest way to manipulate and query Turing models is via the DynamicPPL probability interface. |
| 14 | + |
| 15 | +Let's use a simple model of normally-distributed data as an example. |
| 16 | + |
| 17 | +```{julia} |
| 18 | +using Turing |
| 19 | +using LinearAlgebra: I |
| 20 | +using Random |
| 21 | +
|
| 22 | +@model function gdemo(n) |
| 23 | + μ ~ Normal(0, 1) |
| 24 | + x ~ MvNormal(fill(μ, n), I) |
| 25 | +end |
| 26 | +``` |
| 27 | + |
| 28 | +We generate some data using `μ = 0`: |
| 29 | + |
| 30 | +```{julia} |
| 31 | +Random.seed!(1776) |
| 32 | +dataset = randn(100) |
| 33 | +dataset[1:5] |
| 34 | +``` |
| 35 | + |
| 36 | +## Conditioning and Deconditioning |
| 37 | + |
| 38 | +Bayesian models can be transformed with two main operations, conditioning and deconditioning (also known as marginalization). |
| 39 | +Conditioning takes a variable and fixes its value as known. |
| 40 | +We do this by passing a model and a collection of conditioned variables to `|`, or its alias, `condition`: |
| 41 | + |
| 42 | +```{julia} |
| 43 | +# (equivalently) |
| 44 | +# conditioned_model = condition(gdemo(length(dataset)), (x=dataset, μ=0)) |
| 45 | +conditioned_model = gdemo(length(dataset)) | (x=dataset, μ=0) |
| 46 | +``` |
| 47 | + |
| 48 | +This operation can be reversed by applying `decondition`: |
| 49 | + |
| 50 | +```{julia} |
| 51 | +original_model = decondition(conditioned_model) |
| 52 | +``` |
| 53 | + |
| 54 | +We can also decondition only some of the variables: |
| 55 | + |
| 56 | +```{julia} |
| 57 | +partially_conditioned = decondition(conditioned_model, :μ) |
| 58 | +``` |
| 59 | + |
| 60 | +We can see which of the variables in a model have been conditioned with `DynamicPPL.conditioned`: |
| 61 | + |
| 62 | +```{julia} |
| 63 | +DynamicPPL.conditioned(partially_conditioned) |
| 64 | +``` |
| 65 | + |
| 66 | +::: {.callout-note} |
| 67 | +Sometimes it is helpful to define convenience functions for conditioning on some variable(s). |
| 68 | +For instance, in this example we might want to define a version of `gdemo` that conditions on some observations of `x`: |
| 69 | + |
| 70 | +```julia |
| 71 | +gdemo(x::AbstractVector{<:Real}) = gdemo(length(x)) | (; x) |
| 72 | +``` |
| 73 | + |
| 74 | +For illustrative purposes, however, we do not use this function in the examples below. |
| 75 | +::: |
| 76 | + |
| 77 | +## Probabilities and Densities |
| 78 | + |
| 79 | +We often want to calculate the (unnormalized) probability density for an event. |
| 80 | +This probability might be a prior, a likelihood, or a posterior (joint) density. |
| 81 | +DynamicPPL provides convenient functions for this. |
| 82 | +To begin, let's define a model `gdemo`, condition it on a dataset, and draw a sample. |
| 83 | +The returned sample only contains `μ`, since the value of `x` has already been fixed: |
| 84 | + |
| 85 | +```{julia} |
| 86 | +model = gdemo(length(dataset)) | (x=dataset,) |
| 87 | +
|
| 88 | +Random.seed!(124) |
| 89 | +sample = rand(model) |
| 90 | +``` |
| 91 | + |
| 92 | +We can then calculate the joint probability of a set of samples (here drawn from the prior) with `logjoint`. |
| 93 | + |
| 94 | +```{julia} |
| 95 | +logjoint(model, sample) |
| 96 | +``` |
| 97 | + |
| 98 | +For models with many variables `rand(model)` can be prohibitively slow since it returns a `NamedTuple` of samples from the prior distribution of the unconditioned variables. |
| 99 | +We recommend working with samples of type `DataStructures.OrderedDict` in this case: |
| 100 | + |
| 101 | +```{julia} |
| 102 | +using DataStructures: OrderedDict |
| 103 | +
|
| 104 | +Random.seed!(124) |
| 105 | +sample_dict = rand(OrderedDict, model) |
| 106 | +``` |
| 107 | + |
| 108 | +`logjoint` can also be used on this sample: |
| 109 | + |
| 110 | +```{julia} |
| 111 | +logjoint(model, sample_dict) |
| 112 | +``` |
| 113 | + |
| 114 | +The prior probability and the likelihood of a set of samples can be calculated with the functions `logprior` and `loglikelihood` respectively. |
| 115 | +The log joint probability is the sum of these two quantities: |
| 116 | + |
| 117 | +```{julia} |
| 118 | +logjoint(model, sample) ≈ loglikelihood(model, sample) + logprior(model, sample) |
| 119 | +``` |
| 120 | + |
| 121 | +```{julia} |
| 122 | +logjoint(model, sample_dict) ≈ loglikelihood(model, sample_dict) + logprior(model, sample_dict) |
| 123 | +``` |
| 124 | + |
| 125 | +## Example: Cross-validation |
| 126 | + |
| 127 | +To give an example of the probability interface in use, we can use it to estimate the performance of our model using cross-validation. |
| 128 | +In cross-validation, we split the dataset into several equal parts. |
| 129 | +Then, we choose one of these sets to serve as the validation set. |
| 130 | +Here, we measure fit using the cross entropy (Bayes loss).[^1] |
| 131 | +(For the sake of simplicity, in the following code, we enforce that `nfolds` must divide the number of data points. |
| 132 | +For a more competent implementation, see [MLUtils.jl](https://juliaml.github.io/MLUtils.jl/dev/api/#MLUtils.kfolds).) |
| 133 | + |
| 134 | +```{julia} |
| 135 | +# Calculate the train/validation splits across `nfolds` partitions, assume `length(dataset)` divides `nfolds` |
| 136 | +function kfolds(dataset::Array{<:Real}, nfolds::Int) |
| 137 | + fold_size, remaining = divrem(length(dataset), nfolds) |
| 138 | + if remaining != 0 |
| 139 | + error("The number of folds must divide the number of data points.") |
| 140 | + end |
| 141 | + first_idx = firstindex(dataset) |
| 142 | + last_idx = lastindex(dataset) |
| 143 | + splits = map(0:(nfolds - 1)) do i |
| 144 | + start_idx = first_idx + i * fold_size |
| 145 | + end_idx = start_idx + fold_size |
| 146 | + train_set_indices = [first_idx:(start_idx - 1); end_idx:last_idx] |
| 147 | + return (view(dataset, train_set_indices), view(dataset, start_idx:(end_idx - 1))) |
| 148 | + end |
| 149 | + return splits |
| 150 | +end |
| 151 | +
|
| 152 | +function cross_val( |
| 153 | + dataset::Vector{<:Real}; |
| 154 | + nfolds::Int=5, |
| 155 | + nsamples::Int=1_000, |
| 156 | + rng::Random.AbstractRNG=Random.default_rng(), |
| 157 | +) |
| 158 | + # Initialize `loss` in a way such that the loop below does not change its type |
| 159 | + model = gdemo(1) | (x=[first(dataset)],) |
| 160 | + loss = zero(logjoint(model, rand(rng, model))) |
| 161 | +
|
| 162 | + for (train, validation) in kfolds(dataset, nfolds) |
| 163 | + # First, we train the model on the training set, i.e., we obtain samples from the posterior. |
| 164 | + # For normally-distributed data, the posterior can be computed in closed form. |
| 165 | + # For general models, however, typically samples will be generated using MCMC with Turing. |
| 166 | + posterior = Normal(mean(train), 1) |
| 167 | + samples = rand(rng, posterior, nsamples) |
| 168 | +
|
| 169 | + # Evaluation on the validation set. |
| 170 | + validation_model = gdemo(length(validation)) | (x=validation,) |
| 171 | + loss += sum(samples) do sample |
| 172 | + logjoint(validation_model, (μ=sample,)) |
| 173 | + end |
| 174 | + end |
| 175 | +
|
| 176 | + return loss |
| 177 | +end |
| 178 | +
|
| 179 | +cross_val(dataset) |
| 180 | +``` |
| 181 | + |
| 182 | +[^1]: See [ParetoSmooth.jl](https://github.com/TuringLang/ParetoSmooth.jl) for a faster and more accurate implementation of cross-validation than the one provided here. |
0 commit comments