|
| 1 | +using MLDatasets: MNIST |
| 2 | +using MultivariateStats: fit, PCA, transform |
| 3 | + |
| 4 | +# Load MNIST images and labels |
| 5 | +features = MNIST(split=:train).features |
| 6 | +nrows, ncols, nimages = size(features) |
| 7 | +image_raw = Float64.(reshape(features, (nrows * ncols, nimages))) |
| 8 | +labels = MNIST(split=:train).targets .+ 1 |
| 9 | +C = 10 # Number of labels |
| 10 | + |
| 11 | +# Preprocess the images by reducing dimensionality |
| 12 | +D = 40 |
| 13 | +pca = fit(PCA, image_raw; maxoutdim=D) |
| 14 | +image = transform(pca, image_raw) |
| 15 | + |
| 16 | +# Take only the first 1000 images and vectorise |
| 17 | +N = 1000 |
| 18 | +image_subset = image[:, 1:N]' |
| 19 | +image_vec = vec(image_subset[:, :]) |
| 20 | +labels = labels[1:N] |
| 21 | + |
| 22 | +@model dppl_naive_bayes(image_vec, labels, C, D) = begin |
| 23 | + m ~ filldist(Normal(0, 10), C, D) |
| 24 | + image_vec ~ MvNormal(vec(m[labels, :]), I) |
| 25 | +end |
| 26 | + |
| 27 | +@register dppl_naive_bayes(image_vec, labels, C, D) |
0 commit comments