Skip to content

Commit 475a1b4

Browse files
authored
Improve Kronecker example (pymc-devs#653)
* Improve Kronecker example * Use HalfNormal instead of Truncated for eta and sigma
1 parent 0ef1220 commit 475a1b4

File tree

2 files changed

+473
-285
lines changed

2 files changed

+473
-285
lines changed

examples/gaussian_processes/GP-Kron.ipynb

Lines changed: 414 additions & 252 deletions
Large diffs are not rendered by default.

examples/gaussian_processes/GP-Kron.myst.md

Lines changed: 59 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ jupytext:
55
format_name: myst
66
format_version: 0.13
77
kernelspec:
8-
display_name: Python 3 (ipykernel)
8+
display_name: pymc-examples
99
language: python
10-
name: python3
10+
name: pymc-examples
1111
---
1212

1313
(GP-Kron)=
@@ -16,7 +16,7 @@ kernelspec:
1616
:::{post} October, 2022
1717
:tags: gaussian process
1818
:category: intermediate
19-
:author: Bill Engels, Raul-ing Average, Christopher Krapu, Danh Phan
19+
:author: Bill Engels, Raul-ing Average, Christopher Krapu, Danh Phan, Alex Andorra
2020
:::
2121

2222
+++
@@ -59,16 +59,18 @@ import arviz as az
5959
import matplotlib as mpl
6060
import numpy as np
6161
import pymc as pm
62+
```
6263

64+
```{code-cell} ipython3
65+
az.style.use("arviz-whitegrid")
6366
plt = mpl.pyplot
6467
%matplotlib inline
6568
%config InlineBackend.figure_format = 'retina'
69+
seed = sum(map(ord, "gpkron"))
70+
rng = np.random.default_rng(seed)
6671
```
6772

6873
```{code-cell} ipython3
69-
RANDOM_SEED = 12345
70-
rng = np.random.default_rng(RANDOM_SEED)
71-
7274
# One dimensional column vectors of inputs
7375
n1, n2 = (50, 30)
7476
x1 = np.linspace(0, 5, n1)
@@ -91,7 +93,7 @@ cov = (
9193
K = cov(X).eval()
9294
f_true = rng.multivariate_normal(np.zeros(X.shape[0]), K, 1).flatten()
9395
94-
sigma_true = 0.25
96+
sigma_true = 0.5
9597
y = f_true + sigma_true * rng.standard_normal(X.shape[0])
9698
```
9799

@@ -174,61 +176,84 @@ plt.title("observed data 'y' (circles) with predicted mean (squares)");
174176

175177
Like the `gp.Latent` implementation, the `gp.LatentKron` implementation specifies a Kronecker structured GP regardless of context. **It can be used with any likelihood function, or can be used to model a variance or some other unobserved processes**. The syntax follows that of `gp.Latent` exactly.
176178

177-
### Example 1
179+
### Model
178180

179181
To compare with `MarginalLikelihood`, we use same example as before where the noise is normal, but the GP itself is not marginalized out. Instead, it is sampled directly using NUTS. It is very important to note that `gp.LatentKron` does not require a Gaussian likelihood like `gp.MarginalKron`; rather, any likelihood is admissible.
180182

183+
Here though, we'll need to be more informative for our priors, at least those for the GP hyperparameters. This is a general rule when using GPs: **use as informative priors as you can**, as sampling lenghtscale and amplitude is a challenging business, so you want to make the sampler's work as easy as possible.
184+
185+
Here thankfully, we have a lot of information about our amplitude and lenghtscales -- we're the ones who created them ;) So we could fix them, but we'll show how you could code that prior knowledge in your own models, with, e.g, Truncated Normal distributions:
186+
181187
```{code-cell} ipython3
182188
with pm.Model() as model:
183189
# Set priors on the hyperparameters of the covariance
184-
ls1 = pm.Gamma("ls1", alpha=2, beta=2)
185-
ls2 = pm.Gamma("ls2", alpha=2, beta=2)
186-
eta = pm.HalfNormal("eta", sigma=2)
190+
ls1 = pm.TruncatedNormal("ls1", lower=0.5, upper=1.5, mu=1, sigma=0.5)
191+
ls2 = pm.TruncatedNormal("ls2", lower=0.5, upper=1.5, mu=1, sigma=0.5)
192+
eta = pm.HalfNormal("eta", sigma=0.5)
187193
188194
# Specify the covariance functions for each Xi
189195
cov_x1 = pm.gp.cov.Matern52(1, ls=ls1)
190196
cov_x2 = eta**2 * pm.gp.cov.Cosine(1, ls=ls2)
191197
192-
# Set the prior on the variance for the Gaussian noise
193-
sigma = pm.HalfNormal("sigma", sigma=2)
194-
195-
# Specify the GP. The default mean function is `Zero`.
198+
# Specify the GP. The default mean function is `Zero`
196199
gp = pm.gp.LatentKron(cov_funcs=[cov_x1, cov_x2])
197200
198-
# Place a GP prior over the function f.
201+
# Place a GP prior over the function f
199202
f = gp.prior("f", Xs=Xs)
200203
204+
# Set the prior on the variance for the Gaussian noise
205+
sigma = pm.HalfNormal("sigma", sigma=0.5)
206+
201207
y_ = pm.Normal("y_", mu=f, sigma=sigma, observed=y)
202208
```
203209

210+
```{code-cell} ipython3
211+
pm.model_to_graphviz(model)
212+
```
213+
204214
```{code-cell} ipython3
205215
with model:
206-
tr = pm.sample(500, chains=1, return_inferencedata=True, target_accept=0.90)
216+
idata = pm.sample(nuts_sampler="numpyro", target_accept=0.9, tune=1500, draws=1500)
207217
```
208218

209-
The posterior distribution of the unknown lengthscale parameters, covariance scaling `eta`, and white noise `sigma` are shown below. The vertical lines are the true values that were used to generate the original data set.
219+
```{code-cell} ipython3
220+
idata.sample_stats.diverging.sum().data
221+
```
222+
223+
### Posterior convergence
224+
225+
+++
226+
227+
The posterior distribution of the unknown lengthscale parameters, covariance scaling `eta`, and white noise `sigma` are shown below. The vertical lines are the true values that were used to generate the original data set:
210228

211229
```{code-cell} ipython3
212-
az.plot_trace(
213-
tr,
214-
var_names=["ls1", "ls2", "eta", "sigma"],
215-
lines={"ls1": l1_true, "ls2": l2_true, "eta": eta_true, "sigma": sigma_true},
216-
)
217-
plt.tight_layout()
230+
var_names = ["ls1", "ls2", "eta", "sigma"]
218231
```
219232

220233
```{code-cell} ipython3
221-
x1new = np.linspace(5.1, 7.1, 20)
222-
x2new = np.linspace(-0.5, 3.5, 40)
223-
Xnew = pm.math.cartesian(x1new[:, None], x2new[:, None])
234+
az.plot_posterior(
235+
idata,
236+
var_names=var_names,
237+
ref_val=[l1_true, l2_true, eta_true, sigma_true],
238+
grid=(2, 2),
239+
figsize=(12, 6),
240+
);
241+
```
224242

225-
with model:
226-
fnew = gp.conditional("fnew3", Xnew, jitter=1e-6)
243+
We can see how challenging sampling can be in these situations. Here, all went well because we were careful with our choice of priors -- especially in this simulated case, where parameters don't have a real interpretation.
227244

228-
with model:
229-
ppc = pm.sample_posterior_predictive(tr, var_names=["fnew3"])
245+
What does the trace plot looks like?
246+
247+
```{code-cell} ipython3
248+
az.plot_trace(idata, var_names=var_names);
230249
```
231250

251+
All good, so let's go ahead with out-of-sample predictions!
252+
253+
+++
254+
255+
### Out-of-sample predictions
256+
232257
```{code-cell} ipython3
233258
x1new = np.linspace(5.1, 7.1, 20)[:, None]
234259
x2new = np.linspace(-0.5, 3.5, 40)[:, None]
@@ -243,7 +268,7 @@ with model:
243268

244269
```{code-cell} ipython3
245270
with model:
246-
ppc = pm.sample_posterior_predictive(tr, var_names=["fnew"])
271+
ppc = pm.sample_posterior_predictive(idata, var_names=["fnew"], compile_kwargs={"mode": "JAX"})
247272
```
248273

249274
Below we show the original data set as colored circles, and the mean of the conditional samples as colored squares. The results closely follow those given by the `gp.MarginalKron` implementation.
@@ -291,14 +316,15 @@ for i, ax in enumerate(axs):
291316
* Updated by [Raul-ing Average](https://github.com/CloudChaoszero), March 2021
292317
* Updated by [Christopher Krapu](https://github.com/ckrapu), July 2021
293318
* Updated to PyMC 4.x by [Danh Phan](https://github.com/danhphan), November 2022
319+
* Updated with some new plots and priors, by [Alex Andorra](https://github.com/AlexAndorra), April 2024
294320

295321
+++
296322

297323
## Watermark
298324

299325
```{code-cell} ipython3
300326
%load_ext watermark
301-
%watermark -n -u -v -iv -w -p pytensor,aeppl,xarray
327+
%watermark -n -u -v -iv -w -p pytensor,xarray
302328
```
303329

304330
:::{include} ../page_footer.md

0 commit comments

Comments
 (0)