Skip to content

Commit eacb2bf

Browse files
committed
Additional edits
1 parent 7c214ef commit eacb2bf

File tree

2 files changed

+459
-276
lines changed

2 files changed

+459
-276
lines changed

examples/gaussian_processes/GP-Heteroskedastic.ipynb

Lines changed: 369 additions & 220 deletions
Large diffs are not rendered by default.

examples/gaussian_processes/GP-Heteroskedastic.myst.md

Lines changed: 90 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ jupytext:
55
format_name: myst
66
format_version: 0.13
77
kernelspec:
8-
display_name: Python [conda env:pymc3]
8+
display_name: default
99
language: python
10-
name: conda-env-pymc3-py
10+
name: python3
1111
---
1212

1313
# Heteroskedastic Gaussian Processes
@@ -25,14 +25,18 @@ This notebook will work through several approaches to heteroskedastic modeling w
2525
## Data
2626

2727
```{code-cell} ipython3
28+
import warnings
29+
2830
import arviz as az
2931
import matplotlib.pyplot as plt
3032
import numpy as np
31-
import pymc3 as pm
32-
import theano.tensor as tt
33+
import pymc as pm
34+
import pytensor.tensor as pt
3335
3436
from scipy.spatial.distance import pdist
3537
38+
warnings.filterwarnings("ignore", category=UserWarning)
39+
3640
%config InlineBackend.figure_format ='retina'
3741
%load_ext watermark
3842
```
@@ -56,9 +60,9 @@ X = np.linspace(0.1, 1, 20)[:, None]
5660
X = np.vstack([X, X + 2])
5761
X_ = X.flatten()
5862
y = signal(X_)
59-
σ_fun = noise(y)
63+
sigma_fun = noise(y)
6064
61-
y_err = rng.lognormal(np.log(σ_fun), 0.1)
65+
y_err = rng.lognormal(np.log(sigma_fun), 0.1)
6266
y_obs = rng.normal(y, y_err, size=(5, len(y)))
6367
y_obs_ = y_obs.T.flatten()
6468
X_obs = np.tile(X.T, (5, 1)).T.reshape(-1, 1)
@@ -70,24 +74,24 @@ Xnew_ = Xnew.flatten()
7074
ynew = signal(Xnew)
7175
7276
plt.plot(X, y, "C0o")
73-
plt.errorbar(X_, y, y_err, color="C0")
77+
plt.errorbar(X_, y, y_err, color="C0");
7478
```
7579

7680
## Helper and plotting functions
7781

7882
```{code-cell} ipython3
79-
def get_ℓ_prior(points):
83+
def get_ell_prior(points):
8084
"""Calculates mean and sd for InverseGamma prior on lengthscale"""
8185
distances = pdist(points[:, None])
8286
distinct = distances != 0
83-
ℓ_l = distances[distinct].min() if sum(distinct) > 0 else 0.1
84-
ℓ_u = distances[distinct].max() if sum(distinct) > 0 else 1
85-
ℓ_σ = max(0.1, (ℓ_u - ℓ_l) / 6)
86-
ℓ_μ = ℓ_l + 3 * ℓ_σ
87-
return ℓ_μ, ℓ_σ
87+
ell_l = distances[distinct].min() if sum(distinct) > 0 else 0.1
88+
ell_u = distances[distinct].max() if sum(distinct) > 0 else 1
89+
ell_sigma = max(0.1, (ell_u - ell_l) / 6)
90+
ell_mu = ell_l + 3 * ell_sigma
91+
return ell_mu, ell_sigma
8892
8993
90-
ℓ_μ, ℓ_σ = [stat for stat in get_ℓ_prior(X_)]
94+
ell_mu, ell_sigma = [stat for stat in get_ell_prior(X_)]
9195
```
9296

9397
```{code-cell} ipython3
@@ -179,35 +183,40 @@ def plot_total(ax, mean_samples, var_samples=None, bootstrap=True, n_boots=100):
179183

180184
+++
181185

182-
First let's fit a standard homoskedastic GP using PyMC3's `Marginal Likelihood` implementation. Here and throughout this notebook we'll use an informative prior for length scale as suggested by [Michael Betancourt](https://betanalpha.github.io/assets/case_studies/gp_part3/part3.html#4_adding_an_informative_prior_for_the_length_scale). We could use `pm.find_MAP()` and `.predict`for even faster inference and prediction, with similar results, but for direct comparison to the other models we'll use NUTS and `.conditional` instead, which run fast enough.
186+
First let's fit a standard homoskedastic GP using PyMC's `Marginal Likelihood` implementation. Here and throughout this notebook we'll use an informative prior for length scale as suggested by [Michael Betancourt](https://betanalpha.github.io/assets/case_studies/gp_part3/part3.html#4_adding_an_informative_prior_for_the_length_scale). We could use `pm.find_MAP()` and `.predict`for even faster inference and prediction, with similar results, but for direct comparison to the other models we'll use NUTS and `.conditional` instead, which run fast enough.
183187

184188
```{code-cell} ipython3
185189
with pm.Model() as model_hm:
186-
= pm.InverseGamma("", mu=ℓ_μ, sigma=ℓ_σ)
187-
η = pm.Gamma("η", alpha=2, beta=1)
188-
cov = η**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=)
190+
ell = pm.InverseGamma("ell", mu=ell_mu, sigma=ell_sigma)
191+
eta = pm.Gamma("eta", alpha=2, beta=1)
192+
cov = eta**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=ell)
189193
190194
gp_hm = pm.gp.Marginal(cov_func=cov)
191195
192-
σ = pm.Exponential("σ", lam=1)
196+
sigma = pm.Exponential("sigma", lam=1)
193197
194-
ml_hm = gp_hm.marginal_likelihood("ml_hm", X=X_obs, y=y_obs_, noise=σ)
198+
ml_hm = gp_hm.marginal_likelihood("ml_hm", X=X_obs, y=y_obs_, noise=sigma)
195199
196-
trace_hm = pm.sample(return_inferencedata=True, random_seed=SEED)
200+
trace_hm = pm.sample(random_seed=SEED)
197201
198202
with model_hm:
199203
mu_pred_hm = gp_hm.conditional("mu_pred_hm", Xnew=Xnew)
200204
noisy_pred_hm = gp_hm.conditional("noisy_pred_hm", Xnew=Xnew, pred_noise=True)
201-
samples_hm = pm.sample_posterior_predictive(trace_hm, var_names=["mu_pred_hm", "noisy_pred_hm"])
205+
pm.sample_posterior_predictive(
206+
trace_hm,
207+
var_names=["mu_pred_hm", "noisy_pred_hm"],
208+
extend_inferencedata=True,
209+
predictions=True,
210+
)
202211
```
203212

204213
```{code-cell} ipython3
205214
_, axs = plt.subplots(1, 3, figsize=(18, 4))
206-
mu_samples = samples_hm["mu_pred_hm"]
207-
noisy_samples = samples_hm["noisy_pred_hm"]
208-
plot_mean(axs[0], mu_samples)
209-
plot_var(axs[1], noisy_samples.var(axis=0))
210-
plot_total(axs[2], noisy_samples)
215+
mu_samples = az.extract(trace_hm.predictions["mu_pred_hm"])["mu_pred_hm"]
216+
noisy_samples = az.extract(trace_hm.predictions["noisy_pred_hm"])["noisy_pred_hm"]
217+
plot_mean(axs[0], mu_samples.T)
218+
plot_var(axs[1], noisy_samples.var(dim=["sample"]))
219+
plot_total(axs[2], noisy_samples.T)
211220
```
212221

213222
Here we've plotted our understanding of the mean behavior with the corresponding epistemic uncertainty on the left, our understanding of the variance or aleatoric uncertainty in the middle, and integrate all sources of uncertainty on the right. This model captures the mean behavior well, but we can see that it overestimates the noise in the lower regime while underestimating the noise in the upper regime, as expected.
@@ -222,28 +231,32 @@ The simplest approach to modeling a heteroskedastic system is to fit a GP on the
222231

223232
```{code-cell} ipython3
224233
with pm.Model() as model_wt:
225-
= pm.InverseGamma("", mu=ℓ_μ, sigma=ℓ_σ)
226-
η = pm.Gamma("η", alpha=2, beta=1)
227-
cov = η**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=)
234+
ell = pm.InverseGamma("ell", mu=ell_mu, sigma=ell_sigma)
235+
eta = pm.Gamma("eta", alpha=2, beta=1)
236+
cov = eta**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=ell)
228237
229238
gp_wt = pm.gp.Marginal(cov_func=cov)
230239
240+
# Using the observed noise now
231241
ml_wt = gp_wt.marginal_likelihood("ml_wt", X=X, y=y, noise=y_err)
232242
233243
trace_wt = pm.sample(return_inferencedata=True, random_seed=SEED)
234244
235245
with model_wt:
236246
mu_pred_wt = gp_wt.conditional("mu_pred_wt", Xnew=Xnew)
237-
samples_wt = pm.sample_posterior_predictive(trace_wt, var_names=["mu_pred_wt"])
247+
pm.sample_posterior_predictive(
248+
trace_wt, var_names=["mu_pred_wt"], extend_inferencedata=True, predictions=True
249+
)
238250
```
239251

240252
```{code-cell} ipython3
241253
_, axs = plt.subplots(1, 3, figsize=(18, 4))
242-
mu_samples = samples_wt["mu_pred_wt"]
243-
plot_mean(axs[0], mu_samples)
254+
mu_samples = az.extract(trace_wt.predictions["mu_pred_wt"])["mu_pred_wt"]
255+
256+
plot_mean(axs[0], mu_samples.T)
244257
axs[0].errorbar(X_, y, y_err, ls="none", color="C1", label="STDEV")
245-
plot_var(axs[1], mu_samples.var(axis=0))
246-
plot_total(axs[2], mu_samples)
258+
plot_var(axs[1], mu_samples.var(dim=["sample"]))
259+
plot_total(axs[2], mu_samples.T)
247260
```
248261

249262
This approach captured slightly more nuance in the overall uncertainty than the homoskedastic GP, but still underestimated the variance within both the observed regimes. Note that the variance displayed by this model is purely epistemic: our understanding of the mean behavior is weighted by the uncertainty in our observations, but we didn't include a component to account for aleatoric noise.
@@ -254,33 +267,48 @@ This approach captured slightly more nuance in the overall uncertainty than the
254267

255268
+++
256269

257-
Now let's model the mean and the log of the variance as separate GPs through PyMC3's `Latent` implementation, feeding both into a `Normal` likelihood. Note that we add a small amount of diagonal noise to the individual covariances in order to stabilize them for inversion.
270+
Now let's model the mean and the log of the variance as separate GPs through PyMC's `Latent` implementation, feeding both into a `Normal` likelihood. Note that we add a small amount of diagonal noise to the individual covariances in order to stabilize them for inversion.
271+
272+
The `Latent` parameterization takes signifiantly longer to sample than the `Marginal` approach, so we are going to accerelate the sampling with the Numpyro NUTS sampler.
258273

259274
```{code-cell} ipython3
260275
with pm.Model() as model_ht:
261-
= pm.InverseGamma("", mu=ℓ_μ, sigma=ℓ_σ)
262-
η = pm.Gamma("η", alpha=2, beta=1)
263-
cov = η**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=) + pm.gp.cov.WhiteNoise(sigma=1e-6)
276+
ell = pm.InverseGamma("ell", mu=ell_mu, sigma=ell_sigma)
277+
eta = pm.Gamma("eta", alpha=2, beta=1)
278+
cov = eta**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=ell) + pm.gp.cov.WhiteNoise(sigma=1e-6)
264279
265280
gp_ht = pm.gp.Latent(cov_func=cov)
266-
μ_f = gp_ht.prior("μ_f", X=X_obs)
281+
mu_f = gp_ht.prior("mu_f", X=X_obs)
267282
268-
σ_ℓ = pm.InverseGamma("σ_ℓ", mu=ℓ_μ, sigma=ℓ_σ)
269-
σ_η = pm.Gamma("σ_η", alpha=2, beta=1)
270-
σ_cov = σ_η**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=σ_ℓ) + pm.gp.cov.WhiteNoise(sigma=1e-6)
283+
sigma_ell = pm.InverseGamma("sigma_ell", mu=ell_mu, sigma=ell_sigma)
284+
sigma_eta = pm.Gamma("sigma_eta", alpha=2, beta=1)
285+
sigma_cov = sigma_eta**2 * pm.gp.cov.ExpQuad(
286+
input_dim=1, ls=sigma_ell
287+
) + pm.gp.cov.WhiteNoise(sigma=1e-6)
271288
272-
σ_gp = pm.gp.Latent(cov_func=σ_cov)
273-
lg_σ_f = σ_gp.prior("lg_σ_f", X=X_obs)
274-
σ_f = pm.Deterministic("σ_f", pm.math.exp(lg_σ_f))
289+
sigma_gp = pm.gp.Latent(cov_func=sigma_cov)
290+
lg_sigma_f = sigma_gp.prior("lg_sigma_f", X=X_obs)
291+
sigma_f = pm.Deterministic("sigma_f", pm.math.exp(lg_sigma_f))
275292
276-
lik_ht = pm.Normal("lik_ht", mu=μ_f, sd=σ_f, observed=y_obs_)
293+
lik_ht = pm.Normal("lik_ht", mu=mu_f, sigma=sigma_f, observed=y_obs_)
277294
278-
trace_ht = pm.sample(target_accept=0.95, chains=2, return_inferencedata=True, random_seed=SEED)
295+
trace_ht = pm.sample(
296+
target_accept=0.95,
297+
chains=2,
298+
nuts_sampler="nutpie",
299+
return_inferencedata=True,
300+
random_seed=SEED,
301+
)
279302
280303
with model_ht:
281-
μ_pred_ht = gp_ht.conditional("μ_pred_ht", Xnew=Xnew)
282-
lg_σ_pred_ht = σ_gp.conditional("lg_σ_pred_ht", Xnew=Xnew)
283-
samples_ht = pm.sample_posterior_predictive(trace_ht, var_names=["μ_pred_ht", "lg_σ_pred_ht"])
304+
mu_pred_ht = gp_ht.conditional("mu_pred_ht", Xnew=Xnew)
305+
lg_sigma_pred_ht = sigma_gp.conditional("lg_sigma_pred_ht", Xnew=Xnew)
306+
pm.sample_posterior_predictive(
307+
trace_ht,
308+
var_names=["mu_pred_ht", "lg_sigma_pred_ht"],
309+
extend_inferencedata=True,
310+
predictions=True,
311+
)
284312
```
285313

286314
```{code-cell} ipython3
@@ -292,6 +320,16 @@ plot_var(axs[1], σ_samples**2)
292320
plot_total(axs[2], μ_samples, σ_samples**2)
293321
```
294322

323+
```{code-cell} ipython3
324+
_, axs = plt.subplots(1, 3, figsize=(18, 4))
325+
mu_samples = az.extract(trace_ht.predictions["mu_pred_ht"])["mu_pred_ht"]
326+
sigma_samples = np.exp(az.extract(trace_ht.predictions["lg_sigma_pred_ht"])["lg_sigma_pred_ht"])
327+
328+
plot_mean(axs[0], mu_samples.T)
329+
plot_var(axs[1], sigma_samples.T**2)
330+
plot_total(axs[2], mu_samples.T, sigma_samples.T**2)
331+
```
332+
295333
That looks much better! We've accurately captured the mean behavior of our system along with an understanding of the underlying trend in the variance, with appropriate uncertainty. Crucially, the aggregate behavior of the model integrates both epistemic *and* aleatoric uncertainty, and the ~5% of our observations fall outside the 2σ band are more or less evenly distributed across the domain. However, that took *over two hours* to sample only 4k NUTS iterations. Due to the expense of the requisite matrix inversions, GPs are notoriously inefficient for large data sets. Let's reformulate this model using a sparse approximation.
296334

297335
+++
@@ -300,7 +338,7 @@ That looks much better! We've accurately captured the mean behavior of our syste
300338

301339
+++
302340

303-
Sparse approximations to GPs use a small set of *inducing points* to condition the model, vastly improve speed of inference and somewhat improving memory consumption. PyMC3 doesn't have an implementation for sparse latent GPs ([yet](https://github.com/pymc-devs/pymc3/pull/2951)), but we can throw together our own real quick using Bill Engel's [DTC latent GP example](https://gist.github.com/bwengals/a0357d75d2083657a2eac85947381a44). These inducing points can be specified in a variety of ways, such as via the popular k-means initialization or even optimized as part of the model, but since our observations are evenly distributed we can make do with simply a subset of our unique input values.
341+
Sparse approximations to GPs use a small set of *inducing points* to condition the model, vastly improve speed of inference and somewhat improving memory consumption. PyMC doesn't have an implementation for sparse latent GPs yet, but we can throw together our own real quick using Bill Engel's [DTC latent GP example](https://gist.github.com/bwengals/a0357d75d2083657a2eac85947381a44). These inducing points can be specified in a variety of ways, such as via the popular k-means initialization or even optimized as part of the model, but since our observations are evenly distributed we can make do with simply a subset of our unique input values.
304342

305343
```{code-cell} ipython3
306344
class SparseLatent:
@@ -531,7 +569,3 @@ display(az.summary(trace_htsc).sort_values("ess_bulk").iloc[:5])
531569
```{code-cell} ipython3
532570
%watermark -n -u -v -iv -w -p xarray
533571
```
534-
535-
```{code-cell} ipython3
536-
537-
```

0 commit comments

Comments
 (0)