Skip to content

Commit d81a674

Browse files
esantorellafacebook-github-bot
authored andcommitted
Clarifying comments and expanded docstrings for PairwiseGP linear algebra (#2072)
Summary: ## Motivation * Added clarifying comments and a minor correction or two while working through this code myself. * Expanded on docstrings. * Removed DT, an unused transpose argument, from two private methods, and stopped computing it. This may provide a minor speedup. * ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #2072 Test Plan: Ran existing unit tests for PairwiseGP. Reviewed By: ItsMrLin Differential Revision: D50667342 Pulled By: esantorella fbshipit-source-id: 8c27736844060b9b2d6a3c83d2bf2cf3628aeb61
1 parent 84245d5 commit d81a674

File tree

1 file changed

+53
-37
lines changed

1 file changed

+53
-37
lines changed

botorch/models/pairwise_gp.py

Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -406,9 +406,8 @@ def _grad_posterior_f(
406406
utility: Union[Tensor, np.ndarray],
407407
datapoints: Tensor,
408408
D: Tensor,
409-
DT: Tensor,
410409
covar_chol: Tensor,
411-
covar_inv: Tensor,
410+
covar_inv: Optional[Tensor] = None,
412411
ret_np: bool = False,
413412
) -> Union[Tensor, np.ndarray]:
414413
r"""Compute the gradient of S loss wrt to f/utility in [Chu2005preference]_.
@@ -421,10 +420,12 @@ def _grad_posterior_f(
421420
utility: A Tensor of shape `batch_size x n`
422421
datapoints: A Tensor of shape `batch_size x n x d` as in self.datapoints
423422
D: A Tensor of shape `batch_size x m x n` as in self.D
424-
DT: Transpose of D. A Tensor of shape `batch_size x n x m` as in self.DT
425423
covar_chol: A Tensor of shape `batch_size x n x n`, as in self.covar_chol
426-
covar_inv: A Tensor of shape `batch_size x n x n`, as in self.covar_inv
427-
ret_np: return a numpy array if true, otherwise a Tensor
424+
covar_inv: `None` or a Tensor of shape `batch_size x n x n`, as in
425+
self.covar_inv. This is not used but is needed so that
426+
PairwiseGP._grad_posterior_f has the same signature as
427+
PairwiseGP._hess_posterior_f.
428+
ret_np: return a numpy array if True, otherwise a Tensor
428429
"""
429430
prior_mean = self._prior_mean(datapoints)
430431

@@ -442,15 +443,13 @@ def _grad_posterior_f(
442443
g = g_ + b
443444
if ret_np:
444445
return g.cpu().numpy()
445-
else:
446-
return g
446+
return g
447447

448448
def _hess_posterior_f(
449449
self,
450450
utility: Union[Tensor, np.ndarray],
451451
datapoints: Tensor,
452452
D: Tensor,
453-
DT: Tensor,
454453
covar_chol: Tensor,
455454
covar_inv: Tensor,
456455
ret_np: bool = False,
@@ -463,10 +462,15 @@ def _hess_posterior_f(
463462
464463
Args:
465464
utility: A Tensor of shape `batch_size x n`
466-
datapoints: A Tensor of shape `batch_size x n x d` as in self.datapoints
465+
datapoints: A Tensor of shape `batch_size x n x d`, as in
466+
self.datapoints. This is not used but is needed so that
467+
`_hess_posterior_f` has the same signature as
468+
`_grad_posterior_f`.
467469
D: A Tensor of shape `batch_size x m x n` as in self.D
468-
DT: Transpose of D. A Tensor of shape `batch_size x n x m` as in self.DT
469-
covar_chol: A Tensor of shape `batch_size x n x n`, as in self.covar_chol
470+
covar_chol: A Tensor of shape `batch_size x n x n`, as in
471+
self.covar_chol. This is not used but is needed so that
472+
`_hess_posterior_f` has the same signature as
473+
`_grad_posterior_f`.
470474
covar_inv: A Tensor of shape `batch_size x n x n`, as in self.covar_inv
471475
ret_np: return a numpy array if true, otherwise a Tensor
472476
"""
@@ -478,12 +482,16 @@ def _hess_posterior_f(
478482
return hess.numpy() if ret_np else hess
479483

480484
def _update_utility_derived_values(self) -> None:
481-
r"""Calculate utility-derived values not needed during optimization
485+
r"""
486+
Set self.hlcov_eye to self.likelihood_hess @ self.covar + I.
487+
488+
`self.hlcov_eye` is a utility-derived value not used during
489+
optimization. This quantity is used so that we will be able to compute
490+
the predictive covariance (in PairwiseGP.forward in posterior mode) with
491+
better numerical stability using the substitution method:
482492
483-
Using subsitution method for better numerical stability
484-
Let `pred_cov_fac = (covar + hl^-1)`, which is needed for calculate
485-
predictive covariance = `K - k.T @ pred_cov_fac^-1 @ k`
486-
(Also see posterior mode in `forward`)
493+
Let `pred_cov_fac = (covar + hl^-1)`, which is needed for calculating
494+
the predictive covariance = `K - k.T @ pred_cov_fac^-1 @ k`.
487495
Instead of inverting `pred_cov_fac`, let `hlcov_eye = (hl @ covar + I)`
488496
Then we can obtain `pred_cov_fac^-1 @ k` by solving for p in
489497
`(hl @ k) p = hlcov_eye`
@@ -554,12 +562,11 @@ def _update(self, datapoints: Tensor, **kwargs) -> None:
554562
x0 = x0.reshape(-1, self.n)
555563
dp_v = datapoints.view(-1, self.n, self.dim).cpu()
556564
D_v = self.D.view(-1, self.m, self.n).cpu()
557-
DT_v = self.DT.view(-1, self.n, self.m).cpu()
558565
ch_v = self.covar_chol.view(-1, self.n, self.n).cpu()
559566
ci_v = self.covar_inv.view(-1, self.n, self.n).cpu()
560567
x = np.empty(x0.shape)
561568
for i in range(x0.shape[0]):
562-
fsolve_args = (dp_v[i], D_v[i], DT_v[i], ch_v[i], ci_v[i], True)
569+
fsolve_args = (dp_v[i], D_v[i], ch_v[i], ci_v[i], True)
563570
with warnings.catch_warnings():
564571
warnings.filterwarnings("ignore", category=RuntimeWarning)
565572
x[i] = optimize.fsolve(
@@ -577,7 +584,6 @@ def _update(self, datapoints: Tensor, **kwargs) -> None:
577584
fsolve_args = (
578585
datapoints.cpu(),
579586
self.D.cpu(),
580-
self.DT.cpu(),
581587
self.covar_chol.cpu(),
582588
self.covar_inv.cpu(),
583589
True,
@@ -616,7 +622,7 @@ def _update(self, datapoints: Tensor, **kwargs) -> None:
616622
# the first step results in gradients in the order of 1e-7 while the 2nd step
617623
# allows it go down further to the order of 1e-12 and stay there.
618624
self.utility = self._util_newton_updates(
619-
datapoints, f.clone().requires_grad_(True), max_iter=2
625+
dp=datapoints, x0=f.clone().requires_grad_(True), max_iter=2
620626
)
621627

622628
def _transform_batch_shape(self, X: Tensor, X_new: Tensor) -> Tuple[Tensor, Tensor]:
@@ -650,7 +656,9 @@ def _transform_batch_shape(self, X: Tensor, X_new: Tensor) -> Tuple[Tensor, Tens
650656
# if X has fewer dimension, try to expand it to X_new's shape
651657
return X.expand(X_new_bs + X.shape[-2:]), X_new
652658

653-
def _util_newton_updates(self, dp, x0, max_iter=1, xtol=None) -> Tensor:
659+
def _util_newton_updates(
660+
self, dp: Tensor, x0: Tensor, max_iter: int = 1, xtol: Optional[float] = None
661+
) -> Tensor:
654662
r"""Make `max_iter` newton updates on utility.
655663
656664
This is used in `forward` to calculate and fill in gradient into tensors.
@@ -659,19 +667,15 @@ def _util_newton_updates(self, dp, x0, max_iter=1, xtol=None) -> Tensor:
659667
By default only need to run one iteration just to fill the the gradients.
660668
661669
Args:
662-
dp: (Transformed) datapoints.
670+
dp: (Transformed) datapoints. A Tensor of shape `batch_size x n x d`
671+
as in self.datapoints
663672
x0: A `batch_size x n` dimension tensor, initial values.
664673
max_iter: Max number of iterations.
665674
xtol: Stop creteria. If `None`, do not stop until
666675
finishing `max_iter` updates.
667676
"""
668677
xtol = float("-Inf") if xtol is None else xtol
669-
D, DT, ch, ci = (
670-
self.D,
671-
self.DT,
672-
self.covar_chol,
673-
self.covar_inv,
674-
)
678+
D, ch = self.D, self.covar_chol
675679
covar = self.covar
676680
diff = float("Inf")
677681
i = 0
@@ -688,7 +692,12 @@ def _util_newton_updates(self, dp, x0, max_iter=1, xtol=None) -> Tensor:
688692
)
689693
)
690694
cov_hl = cov_hl + eye # add 1 to cov_hl
691-
g = self._grad_posterior_f(x, dp, D, DT, ch, ci)
695+
g = self._grad_posterior_f(
696+
utility=x,
697+
datapoints=dp,
698+
D=D,
699+
covar_chol=ch,
700+
)
692701
cov_g = covar @ g.unsqueeze(-1)
693702
x_update = torch.linalg.solve(cov_hl, cov_g).squeeze(-1)
694703
x_next = x - x_update
@@ -961,8 +970,8 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:
961970
device=self.datapoints.device,
962971
).expand(hl_cov.shape)
963972
hl_cov_I = hl_cov + eye # add I to hl_cov
964-
train_covar_map = covar - covar @ torch.linalg.solve(hl_cov_I, hl_cov)
965-
output_mean, output_covar = self.utility, train_covar_map
973+
output_covar = covar - covar @ torch.linalg.solve(hl_cov_I, hl_cov)
974+
output_mean = self.utility
966975

967976
# Prior mode
968977
elif settings.prior_mode.on() or self._has_no_data():
@@ -999,10 +1008,17 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:
9991008
pred_mean = (covar_xnew_x @ covar_inv_p).squeeze(-1)
10001009
pred_mean = pred_mean + self._prior_mean(X_new)
10011010

1002-
# [Brochu2010tutorial]_ page 27
1003-
# Preictive covariance fatcor: hlcov_eye = (K + C^-1)
1004-
# fac = (K + C^-1)^-1 @ k = pred_cov_fac_inv @ covar_x_xnew
1005-
# used substitution method here to calculate fac
1011+
# Using the terminology from [Brochu2010tutorial]_ page 27:
1012+
# hl = C; hlcov_eye = CK + I; k = covar_x_xnew
1013+
#
1014+
# To compute the predictive covariance, one term we need is
1015+
# k^T (K + C^{-1})^{-1} k.
1016+
# Rather than performing two matrix inversions, we can compute this
1017+
# in a more efficient and numerically stable way by using
1018+
# fac = hlcov_eye^-1 @ hl @ covar_x_xnew
1019+
# = (CK + I)^-1 @ C @ k
1020+
# = (K + C^-1)^{-1}
1021+
# This is the substitution method.
10061022
fac = torch.linalg.solve(hlcov_eye, hl @ covar_x_xnew)
10071023
pred_covar = covar_xnew - (covar_xnew_x @ fac)
10081024

@@ -1058,8 +1074,7 @@ def posterior(
10581074
posterior = GPyTorchPosterior(post)
10591075
if posterior_transform is not None:
10601076
return posterior_transform(posterior)
1061-
else:
1062-
return posterior
1077+
return posterior
10631078

10641079
def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Model:
10651080
r"""Condition the model on new observations.
@@ -1071,6 +1086,7 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
10711086
X: A `batch_shape x n x d` dimension tensor X
10721087
Y: A tensor of size `batch_shape x m x 2`. (i, j) means
10731088
f_i is preferred over f_j
1089+
kwargs: Not used.
10741090
10751091
Returns:
10761092
A (deepcopied) `Model` object of the same type, representing the

0 commit comments

Comments
 (0)