Skip to content

Commit 6e08992

Browse files
Balandatfacebook-github-bot
authored andcommitted
Suppress warnings in PairwiseGP. (#428)
Summary: Suppresses `RuntimeWarning` within the `fsolve` calls. Also, the warning in `_add_jitter` was not emitting a warning of type `RuntimeWarning`, so it was not properly suppressed in the tests. Pull Request resolved: #428 Test Plan: unit tests Reviewed By: danielrjiang Differential Revision: D21247256 Pulled By: Balandat fbshipit-source-id: 9deb70317cdfbd27c895cf46135dae70be720c90
1 parent a948093 commit 6e08992

File tree

2 files changed

+31
-26
lines changed

2 files changed

+31
-26
lines changed

botorch/models/pairwise_gp.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __init__(
149149

150150
self.to(self.datapoints)
151151

152-
def __deepcopy__(self, memo):
152+
def __deepcopy__(self, memo) -> PairwiseGP:
153153
attrs = (
154154
"datapoints",
155155
"comparisons",
@@ -260,27 +260,28 @@ def _prior_predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
260260
pred_covar = self._calc_covar(X, X)
261261
return pred_mean, pred_covar
262262

263-
def _add_jitter(self, X) -> Tensor:
264-
X = X.clone()
263+
def _add_jitter(self, X: Tensor) -> Tensor:
265264
jitter_prev = 0
265+
Eye = torch.eye(X.size(-1)).expand(X.shape)
266266
for i in range(3):
267267
jitter_new = self._jitter * (10 ** i)
268-
X = X + torch.eye(X.size(-1)).expand(X.shape) * (jitter_new - jitter_prev)
268+
X = X + (jitter_new - jitter_prev) * Eye
269269
jitter_prev = jitter_new
270270
# This may be VERY slow given upstream pytorch issue:
271271
# https://github.com/pytorch/pytorch/issues/34272
272272
try:
273273
_ = torch.cholesky(X)
274274
warnings.warn(
275-
f"X is not a p.d. matrix; "
276-
f"Added jitter of {jitter_new} to the diagonal",
275+
"X is not a p.d. matrix; "
276+
f"Added jitter of {jitter_new:.2e} to the diagonal",
277277
RuntimeWarning,
278278
)
279279
return X
280280
except RuntimeError:
281281
continue
282282
warnings.warn(
283-
f"Failed to turn X into p.d. after trying to add adding {jitter_new} jitter"
283+
f"Failed to render X p.d. after adding {jitter_new:.2e} jitter",
284+
RuntimeWarning,
284285
)
285286
return X
286287

@@ -524,15 +525,17 @@ def _update(self, **kwargs) -> None:
524525
ci_v[i],
525526
True,
526527
)
527-
x[i] = optimize.fsolve(
528-
x0=x0[i],
529-
func=self._grad_posterior_f,
530-
fprime=self._hess_posterior_f,
531-
xtol=xtol,
532-
maxfev=maxfev,
533-
args=fsolve_args,
534-
**kwargs,
535-
)
528+
with warnings.catch_warnings():
529+
warnings.filterwarnings("ignore", category=RuntimeWarning)
530+
x[i] = optimize.fsolve(
531+
x0=x0[i],
532+
func=self._grad_posterior_f,
533+
fprime=self._hess_posterior_f,
534+
xtol=xtol,
535+
maxfev=maxfev,
536+
args=fsolve_args,
537+
**kwargs,
538+
)
536539
x = x.reshape(*init_x0_size)
537540
else:
538541
fsolve_args = (
@@ -544,15 +547,17 @@ def _update(self, **kwargs) -> None:
544547
self.covar_inv,
545548
True,
546549
)
547-
x = optimize.fsolve(
548-
x0=x0,
549-
func=self._grad_posterior_f,
550-
fprime=self._hess_posterior_f,
551-
xtol=xtol,
552-
maxfev=maxfev,
553-
args=fsolve_args,
554-
**kwargs,
555-
)
550+
with warnings.catch_warnings():
551+
warnings.filterwarnings("ignore", category=RuntimeWarning)
552+
x = optimize.fsolve(
553+
x0=x0,
554+
func=self._grad_posterior_f,
555+
fprime=self._hess_posterior_f,
556+
xtol=xtol,
557+
maxfev=maxfev,
558+
args=fsolve_args,
559+
**kwargs,
560+
)
556561

557562
self._x0 = x.copy() # save for warm-starting
558563
f = torch.tensor(x, **self.tkwargs)

test/models/test_pairwise_gp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def test_condition_on_observations(self):
247247
"datapoints": model_kwargs["datapoints"][0],
248248
"comparisons": model_kwargs["comparisons"][0],
249249
}
250-
model_non_batch = type(model)(**model_kwargs_non_batch)
250+
model_non_batch = model.__class__(**model_kwargs_non_batch)
251251
model_non_batch.load_state_dict(state_dict_non_batch)
252252
model_non_batch.eval()
253253
model_non_batch.posterior(

0 commit comments

Comments
 (0)