Skip to content

Commit 4d77dab

Browse files
committed
Updates to several models
1 parent eacb2bf commit 4d77dab

File tree

2 files changed

+411
-476
lines changed

2 files changed

+411
-476
lines changed

examples/gaussian_processes/GP-Heteroskedastic.ipynb

Lines changed: 316 additions & 417 deletions
Large diffs are not rendered by default.

examples/gaussian_processes/GP-Heteroskedastic.myst.md

Lines changed: 95 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,9 @@ def plot_total(ax, mean_samples, var_samples=None, bootstrap=True, n_boots=100):
156156
# Estimate the aggregate behavior using samples from each normal distribution in the posterior
157157
samples = (
158158
rng.normal(
159-
mean_samples.T[:, :, None],
160-
np.sqrt(var_samples).T[:, :, None],
161-
(*mean_samples.T.shape, n_boots),
159+
mean_samples.values.T[..., None],
160+
np.sqrt(var_samples.values).T[..., None],
161+
(*mean_samples.values.T.shape, n_boots),
162162
)
163163
.reshape(len(Xnew_), -1)
164164
.T
@@ -269,7 +269,7 @@ This approach captured slightly more nuance in the overall uncertainty than the
269269

270270
Now let's model the mean and the log of the variance as separate GPs through PyMC's `Latent` implementation, feeding both into a `Normal` likelihood. Note that we add a small amount of diagonal noise to the individual covariances in order to stabilize them for inversion.
271271

272-
The `Latent` parameterization takes signifiantly longer to sample than the `Marginal` approach, so we are going to accerelate the sampling with the Numpyro NUTS sampler.
272+
The `Latent` parameterization takes signifiantly longer to sample than the `Marginal` model, so we are going to accerelate the sampling with the Numpyro NUTS sampler.
273273

274274
```{code-cell} ipython3
275275
with pm.Model() as model_ht:
@@ -295,7 +295,7 @@ with pm.Model() as model_ht:
295295
trace_ht = pm.sample(
296296
target_accept=0.95,
297297
chains=2,
298-
nuts_sampler="nutpie",
298+
nuts_sampler="numpyro",
299299
return_inferencedata=True,
300300
random_seed=SEED,
301301
)
@@ -311,15 +311,6 @@ with model_ht:
311311
)
312312
```
313313

314-
```{code-cell} ipython3
315-
_, axs = plt.subplots(1, 3, figsize=(18, 4))
316-
μ_samples = samples_ht["μ_pred_ht"]
317-
σ_samples = np.exp(samples_ht["lg_σ_pred_ht"])
318-
plot_mean(axs[0], μ_samples)
319-
plot_var(axs[1], σ_samples**2)
320-
plot_total(axs[2], μ_samples, σ_samples**2)
321-
```
322-
323314
```{code-cell} ipython3
324315
_, axs = plt.subplots(1, 3, figsize=(18, 4))
325316
mu_samples = az.extract(trace_ht.predictions["mu_pred_ht"])["mu_pred_ht"]
@@ -330,7 +321,7 @@ plot_var(axs[1], sigma_samples.T**2)
330321
plot_total(axs[2], mu_samples.T, sigma_samples.T**2)
331322
```
332323

333-
That looks much better! We've accurately captured the mean behavior of our system along with an understanding of the underlying trend in the variance, with appropriate uncertainty. Crucially, the aggregate behavior of the model integrates both epistemic *and* aleatoric uncertainty, and the ~5% of our observations fall outside the 2σ band are more or less evenly distributed across the domain. However, that took *over two hours* to sample only 4k NUTS iterations. Due to the expense of the requisite matrix inversions, GPs are notoriously inefficient for large data sets. Let's reformulate this model using a sparse approximation.
324+
That looks much better! We've accurately captured the mean behavior of our system, as well as the underlying trend in the variance (with appropriate uncertainty). Crucially, the aggregate behavior of the model integrates both epistemic *and* aleatoric uncertainty, and the ~5% of our observations fall outside the 2σ band are more or less evenly distributed across the domain. However, even with the Numpyro sampler, this took nearly an hour on a Ryen 7040 laptop to sample only 4k NUTS iterations. Due to the expense of the requisite matrix inversions, GPs are notoriously inefficient for large data sets. Let's reformulate this model using a sparse approximation.
334325

335326
+++
336327

@@ -347,25 +338,25 @@ class SparseLatent:
347338
348339
def prior(self, name, X, Xu):
349340
Kuu = self.cov(Xu)
350-
self.L = pm.gp.util.cholesky(pm.gp.util.stabilize(Kuu))
341+
self.L = pt.linalg.cholesky(pm.gp.util.stabilize(Kuu))
351342
352-
self.v = pm.Normal(f"u_rotated_{name}", mu=0.0, sd=1.0, shape=len(Xu))
353-
self.u = pm.Deterministic(f"u_{name}", tt.dot(self.L, self.v))
343+
self.v = pm.Normal(f"u_rotated_{name}", mu=0.0, sigma=1.0, shape=len(Xu))
344+
self.u = pm.Deterministic(f"u_{name}", pt.dot(self.L, self.v))
354345
355346
Kfu = self.cov(X, Xu)
356-
self.Kuiu = tt.slinalg.solve_upper_triangular(
357-
self.L.T, tt.slinalg.solve_lower_triangular(self.L, self.u)
347+
self.Kuiu = pt.slinalg.solve_triangular(
348+
self.L.T, pt.slinalg.solve_triangular(self.L, self.u, lower=True), lower=False
358349
)
359-
self.mu = pm.Deterministic(f"mu_{name}", tt.dot(Kfu, self.Kuiu))
350+
self.mu = pm.Deterministic(f"mu_{name}", pt.dot(Kfu, self.Kuiu))
360351
return self.mu
361352
362353
def conditional(self, name, Xnew, Xu):
363354
Ksu = self.cov(Xnew, Xu)
364-
mus = tt.dot(Ksu, self.Kuiu)
365-
tmp = tt.slinalg.solve_lower_triangular(self.L, Ksu.T)
366-
Qss = tt.dot(tmp.T, tmp) # Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T)
355+
mus = pt.dot(Ksu, self.Kuiu)
356+
tmp = pt.slinalg.solve_triangular(self.L, Ksu.T, lower=True)
357+
Qss = pt.dot(tmp.T, tmp)
367358
Kss = self.cov(Xnew)
368-
Lss = pm.gp.util.cholesky(pm.gp.util.stabilize(Kss - Qss))
359+
Lss = pt.linalg.cholesky(pm.gp.util.stabilize(Kss - Qss))
369360
mu_pred = pm.MvNormal(name, mu=mus, chol=Lss, shape=len(Xnew))
370361
return mu_pred
371362
```
@@ -375,39 +366,51 @@ class SparseLatent:
375366
Xu = X[1::2]
376367
377368
with pm.Model() as model_hts:
378-
= pm.InverseGamma("", mu=ℓ_μ, sigma=ℓ_σ)
379-
η = pm.Gamma("η", alpha=2, beta=1)
380-
cov = η**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=)
369+
ell = pm.InverseGamma("ell", mu=ell_mu, sigma=ell_sigma)
370+
eta = pm.Gamma("eta", alpha=2, beta=1)
371+
cov = eta**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=ell)
381372
382-
μ_gp = SparseLatent(cov)
383-
μ_f = μ_gp.prior("μ", X_obs, Xu)
373+
mu_gp = SparseLatent(cov)
374+
mu_f = mu_gp.prior("mu", X_obs, Xu)
384375
385-
σ_ℓ = pm.InverseGamma("σ_ℓ", mu=ℓ_μ, sigma=ℓ_σ)
386-
σ_η = pm.Gamma("σ_η", alpha=2, beta=1)
387-
σ_cov = σ_η**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=σ_ℓ)
376+
sigma_ell = pm.InverseGamma("sigma_ell", mu=ell_mu, sigma=ell_sigma)
377+
sigma_η = pm.Gamma("sigma_η", alpha=2, beta=1)
378+
sigma_cov = sigma_η**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=sigma_ell)
388379
389-
lg_σ_gp = SparseLatent(σ_cov)
390-
lg_σ_f = lg_σ_gp.prior("lg_σ_f", X_obs, Xu)
391-
σ_f = pm.Deterministic("σ_f", pm.math.exp(lg_σ_f))
380+
lg_sigma_gp = SparseLatent(sigma_cov)
381+
lg_sigma_f = lg_sigma_gp.prior("lg_sigma_f", X_obs, Xu)
382+
sigma_f = pm.Deterministic("sigma_f", pm.math.exp(lg_sigma_f))
392383
393-
lik_hts = pm.Normal("lik_hts", mu=μ_f, sd=σ_f, observed=y_obs_)
394-
trace_hts = pm.sample(target_accept=0.95, return_inferencedata=True, random_seed=SEED)
384+
lik_hts = pm.Normal("lik_hts", mu=mu_f, sigma=sigma_f, observed=y_obs_)
385+
trace_hts = pm.sample(
386+
target_accept=0.95,
387+
nuts_sampler="numpyro",
388+
chains=2,
389+
return_inferencedata=True,
390+
random_seed=SEED,
391+
)
395392
396393
with model_hts:
397-
μ_pred = μ_gp.conditional("μ_pred", Xnew, Xu)
398-
lg_σ_pred = lg_σ_gp.conditional("lg_σ_pred", Xnew, Xu)
399-
samples_hts = pm.sample_posterior_predictive(trace_hts, var_names=["μ_pred", "lg_σ_pred"])
394+
mu_pred = mu_gp.conditional("mu_pred", Xnew, Xu)
395+
lg_sigma_pred = lg_sigma_gp.conditional("lg_sigma_pred", Xnew, Xu)
396+
pm.sample_posterior_predictive(
397+
trace_hts,
398+
var_names=["mu_pred", "lg_sigma_pred"],
399+
extend_inferencedata=True,
400+
predictions=True,
401+
)
400402
```
401403

402404
```{code-cell} ipython3
403405
_, axs = plt.subplots(1, 3, figsize=(18, 4))
404-
μ_samples = samples_hts["μ_pred"]
405-
σ_samples = np.exp(samples_hts["lg_σ_pred"])
406-
plot_mean(axs[0], μ_samples)
406+
mu_samples = az.extract(trace_hts.predictions["mu_pred"])["mu_pred"]
407+
sigma_samples = np.exp(az.extract(trace_hts.predictions["lg_sigma_pred"])["lg_sigma_pred"])
408+
409+
plot_mean(axs[0], mu_samples.T)
407410
plot_inducing_points(axs[0])
408-
plot_var(axs[1], σ_samples**2)
411+
plot_var(axs[1], sigma_samples.T**2)
409412
plot_inducing_points(axs[1])
410-
plot_total(axs[2], μ_samples, σ_samples**2)
413+
plot_total(axs[2], mu_samples.T, sigma_samples.T**2)
411414
plot_inducing_points(axs[2])
412415
```
413416

@@ -429,31 +432,60 @@ def add_coreg_idx(x):
429432
Xu_c, X_obs_c, Xnew_c = [add_coreg_idx(x) for x in [Xu, X_obs, Xnew]]
430433
431434
with pm.Model() as model_htsc:
432-
= pm.InverseGamma("", mu=ℓ_μ, sigma=ℓ_σ)
433-
η = pm.Gamma("η", alpha=2, beta=1)
434-
EQcov = η**2 * pm.gp.cov.ExpQuad(input_dim=1, active_dims=[0], ls=)
435+
ell = pm.InverseGamma("ell", mu=ell_mu, sigma=ell_sigma)
436+
eta = pm.Gamma("eta", alpha=2, beta=1)
437+
cov = eta**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=ell)
435438
436439
D_out = 2 # two output dimensions, mean and variance
437440
rank = 2 # two basis GPs
438-
W = pm.Normal("W", mu=0, sd=3, shape=(D_out, rank), testval=np.full([D_out, rank], 0.1))
441+
W = pm.Normal("W", mu=0, sigma=3, shape=(D_out, rank), initval=np.full([D_out, rank], 0.1))
439442
kappa = pm.Gamma("kappa", alpha=1.5, beta=1, shape=D_out)
440443
coreg = pm.gp.cov.Coregion(input_dim=1, active_dims=[0], kappa=kappa, W=W)
441444
442-
cov = pm.gp.cov.Kron([EQcov, coreg])
445+
cov = pm.gp.cov.Kron([cov, coreg])
443446
444447
gp_LMC = SparseLatent(cov)
445448
LMC_f = gp_LMC.prior("LMC", X_obs_c, Xu_c)
446449
447-
μ_f = LMC_f[: len(y_obs_)]
448-
lg_σ_f = LMC_f[len(y_obs_) :]
449-
σ_f = pm.Deterministic("σ_f", pm.math.exp(lg_σ_f))
450+
mu_f = LMC_f[: len(y_obs_)]
451+
lg_sigma_f = LMC_f[len(y_obs_) :]
452+
sigma_f = pm.Deterministic("sigma_f", pm.math.exp(lg_sigma_f))
450453
451-
lik_htsc = pm.Normal("lik_htsc", mu=μ_f, sd=σ_f, observed=y_obs_)
452-
trace_htsc = pm.sample(target_accept=0.95, return_inferencedata=True, random_seed=SEED)
454+
lik_htsc = pm.Normal("lik_htsc", mu=mu_f, sigma=sigma_f, observed=y_obs_)
455+
trace_htsc = pm.sample(
456+
target_accept=0.95,
457+
chains=2,
458+
nuts_sampler="numpyro",
459+
return_inferencedata=True,
460+
random_seed=SEED,
461+
)
453462
454463
with model_htsc:
455464
c_mu_pred = gp_LMC.conditional("c_mu_pred", Xnew_c, Xu_c)
456-
samples_htsc = pm.sample_posterior_predictive(trace_htsc, var_names=["c_mu_pred"])
465+
pm.sample_posterior_predictive(
466+
trace_htsc, var_names=["c_mu_pred"], extend_inferencedata=True, predictions=True
467+
)
468+
```
469+
470+
```{code-cell} ipython3
471+
sigma_samples.shape
472+
```
473+
474+
```{code-cell} ipython3
475+
# μ_samples = samples_htsc["c_mu_pred"][:, : len(Xnew)]
476+
# σ_samples = np.exp(samples_htsc["c_mu_pred"][:, len(Xnew) :])
477+
mu_samples = az.extract(trace_htsc.predictions["c_mu_pred"])["c_mu_pred"][: len(Xnew)]
478+
sigma_samples = np.exp(az.extract(trace_htsc.predictions["c_mu_pred"])["c_mu_pred"])[len(Xnew) :]
479+
480+
_, axs = plt.subplots(1, 3, figsize=(18, 4))
481+
plot_mean(axs[0], mu_samples.T)
482+
plot_inducing_points(axs[0])
483+
plot_var(axs[1], sigma_samples.T**2)
484+
axs[1].set_ylim(-0.01, 0.2)
485+
axs[1].legend(loc="upper left")
486+
plot_inducing_points(axs[1])
487+
plot_total(axs[2], mu_samples.T, sigma_samples.T**2)
488+
plot_inducing_points(axs[2])
457489
```
458490

459491
```{code-cell} ipython3
@@ -478,13 +510,17 @@ with model_htsc:
478510
B_samples = pm.sample_posterior_predictive(trace_htsc, var_names=["W", "kappa"])
479511
```
480512

513+
```{code-cell} ipython3
514+
kappa.shape
515+
```
516+
481517
```{code-cell} ipython3
482518
# Keep in mind that the first dimension in all arrays is the sampling dimension
483-
W = B_samples["W"]
519+
W = az.extract(B_samples.posterior_predictive["W"])["W"].values.T
484520
W_T = np.swapaxes(W, 1, 2)
485521
WW_T = np.matmul(W, W_T)
486522
487-
kappa = B_samples["kappa"]
523+
kappa = az.extract(B_samples.posterior_predictive["kappa"])["kappa"].values.T
488524
I = np.tile(np.identity(2), [kappa.shape[0], 1, 1])
489525
# einsum is just a concise way of doing multiplication and summation over arbitrary axes
490526
diag_kappa = np.einsum("ij,ijk->ijk", kappa, I)

0 commit comments

Comments
 (0)