Skip to content

Commit 5a3f2f3

Browse files
committed
Trying to get multi-output working
1 parent d28a610 commit 5a3f2f3

File tree

2 files changed

+653
-130
lines changed

2 files changed

+653
-130
lines changed

examples/gaussian_processes/GP-Heteroskedastic.ipynb

Lines changed: 568 additions & 112 deletions
Large diffs are not rendered by default.

examples/gaussian_processes/GP-Heteroskedastic.myst.md

Lines changed: 85 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ plot_var(axs[1], sigma_samples.T**2)
321321
plot_total(axs[2], mu_samples.T, sigma_samples.T**2)
322322
```
323323

324-
That looks much better! We've accurately captured the mean behavior of our system, as well as the underlying trend in the variance (with appropriate uncertainty). Crucially, the aggregate behavior of the model integrates both epistemic *and* aleatoric uncertainty, and the ~5% of our observations fall outside the 2σ band are more or less evenly distributed across the domain. However, even with the Numpyro sampler, this took nearly an hour on a Ryen 7040 laptop to sample only 4k NUTS iterations. Due to the expense of the requisite matrix inversions, GPs are notoriously inefficient for large data sets. Let's reformulate this model using a sparse approximation.
324+
That looks much better! We've accurately captured the mean behavior of our system, as well as the underlying trend in the variance (with appropriate uncertainty). Crucially, the aggregate behavior of the model integrates both epistemic *and* aleatoric uncertainty, and the ~5% of our observations fall outside the 2σ band are more or less evenly distributed across the domain. However, even with the Numpyro sampler, this took nearly an hour on a Ryzen 7040 laptop to sample only 4k NUTS iterations. Due to the expense of the requisite matrix inversions, GPs are notoriously inefficient for large data sets. Let's reformulate this model using a sparse approximation.
325325

326326
+++
327327

@@ -341,20 +341,20 @@ class SparseLatent:
341341
self.L = pt.linalg.cholesky(pm.gp.util.stabilize(Kuu))
342342
343343
self.v = pm.Normal(f"u_rotated_{name}", mu=0.0, sigma=1.0, shape=len(Xu))
344-
self.u = pm.Deterministic(f"u_{name}", pt.dot(self.L, self.v))
344+
self.u = pm.Deterministic(f"u_{name}", self.L @ self.v)
345345
346346
Kfu = self.cov(X, Xu)
347347
self.Kuiu = pt.slinalg.solve_triangular(
348348
self.L.T, pt.slinalg.solve_triangular(self.L, self.u, lower=True), lower=False
349349
)
350-
self.mu = pm.Deterministic(f"mu_{name}", pt.dot(Kfu, self.Kuiu))
350+
self.mu = pm.Deterministic(f"mu_{name}", Kfu @ self.Kuiu)
351351
return self.mu
352352
353353
def conditional(self, name, Xnew, Xu):
354354
Ksu = self.cov(Xnew, Xu)
355-
mus = pt.dot(Ksu, self.Kuiu)
355+
mus = Ksu @ self.Kuiu
356356
tmp = pt.slinalg.solve_triangular(self.L, Ksu.T, lower=True)
357-
Qss = pt.dot(tmp.T, tmp)
357+
Qss = tmp.T @ tmp
358358
Kss = self.cov(Xnew)
359359
Lss = pt.linalg.cholesky(pm.gp.util.stabilize(Kss - Qss))
360360
mu_pred = pm.MvNormal(name, mu=mus, chol=Lss, shape=len(Xnew))
@@ -374,8 +374,8 @@ with pm.Model() as model_hts:
374374
mu_f = mu_gp.prior("mu", X_obs, Xu)
375375
376376
sigma_ell = pm.InverseGamma("sigma_ell", mu=ell_mu, sigma=ell_sigma)
377-
sigma_η = pm.Gamma("sigma_η", alpha=2, beta=1)
378-
sigma_cov = sigma_η**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=sigma_ell)
377+
sigma_eta = pm.Gamma("sigma_eta", alpha=2, beta=1)
378+
sigma_cov = sigma_eta**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=sigma_ell)
379379
380380
lg_sigma_gp = SparseLatent(sigma_cov)
381381
lg_sigma_f = lg_sigma_gp.prior("lg_sigma_f", X_obs, Xu)
@@ -429,24 +429,75 @@ def add_coreg_idx(x):
429429
return np.hstack([np.tile(x, (2, 1)), np.vstack([np.zeros(x.shape), np.ones(x.shape)])])
430430
431431
432+
Xu = X[1::2]
432433
Xu_c, X_obs_c, Xnew_c = [add_coreg_idx(x) for x in [Xu, X_obs, Xnew]]
434+
```
433435

436+
```{code-cell} ipython3
434437
with pm.Model() as model_htsc:
435-
ell = pm.InverseGamma("ell", mu=ell_mu, sigma=ell_sigma)
436-
eta = pm.Gamma("eta", alpha=2, beta=1)
437-
cov = eta**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=ell)
438+
ell = pm.InverseGamma("ell", mu=ell_mu, sigma=ell_sigma, initval=1.0)
439+
eta = pm.Gamma("eta", alpha=2, beta=1, initval=1.0)
440+
eq_cov = eta**2 * pm.gp.cov.Matern52(input_dim=1, active_dims=[0], ls=ell)
438441
439442
D_out = 2 # two output dimensions, mean and variance
440443
rank = 2 # two basis GPs
441-
W = pm.Normal("W", mu=0, sigma=3, shape=(D_out, rank), initval=np.full([D_out, rank], 0.1))
444+
W = pm.Normal("W", mu=0, sigma=3, shape=(D_out, rank))
442445
kappa = pm.Gamma("kappa", alpha=1.5, beta=1, shape=D_out)
443446
coreg = pm.gp.cov.Coregion(input_dim=1, active_dims=[0], kappa=kappa, W=W)
444447
445-
cov = pm.gp.cov.Kron([cov, coreg])
448+
cov = pm.gp.cov.Kron([eq_cov, coreg])
446449
447450
gp_LMC = SparseLatent(cov)
448451
LMC_f = gp_LMC.prior("LMC", X_obs_c, Xu_c)
449452
453+
# First half of LMC_f is the mean, second half is the log-scale
454+
mu_f = LMC_f[: len(y_obs_)]
455+
lg_sigma_f = LMC_f[len(y_obs_) :]
456+
sigma_f = pm.Deterministic("sigma_f", pm.math.exp(lg_sigma_f))
457+
458+
lik_htsc = pm.Normal("lik_htsc", mu=mu_f, sigma=sigma_f, observed=y_obs_)
459+
trace_htsc = pm.sample(
460+
target_accept=0.9,
461+
chains=2,
462+
nuts_sampler="nutpie",
463+
return_inferencedata=True,
464+
random_seed=SEED,
465+
)
466+
467+
with model_htsc:
468+
c_mu_pred = gp_LMC.conditional("c_mu_pred", Xnew_c, Xu_c)
469+
pm.sample_posterior_predictive(
470+
trace_htsc, var_names=["c_mu_pred"], extend_inferencedata=True, predictions=True
471+
)
472+
```
473+
474+
```{code-cell} ipython3
475+
def get_icm(input_dim, kernel, W=None, kappa=None, B=None, active_dims=None):
476+
"""
477+
This function generates an ICM kernel from an input kernel and a Coregion kernel.
478+
"""
479+
coreg = pm.gp.cov.Coregion(input_dim=input_dim, W=W, kappa=kappa, B=B, active_dims=active_dims)
480+
icm_cov = kernel * coreg # Use Hadamard Product for separate inputs
481+
return icm_cov
482+
```
483+
484+
```{code-cell} ipython3
485+
with pm.Model() as model_htsc:
486+
ell = pm.InverseGamma("ell", mu=ell_mu, sigma=ell_sigma, initval=1.0)
487+
eta = pm.Gamma("eta", alpha=2, beta=1, initval=1.0)
488+
eq_cov = eta**2 * pm.gp.cov.Matern52(input_dim=2, active_dims=[0], ls=ell)
489+
490+
D_out = 2 # two output dimensions, mean and variance
491+
rank = 2 # two basis GPs
492+
W = pm.Normal("W", mu=0, sigma=1, shape=(D_out, rank), initval=-1 * np.ones((D_out, rank)))
493+
kappa = pm.Gamma("kappa", alpha=1.5, beta=1, shape=D_out)
494+
B = pm.Deterministic("B", pt.dot(W, W.T) + pt.diag(kappa))
495+
cov_icm = get_icm(input_dim=2, kernel=eq_cov, B=B, active_dims=[1])
496+
497+
gp_LMC = SparseLatent(cov_icm)
498+
LMC_f = gp_LMC.prior("LMC", X_obs_c, Xu_c)
499+
500+
# First half of LMC_f is the mean, second half is the log-scale
450501
mu_f = LMC_f[: len(y_obs_)]
451502
lg_sigma_f = LMC_f[len(y_obs_) :]
452503
sigma_f = pm.Deterministic("sigma_f", pm.math.exp(lg_sigma_f))
@@ -455,7 +506,7 @@ with pm.Model() as model_htsc:
455506
trace_htsc = pm.sample(
456507
target_accept=0.95,
457508
chains=2,
458-
nuts_sampler="numpyro",
509+
nuts_sampler="nutpie",
459510
return_inferencedata=True,
460511
random_seed=SEED,
461512
)
@@ -468,21 +519,37 @@ with model_htsc:
468519
```
469520

470521
```{code-cell} ipython3
471-
sigma_samples.shape
522+
mu_f.eval()
523+
```
524+
525+
```{code-cell} ipython3
526+
az.plot_trace(trace_htsc, var_names=["ell", "eta", "kappa"])
527+
```
528+
529+
```{code-cell} ipython3
530+
trace_htsc.predictions["c_mu_pred"].shape
531+
az.extract(trace_htsc.predictions)["c_mu_pred"].shape
472532
```
473533

474534
```{code-cell} ipython3
475535
# μ_samples = samples_htsc["c_mu_pred"][:, : len(Xnew)]
476536
# σ_samples = np.exp(samples_htsc["c_mu_pred"][:, len(Xnew) :])
477-
mu_samples = az.extract(trace_htsc.predictions["c_mu_pred"])["c_mu_pred"][: len(Xnew)]
478-
sigma_samples = np.exp(az.extract(trace_htsc.predictions["c_mu_pred"])["c_mu_pred"])[len(Xnew) :]
537+
samples_htsc = az.extract(trace_htsc.predictions)["c_mu_pred"]
538+
mu_samples = samples_htsc[: len(Xnew)]
539+
sigma_samples = np.exp(samples_htsc[len(Xnew) :])
479540
480541
_, axs = plt.subplots(1, 3, figsize=(18, 4))
542+
# plot_mean(axs[0], mu_samples.T)
543+
# plot_inducing_points(axs[0])
544+
# plot_var(axs[1], sigma_samples.T**2)
545+
# axs[1].set_ylim(-0.01, 0.2)
546+
# axs[1].legend(loc="upper left")
547+
# plot_inducing_points(axs[1])
548+
# plot_total(axs[2], mu_samples.T, sigma_samples.T**2)
549+
# plot_inducing_points(axs[2])
481550
plot_mean(axs[0], mu_samples.T)
482551
plot_inducing_points(axs[0])
483552
plot_var(axs[1], sigma_samples.T**2)
484-
axs[1].set_ylim(-0.01, 0.2)
485-
axs[1].legend(loc="upper left")
486553
plot_inducing_points(axs[1])
487554
plot_total(axs[2], mu_samples.T, sigma_samples.T**2)
488555
plot_inducing_points(axs[2])

0 commit comments

Comments
 (0)