Skip to content

Commit 57decad

Browse files
Balandatfacebook-github-bot
authored andcommitted
Fix PaiwiseGP on GPU (#1388)
Summary: Pull Request resolved: #1388 Yet again device/dtype mismatches... Reviewed By: ItsMrLin Differential Revision: D39360545 fbshipit-source-id: 1b502b1487aee1cafa4076c2ab78836cc9394fc8
1 parent 286012d commit 57decad

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

botorch/models/pairwise_gp.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,11 @@ def _util_newton_updates(self, dp, x0, max_iter=1, xtol=None) -> Tensor:
577577
self.likelihood_hess = hl
578578
cov_hl = covar @ hl
579579
if eye is None:
580-
eye = torch.diag_embed(torch.ones(cov_hl.shape[:-1]))
580+
eye = torch.diag_embed(
581+
torch.ones(
582+
cov_hl.shape[:-1], device=cov_hl.device, dtype=cov_hl.dtype
583+
)
584+
)
581585
cov_hl = cov_hl + eye # add 1 to cov_hl
582586
g = self._grad_posterior_f(x, dp, D, DT, ch, ci)
583587
cov_g = covar @ g.unsqueeze(-1)
@@ -1008,7 +1012,9 @@ def forward(self, post: Posterior, comp: Tensor) -> Tensor:
10081012
log_posterior = -log_posterior.clamp(min=0)
10091013

10101014
mll = model.covar @ model.likelihood_hess
1011-
mll = mll + torch.diag_embed(torch.ones(mll.shape[:-1]))
1015+
mll = mll + torch.diag_embed(
1016+
torch.ones(mll.shape[:-1], device=mll.device, dtype=mll.dtype)
1017+
)
10121018
mll = -0.5 * torch.logdet(mll)
10131019

10141020
mll = mll + log_posterior

0 commit comments

Comments
 (0)