Skip to content

Commit 8441d62

Browse files
ItsMrLinfacebook-github-bot
authored andcommitted
Adjust PairwiseGP ScaleKernel prior (#1460)
Summary: Pull Request resolved: #1460 Updating the prior of PairwiseGP's output scale prior. Additionally, also make sure it must be used, better initialization of the inferred utility values, and replaced `_batch_chol_inv` with `torch.cholesky_inverse`. TLDR is that we were previously using an significantly restrictive prior on the output scale theta (note theta = 1/sigma^2 where sigma is the probit noise on the function value), this prevent us from accommodating comparison errors outside range of the green line. Reviewed By: Balandat Differential Revision: D40136741 fbshipit-source-id: 981b974a34f633d09880e670663c4b671f574ff2
1 parent 2ea11a6 commit 8441d62

File tree

2 files changed

+47
-29
lines changed

2 files changed

+47
-29
lines changed

botorch/models/pairwise_gp.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from botorch.posteriors.gpytorch import GPyTorchPosterior
3838
from botorch.posteriors.posterior import Posterior
3939
from gpytorch import settings
40-
from gpytorch.constraints import GreaterThan
40+
from gpytorch.constraints import GreaterThan, Interval
4141
from gpytorch.distributions.multivariate_normal import MultivariateNormal
4242
from gpytorch.kernels.rbf_kernel import RBFKernel
4343
from gpytorch.kernels.scale_kernel import ScaleKernel
@@ -147,7 +147,7 @@ def __init__(
147147

148148
# Set optional parameters
149149
# Explicitly set jitter for numerical stability in psd_safe_cholesky
150-
self._jitter = kwargs.get("jitter", 1e-5)
150+
self._jitter = kwargs.get("jitter", 1e-6)
151151
# Stopping creteria in scipy.optimize.fsolve used to find f_map in _update()
152152
# If None, set to 1e-6 by default in _update
153153
self._xtol = kwargs.get("xtol")
@@ -170,6 +170,7 @@ def __init__(
170170
# estimates away from scale value that would make Phi(f(x)) saturate
171171
# at 0 or 1
172172
if covar_module is None:
173+
os_lb, os_ub = 1e-2, 1e2
173174
ls_prior = GammaPrior(1.2, 0.5)
174175
ls_prior_mode = (ls_prior.concentration - 1) / ls_prior.rate
175176
covar_module = ScaleKernel(
@@ -181,9 +182,16 @@ def __init__(
181182
lower_bound=1e-4, transform=None, initial_value=ls_prior_mode
182183
),
183184
),
184-
outputscale_prior=SmoothedBoxPrior(a=1, b=4),
185+
outputscale_prior=SmoothedBoxPrior(a=os_lb, b=os_ub),
186+
# make sure we won't get extreme values for the output scale
187+
outputscale_constraint=Interval(
188+
lower_bound=os_lb * 0.5,
189+
upper_bound=os_ub * 2.0,
190+
initial_value=1.0,
191+
),
185192
)
186-
193+
if not isinstance(covar_module, ScaleKernel):
194+
raise UnsupportedError("PairwiseGP must be used with a ScaleKernel.")
187195
self.covar_module = covar_module
188196

189197
self._x0 = None # will store temporary results for warm-starting
@@ -225,6 +233,16 @@ def __deepcopy__(self, memo) -> PairwiseGP:
225233
self.__deepcopy__ = dcp
226234
return new_model
227235

236+
def _scaled_psd_safe_cholesky(
237+
self, M: Tensor, jitter: Optional[float] = None
238+
) -> Tensor:
239+
r"""scale M by 1/outputscale before cholesky for better numerical stability"""
240+
scale = self.covar_module.outputscale.unsqueeze(-1).unsqueeze(-1)
241+
M = M / scale
242+
chol = psd_safe_cholesky(M, jitter=jitter)
243+
chol = chol * scale.sqrt()
244+
return chol
245+
228246
def _has_no_data(self):
229247
r"""Return true if the model does not have both datapoints and comparisons"""
230248
return (
@@ -238,24 +256,6 @@ def _calc_covar(self, X1: Tensor, X2: Tensor) -> Union[Tensor, LinearOperator]:
238256
covar = self.covar_module(X1, X2)
239257
return covar.to_dense()
240258

241-
def _batch_chol_inv(self, mat_chol: Tensor) -> Tensor:
242-
r"""Wrapper to perform (batched) cholesky inverse"""
243-
# TODO: get rid of this once cholesky_inverse supports batch mode
244-
batch_eye = torch.eye(
245-
mat_chol.shape[-1],
246-
dtype=self.datapoints.dtype,
247-
device=self.datapoints.device,
248-
)
249-
250-
if len(mat_chol.shape) == 2:
251-
mat_inv = torch.cholesky_inverse(mat_chol)
252-
elif len(mat_chol.shape) > 2 and (mat_chol.shape[-1] == mat_chol.shape[-2]):
253-
batch_eye = batch_eye.repeat(*(mat_chol.shape[:-2]), 1, 1)
254-
chol_inv = torch.linalg.solve_triangular(mat_chol, batch_eye, upper=False)
255-
mat_inv = chol_inv.transpose(-1, -2) @ chol_inv
256-
257-
return mat_inv
258-
259259
def _update_covar(self, datapoints: Tensor) -> None:
260260
r"""Update values derived from the data and hyperparameters
261261
@@ -265,8 +265,10 @@ def _update_covar(self, datapoints: Tensor) -> None:
265265
datapoints: (Transformed) datapoints for finding f_max
266266
"""
267267
self.covar = self._calc_covar(datapoints, datapoints)
268-
self.covar_chol = psd_safe_cholesky(self.covar, jitter=self._jitter)
269-
self.covar_inv = self._batch_chol_inv(self.covar_chol)
268+
self.covar_chol = self._scaled_psd_safe_cholesky(
269+
self.covar, jitter=self._jitter
270+
)
271+
self.covar_inv = torch.cholesky_inverse(self.covar_chol)
270272

271273
def _prior_mean(self, X: Tensor) -> Union[Tensor, LinearOperator]:
272274
r"""Return point prediction using prior only
@@ -417,7 +419,17 @@ def _update(self, datapoints: Tensor, **kwargs) -> None:
417419
# warm start
418420
init_x0_size = self.batch_shape + torch.Size([self.n])
419421
if self._x0 is None or torch.Size(self._x0.shape) != init_x0_size:
420-
x0 = np.random.rand(*init_x0_size)
422+
sqrt_scale = (
423+
self.covar_module.outputscale.sqrt()
424+
.unsqueeze(-1)
425+
.detach()
426+
.cpu()
427+
.numpy()
428+
)
429+
# initialize x0 using std normal but clip by 3 std to keep it bounded
430+
x0 = np.random.standard_normal(init_x0_size).clip(min=-3, max=3)
431+
# scale x0 to be on roughly the right scale
432+
x0 = x0 * sqrt_scale
421433
else:
422434
x0 = self._x0
423435

@@ -755,7 +767,6 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:
755767
2. Prior predictions (prior mode)
756768
3. Predictive posterior (eval mode)
757769
"""
758-
759770
# Training mode: optimizing
760771
if self.training:
761772
if self._has_no_data():
@@ -839,7 +850,7 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:
839850
# output_covar is sometimes non-PSD
840851
# perform a cholesky decomposition to check and amend
841852
covariance_matrix=RootLinearOperator(
842-
psd_safe_cholesky(output_covar, jitter=self._jitter)
853+
self._scaled_psd_safe_cholesky(output_covar, jitter=self._jitter)
843854
),
844855
)
845856
return post

test/models/test_pairwise_gp.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,16 @@ def test_pairwise_gp(self):
105105
self.assertEqual(model.num_outputs, 1)
106106
self.assertEqual(model.batch_shape, batch_shape)
107107

108+
# test not using a ScaleKernel
109+
with self.assertRaisesRegex(UnsupportedError, "used with a ScaleKernel"):
110+
PairwiseGP(**model_kwargs, covar_module=LinearKernel())
111+
108112
# test custom models
109-
custom_m = PairwiseGP(**model_kwargs, covar_module=LinearKernel())
110-
self.assertIsInstance(custom_m.covar_module, LinearKernel)
113+
custom_m = PairwiseGP(
114+
**model_kwargs, covar_module=ScaleKernel(LinearKernel())
115+
)
116+
self.assertIsInstance(custom_m.covar_module, ScaleKernel)
117+
self.assertIsInstance(custom_m.covar_module.base_kernel, LinearKernel)
111118

112119
# prior prediction
113120
prior_m = PairwiseGP(None, None).to(**tkwargs)

0 commit comments

Comments
 (0)