@@ -406,9 +406,8 @@ def _grad_posterior_f(
406406 utility : Union [Tensor , np .ndarray ],
407407 datapoints : Tensor ,
408408 D : Tensor ,
409- DT : Tensor ,
410409 covar_chol : Tensor ,
411- covar_inv : Tensor ,
410+ covar_inv : Optional [ Tensor ] = None ,
412411 ret_np : bool = False ,
413412 ) -> Union [Tensor , np .ndarray ]:
414413 r"""Compute the gradient of S loss wrt to f/utility in [Chu2005preference]_.
@@ -421,10 +420,12 @@ def _grad_posterior_f(
421420 utility: A Tensor of shape `batch_size x n`
422421 datapoints: A Tensor of shape `batch_size x n x d` as in self.datapoints
423422 D: A Tensor of shape `batch_size x m x n` as in self.D
424- DT: Transpose of D. A Tensor of shape `batch_size x n x m` as in self.DT
425423 covar_chol: A Tensor of shape `batch_size x n x n`, as in self.covar_chol
426- covar_inv: A Tensor of shape `batch_size x n x n`, as in self.covar_inv
427- ret_np: return a numpy array if true, otherwise a Tensor
424+ covar_inv: `None` or a Tensor of shape `batch_size x n x n`, as in
425+ self.covar_inv. This is not used but is needed so that
426+ PairwiseGP._grad_posterior_f has the same signature as
427+ PairwiseGP._hess_posterior_f.
428+ ret_np: return a numpy array if True, otherwise a Tensor
428429 """
429430 prior_mean = self ._prior_mean (datapoints )
430431
@@ -442,15 +443,13 @@ def _grad_posterior_f(
442443 g = g_ + b
443444 if ret_np :
444445 return g .cpu ().numpy ()
445- else :
446- return g
446+ return g
447447
448448 def _hess_posterior_f (
449449 self ,
450450 utility : Union [Tensor , np .ndarray ],
451451 datapoints : Tensor ,
452452 D : Tensor ,
453- DT : Tensor ,
454453 covar_chol : Tensor ,
455454 covar_inv : Tensor ,
456455 ret_np : bool = False ,
@@ -463,10 +462,15 @@ def _hess_posterior_f(
463462
464463 Args:
465464 utility: A Tensor of shape `batch_size x n`
466- datapoints: A Tensor of shape `batch_size x n x d` as in self.datapoints
465+ datapoints: A Tensor of shape `batch_size x n x d`, as in
466+ self.datapoints. This is not used but is needed so that
467+ `_hess_posterior_f` has the same signature as
468+ `_grad_posterior_f`.
467469 D: A Tensor of shape `batch_size x m x n` as in self.D
468- DT: Transpose of D. A Tensor of shape `batch_size x n x m` as in self.DT
469- covar_chol: A Tensor of shape `batch_size x n x n`, as in self.covar_chol
470+ covar_chol: A Tensor of shape `batch_size x n x n`, as in
471+ self.covar_chol. This is not used but is needed so that
472+ `_hess_posterior_f` has the same signature as
473+ `_grad_posterior_f`.
470474 covar_inv: A Tensor of shape `batch_size x n x n`, as in self.covar_inv
471475 ret_np: return a numpy array if true, otherwise a Tensor
472476 """
@@ -478,12 +482,16 @@ def _hess_posterior_f(
478482 return hess .numpy () if ret_np else hess
479483
480484 def _update_utility_derived_values (self ) -> None :
481- r"""Calculate utility-derived values not needed during optimization
485+ r"""
486+ Set self.hlcov_eye to self.likelihood_hess @ self.covar + I.
487+
488+ `self.hlcov_eye` is a utility-derived value not used during
489+ optimization. This quantity is used so that we will be able to compute
490+ the predictive covariance (in PairwiseGP.forward in posterior mode) with
491+ better numerical stability using the substitution method:
482492
483- Using subsitution method for better numerical stability
484- Let `pred_cov_fac = (covar + hl^-1)`, which is needed for calculate
485- predictive covariance = `K - k.T @ pred_cov_fac^-1 @ k`
486- (Also see posterior mode in `forward`)
493+ Let `pred_cov_fac = (covar + hl^-1)`, which is needed for calculating
494+ the predictive covariance = `K - k.T @ pred_cov_fac^-1 @ k`.
487495 Instead of inverting `pred_cov_fac`, let `hlcov_eye = (hl @ covar + I)`
488496 Then we can obtain `pred_cov_fac^-1 @ k` by solving for p in
489497 `(hl @ k) p = hlcov_eye`
@@ -554,12 +562,11 @@ def _update(self, datapoints: Tensor, **kwargs) -> None:
554562 x0 = x0 .reshape (- 1 , self .n )
555563 dp_v = datapoints .view (- 1 , self .n , self .dim ).cpu ()
556564 D_v = self .D .view (- 1 , self .m , self .n ).cpu ()
557- DT_v = self .DT .view (- 1 , self .n , self .m ).cpu ()
558565 ch_v = self .covar_chol .view (- 1 , self .n , self .n ).cpu ()
559566 ci_v = self .covar_inv .view (- 1 , self .n , self .n ).cpu ()
560567 x = np .empty (x0 .shape )
561568 for i in range (x0 .shape [0 ]):
562- fsolve_args = (dp_v [i ], D_v [i ], DT_v [ i ], ch_v [i ], ci_v [i ], True )
569+ fsolve_args = (dp_v [i ], D_v [i ], ch_v [i ], ci_v [i ], True )
563570 with warnings .catch_warnings ():
564571 warnings .filterwarnings ("ignore" , category = RuntimeWarning )
565572 x [i ] = optimize .fsolve (
@@ -577,7 +584,6 @@ def _update(self, datapoints: Tensor, **kwargs) -> None:
577584 fsolve_args = (
578585 datapoints .cpu (),
579586 self .D .cpu (),
580- self .DT .cpu (),
581587 self .covar_chol .cpu (),
582588 self .covar_inv .cpu (),
583589 True ,
@@ -616,7 +622,7 @@ def _update(self, datapoints: Tensor, **kwargs) -> None:
616622 # the first step results in gradients in the order of 1e-7 while the 2nd step
617623 # allows it go down further to the order of 1e-12 and stay there.
618624 self .utility = self ._util_newton_updates (
619- datapoints , f .clone ().requires_grad_ (True ), max_iter = 2
625+ dp = datapoints , x0 = f .clone ().requires_grad_ (True ), max_iter = 2
620626 )
621627
622628 def _transform_batch_shape (self , X : Tensor , X_new : Tensor ) -> Tuple [Tensor , Tensor ]:
@@ -650,7 +656,9 @@ def _transform_batch_shape(self, X: Tensor, X_new: Tensor) -> Tuple[Tensor, Tens
650656 # if X has fewer dimension, try to expand it to X_new's shape
651657 return X .expand (X_new_bs + X .shape [- 2 :]), X_new
652658
653- def _util_newton_updates (self , dp , x0 , max_iter = 1 , xtol = None ) -> Tensor :
659+ def _util_newton_updates (
660+ self , dp : Tensor , x0 : Tensor , max_iter : int = 1 , xtol : Optional [float ] = None
661+ ) -> Tensor :
654662 r"""Make `max_iter` newton updates on utility.
655663
656664 This is used in `forward` to calculate and fill in gradient into tensors.
@@ -659,19 +667,15 @@ def _util_newton_updates(self, dp, x0, max_iter=1, xtol=None) -> Tensor:
659667 By default only need to run one iteration just to fill the the gradients.
660668
661669 Args:
662- dp: (Transformed) datapoints.
670+ dp: (Transformed) datapoints. A Tensor of shape `batch_size x n x d`
671+ as in self.datapoints
663672 x0: A `batch_size x n` dimension tensor, initial values.
664673 max_iter: Max number of iterations.
665674 xtol: Stop creteria. If `None`, do not stop until
666675 finishing `max_iter` updates.
667676 """
668677 xtol = float ("-Inf" ) if xtol is None else xtol
669- D , DT , ch , ci = (
670- self .D ,
671- self .DT ,
672- self .covar_chol ,
673- self .covar_inv ,
674- )
678+ D , ch = self .D , self .covar_chol
675679 covar = self .covar
676680 diff = float ("Inf" )
677681 i = 0
@@ -688,7 +692,12 @@ def _util_newton_updates(self, dp, x0, max_iter=1, xtol=None) -> Tensor:
688692 )
689693 )
690694 cov_hl = cov_hl + eye # add 1 to cov_hl
691- g = self ._grad_posterior_f (x , dp , D , DT , ch , ci )
695+ g = self ._grad_posterior_f (
696+ utility = x ,
697+ datapoints = dp ,
698+ D = D ,
699+ covar_chol = ch ,
700+ )
692701 cov_g = covar @ g .unsqueeze (- 1 )
693702 x_update = torch .linalg .solve (cov_hl , cov_g ).squeeze (- 1 )
694703 x_next = x - x_update
@@ -961,8 +970,8 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:
961970 device = self .datapoints .device ,
962971 ).expand (hl_cov .shape )
963972 hl_cov_I = hl_cov + eye # add I to hl_cov
964- train_covar_map = covar - covar @ torch .linalg .solve (hl_cov_I , hl_cov )
965- output_mean , output_covar = self .utility , train_covar_map
973+ output_covar = covar - covar @ torch .linalg .solve (hl_cov_I , hl_cov )
974+ output_mean = self .utility
966975
967976 # Prior mode
968977 elif settings .prior_mode .on () or self ._has_no_data ():
@@ -999,10 +1008,17 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:
9991008 pred_mean = (covar_xnew_x @ covar_inv_p ).squeeze (- 1 )
10001009 pred_mean = pred_mean + self ._prior_mean (X_new )
10011010
1002- # [Brochu2010tutorial]_ page 27
1003- # Preictive covariance fatcor: hlcov_eye = (K + C^-1)
1004- # fac = (K + C^-1)^-1 @ k = pred_cov_fac_inv @ covar_x_xnew
1005- # used substitution method here to calculate fac
1011+ # Using the terminology from [Brochu2010tutorial]_ page 27:
1012+ # hl = C; hlcov_eye = CK + I; k = covar_x_xnew
1013+ #
1014+ # To compute the predictive covariance, one term we need is
1015+ # k^T (K + C^{-1})^{-1} k.
1016+ # Rather than performing two matrix inversions, we can compute this
1017+ # in a more efficient and numerically stable way by using
1018+ # fac = hlcov_eye^-1 @ hl @ covar_x_xnew
1019+ # = (CK + I)^-1 @ C @ k
1020+ # = (K + C^-1)^{-1}
1021+ # This is the substitution method.
10061022 fac = torch .linalg .solve (hlcov_eye , hl @ covar_x_xnew )
10071023 pred_covar = covar_xnew - (covar_xnew_x @ fac )
10081024
@@ -1058,8 +1074,7 @@ def posterior(
10581074 posterior = GPyTorchPosterior (post )
10591075 if posterior_transform is not None :
10601076 return posterior_transform (posterior )
1061- else :
1062- return posterior
1077+ return posterior
10631078
10641079 def condition_on_observations (self , X : Tensor , Y : Tensor , ** kwargs : Any ) -> Model :
10651080 r"""Condition the model on new observations.
@@ -1071,6 +1086,7 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
10711086 X: A `batch_shape x n x d` dimension tensor X
10721087 Y: A tensor of size `batch_shape x m x 2`. (i, j) means
10731088 f_i is preferred over f_j
1089+ kwargs: Not used.
10741090
10751091 Returns:
10761092 A (deepcopied) `Model` object of the same type, representing the
0 commit comments