@@ -321,7 +321,7 @@ plot_var(axs[1], sigma_samples.T**2)
321
321
plot_total(axs[2], mu_samples.T, sigma_samples.T**2)
322
322
```
323
323
324
- That looks much better! We've accurately captured the mean behavior of our system, as well as the underlying trend in the variance (with appropriate uncertainty). Crucially, the aggregate behavior of the model integrates both epistemic * and* aleatoric uncertainty, and the ~ 5% of our observations fall outside the 2σ band are more or less evenly distributed across the domain. However, even with the Numpyro sampler, this took nearly an hour on a Ryen 7040 laptop to sample only 4k NUTS iterations. Due to the expense of the requisite matrix inversions, GPs are notoriously inefficient for large data sets. Let's reformulate this model using a sparse approximation.
324
+ That looks much better! We've accurately captured the mean behavior of our system, as well as the underlying trend in the variance (with appropriate uncertainty). Crucially, the aggregate behavior of the model integrates both epistemic * and* aleatoric uncertainty, and the ~ 5% of our observations fall outside the 2σ band are more or less evenly distributed across the domain. However, even with the Numpyro sampler, this took nearly an hour on a Ryzen 7040 laptop to sample only 4k NUTS iterations. Due to the expense of the requisite matrix inversions, GPs are notoriously inefficient for large data sets. Let's reformulate this model using a sparse approximation.
325
325
326
326
+++
327
327
@@ -341,20 +341,20 @@ class SparseLatent:
341
341
self.L = pt.linalg.cholesky(pm.gp.util.stabilize(Kuu))
342
342
343
343
self.v = pm.Normal(f"u_rotated_{name}", mu=0.0, sigma=1.0, shape=len(Xu))
344
- self.u = pm.Deterministic(f"u_{name}", pt.dot( self.L, self.v) )
344
+ self.u = pm.Deterministic(f"u_{name}", self.L @ self.v)
345
345
346
346
Kfu = self.cov(X, Xu)
347
347
self.Kuiu = pt.slinalg.solve_triangular(
348
348
self.L.T, pt.slinalg.solve_triangular(self.L, self.u, lower=True), lower=False
349
349
)
350
- self.mu = pm.Deterministic(f"mu_{name}", pt.dot( Kfu, self.Kuiu) )
350
+ self.mu = pm.Deterministic(f"mu_{name}", Kfu @ self.Kuiu)
351
351
return self.mu
352
352
353
353
def conditional(self, name, Xnew, Xu):
354
354
Ksu = self.cov(Xnew, Xu)
355
- mus = pt.dot( Ksu, self.Kuiu)
355
+ mus = Ksu @ self.Kuiu
356
356
tmp = pt.slinalg.solve_triangular(self.L, Ksu.T, lower=True)
357
- Qss = pt.dot( tmp.T, tmp)
357
+ Qss = tmp.T @ tmp
358
358
Kss = self.cov(Xnew)
359
359
Lss = pt.linalg.cholesky(pm.gp.util.stabilize(Kss - Qss))
360
360
mu_pred = pm.MvNormal(name, mu=mus, chol=Lss, shape=len(Xnew))
@@ -374,8 +374,8 @@ with pm.Model() as model_hts:
374
374
mu_f = mu_gp.prior("mu", X_obs, Xu)
375
375
376
376
sigma_ell = pm.InverseGamma("sigma_ell", mu=ell_mu, sigma=ell_sigma)
377
- sigma_η = pm.Gamma("sigma_η ", alpha=2, beta=1)
378
- sigma_cov = sigma_η **2 * pm.gp.cov.ExpQuad(input_dim=1, ls=sigma_ell)
377
+ sigma_eta = pm.Gamma("sigma_eta ", alpha=2, beta=1)
378
+ sigma_cov = sigma_eta **2 * pm.gp.cov.ExpQuad(input_dim=1, ls=sigma_ell)
379
379
380
380
lg_sigma_gp = SparseLatent(sigma_cov)
381
381
lg_sigma_f = lg_sigma_gp.prior("lg_sigma_f", X_obs, Xu)
@@ -429,24 +429,75 @@ def add_coreg_idx(x):
429
429
return np.hstack([np.tile(x, (2, 1)), np.vstack([np.zeros(x.shape), np.ones(x.shape)])])
430
430
431
431
432
+ Xu = X[1::2]
432
433
Xu_c, X_obs_c, Xnew_c = [add_coreg_idx(x) for x in [Xu, X_obs, Xnew]]
434
+ ```
433
435
436
+ ``` {code-cell} ipython3
434
437
with pm.Model() as model_htsc:
435
- ell = pm.InverseGamma("ell", mu=ell_mu, sigma=ell_sigma)
436
- eta = pm.Gamma("eta", alpha=2, beta=1)
437
- cov = eta**2 * pm.gp.cov.ExpQuad (input_dim=1, ls=ell)
438
+ ell = pm.InverseGamma("ell", mu=ell_mu, sigma=ell_sigma, initval=1.0 )
439
+ eta = pm.Gamma("eta", alpha=2, beta=1, initval=1.0 )
440
+ eq_cov = eta**2 * pm.gp.cov.Matern52 (input_dim=1, active_dims=[0] , ls=ell)
438
441
439
442
D_out = 2 # two output dimensions, mean and variance
440
443
rank = 2 # two basis GPs
441
- W = pm.Normal("W", mu=0, sigma=3, shape=(D_out, rank), initval=np.full([D_out, rank], 0.1) )
444
+ W = pm.Normal("W", mu=0, sigma=3, shape=(D_out, rank))
442
445
kappa = pm.Gamma("kappa", alpha=1.5, beta=1, shape=D_out)
443
446
coreg = pm.gp.cov.Coregion(input_dim=1, active_dims=[0], kappa=kappa, W=W)
444
447
445
- cov = pm.gp.cov.Kron([cov , coreg])
448
+ cov = pm.gp.cov.Kron([eq_cov , coreg])
446
449
447
450
gp_LMC = SparseLatent(cov)
448
451
LMC_f = gp_LMC.prior("LMC", X_obs_c, Xu_c)
449
452
453
+ # First half of LMC_f is the mean, second half is the log-scale
454
+ mu_f = LMC_f[: len(y_obs_)]
455
+ lg_sigma_f = LMC_f[len(y_obs_) :]
456
+ sigma_f = pm.Deterministic("sigma_f", pm.math.exp(lg_sigma_f))
457
+
458
+ lik_htsc = pm.Normal("lik_htsc", mu=mu_f, sigma=sigma_f, observed=y_obs_)
459
+ trace_htsc = pm.sample(
460
+ target_accept=0.9,
461
+ chains=2,
462
+ nuts_sampler="nutpie",
463
+ return_inferencedata=True,
464
+ random_seed=SEED,
465
+ )
466
+
467
+ with model_htsc:
468
+ c_mu_pred = gp_LMC.conditional("c_mu_pred", Xnew_c, Xu_c)
469
+ pm.sample_posterior_predictive(
470
+ trace_htsc, var_names=["c_mu_pred"], extend_inferencedata=True, predictions=True
471
+ )
472
+ ```
473
+
474
+ ``` {code-cell} ipython3
475
+ def get_icm(input_dim, kernel, W=None, kappa=None, B=None, active_dims=None):
476
+ """
477
+ This function generates an ICM kernel from an input kernel and a Coregion kernel.
478
+ """
479
+ coreg = pm.gp.cov.Coregion(input_dim=input_dim, W=W, kappa=kappa, B=B, active_dims=active_dims)
480
+ icm_cov = kernel * coreg # Use Hadamard Product for separate inputs
481
+ return icm_cov
482
+ ```
483
+
484
+ ``` {code-cell} ipython3
485
+ with pm.Model() as model_htsc:
486
+ ell = pm.InverseGamma("ell", mu=ell_mu, sigma=ell_sigma, initval=1.0)
487
+ eta = pm.Gamma("eta", alpha=2, beta=1, initval=1.0)
488
+ eq_cov = eta**2 * pm.gp.cov.Matern52(input_dim=2, active_dims=[0], ls=ell)
489
+
490
+ D_out = 2 # two output dimensions, mean and variance
491
+ rank = 2 # two basis GPs
492
+ W = pm.Normal("W", mu=0, sigma=1, shape=(D_out, rank), initval=-1 * np.ones((D_out, rank)))
493
+ kappa = pm.Gamma("kappa", alpha=1.5, beta=1, shape=D_out)
494
+ B = pm.Deterministic("B", pt.dot(W, W.T) + pt.diag(kappa))
495
+ cov_icm = get_icm(input_dim=2, kernel=eq_cov, B=B, active_dims=[1])
496
+
497
+ gp_LMC = SparseLatent(cov_icm)
498
+ LMC_f = gp_LMC.prior("LMC", X_obs_c, Xu_c)
499
+
500
+ # First half of LMC_f is the mean, second half is the log-scale
450
501
mu_f = LMC_f[: len(y_obs_)]
451
502
lg_sigma_f = LMC_f[len(y_obs_) :]
452
503
sigma_f = pm.Deterministic("sigma_f", pm.math.exp(lg_sigma_f))
@@ -455,7 +506,7 @@ with pm.Model() as model_htsc:
455
506
trace_htsc = pm.sample(
456
507
target_accept=0.95,
457
508
chains=2,
458
- nuts_sampler="numpyro ",
509
+ nuts_sampler="nutpie ",
459
510
return_inferencedata=True,
460
511
random_seed=SEED,
461
512
)
@@ -468,21 +519,37 @@ with model_htsc:
468
519
```
469
520
470
521
``` {code-cell} ipython3
471
- sigma_samples.shape
522
+ mu_f.eval()
523
+ ```
524
+
525
+ ``` {code-cell} ipython3
526
+ az.plot_trace(trace_htsc, var_names=["ell", "eta", "kappa"])
527
+ ```
528
+
529
+ ``` {code-cell} ipython3
530
+ trace_htsc.predictions["c_mu_pred"].shape
531
+ az.extract(trace_htsc.predictions)["c_mu_pred"].shape
472
532
```
473
533
474
534
``` {code-cell} ipython3
475
535
# μ_samples = samples_htsc["c_mu_pred"][:, : len(Xnew)]
476
536
# σ_samples = np.exp(samples_htsc["c_mu_pred"][:, len(Xnew) :])
477
- mu_samples = az.extract(trace_htsc.predictions["c_mu_pred"])["c_mu_pred"][: len(Xnew)]
478
- sigma_samples = np.exp(az.extract(trace_htsc.predictions["c_mu_pred"])["c_mu_pred"])[len(Xnew) :]
537
+ samples_htsc = az.extract(trace_htsc.predictions)["c_mu_pred"]
538
+ mu_samples = samples_htsc[: len(Xnew)]
539
+ sigma_samples = np.exp(samples_htsc[len(Xnew) :])
479
540
480
541
_, axs = plt.subplots(1, 3, figsize=(18, 4))
542
+ # plot_mean(axs[0], mu_samples.T)
543
+ # plot_inducing_points(axs[0])
544
+ # plot_var(axs[1], sigma_samples.T**2)
545
+ # axs[1].set_ylim(-0.01, 0.2)
546
+ # axs[1].legend(loc="upper left")
547
+ # plot_inducing_points(axs[1])
548
+ # plot_total(axs[2], mu_samples.T, sigma_samples.T**2)
549
+ # plot_inducing_points(axs[2])
481
550
plot_mean(axs[0], mu_samples.T)
482
551
plot_inducing_points(axs[0])
483
552
plot_var(axs[1], sigma_samples.T**2)
484
- axs[1].set_ylim(-0.01, 0.2)
485
- axs[1].legend(loc="upper left")
486
553
plot_inducing_points(axs[1])
487
554
plot_total(axs[2], mu_samples.T, sigma_samples.T**2)
488
555
plot_inducing_points(axs[2])
0 commit comments