Skip to content

Commit 965f154

Browse files
David Erikssonfacebook-github-bot
authored andcommitted
Fix sample_all_priors to not sample one value for all lengthscales (#2404)
Summary: Pull Request resolved: #2404 #2371 added support for `sample_all_priors` to handle multivariate priors, but these changes resulted in sampling the same value for all entries of a multi-dimensional hyperparameter if a univariate prior is used. In particular, this means that `sample_all_priors` will sample exactly the same lengthscale for all dimensions when using a univariate prior. This diff changes this behavior to instead sample according to the shape of the closure when a univariate prior is specified. This results in sampling different lengthscales for each dimension and batch dimension. Reviewed By: saitcakmak Differential Revision: D58855726 fbshipit-source-id: 5c70ba83ff1710fe678c60f0ed2f11f57e671ad5
1 parent bf529df commit 965f154

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

botorch/optim/utils/model_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,16 @@ def sample_all_priors(model: GPyTorchModel, max_retries: int = 100) -> None:
163163
try:
164164
# Set sample shape, so that the prior samples have the same shape
165165
# as `closure(module)` without having to be repeated.
166-
closure_shape = closure(module).shape
167166
prior_shape = prior._extended_shape()
168-
sample_shape = closure_shape[: -len(prior_shape)]
169-
setting_closure(module, prior.sample(sample_shape=sample_shape))
167+
if prior_shape.numel() == 1:
168+
# For a univariate prior we can sample the size of the closure.
169+
# Otherwise we will sample exactly the same value for all
170+
# lengthscales where we commonly specify a univariate prior.
171+
setting_closure(module, prior.sample(closure(module).shape))
172+
else:
173+
closure_shape = closure(module).shape
174+
sample_shape = closure_shape[: -len(prior_shape)]
175+
setting_closure(module, prior.sample(sample_shape=sample_shape))
170176
break
171177
except NotImplementedError:
172178
warn(

test/optim/utils/test_model_utils.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import warnings
1111
from copy import deepcopy
1212
from string import ascii_lowercase
13+
from typing import Any, Dict
1314
from unittest.mock import MagicMock, patch
1415

1516
import torch
@@ -246,19 +247,40 @@ def test_sample_all_priors(self):
246247
with self.assertRaises(RuntimeError):
247248
sample_all_priors(model)
248249

250+
def test_univariate_prior(self) -> None:
251+
tkwargs: Dict[str, Any] = {"device": self.device, "dtype": torch.double}
252+
for batch in (torch.Size([]), torch.Size([2, 2])):
253+
model = SingleTaskGP(
254+
train_X=torch.rand(*batch, 5, 3, **tkwargs),
255+
train_Y=torch.randn(*batch, 5, 1, **tkwargs),
256+
covar_module=RBFKernel(
257+
ard_num_dims=3,
258+
batch_shape=batch,
259+
lengthscale_prior=GammaPrior(6.0, 3.0), # univariate
260+
),
261+
)
262+
original_lengthscales = model.covar_module.lengthscale
263+
self.assertEqual(len(torch.unique(original_lengthscales)), 1)
264+
sample_all_priors(model)
265+
new_lengthscales = model.covar_module.lengthscale
266+
self.assertFalse(torch.allclose(original_lengthscales, new_lengthscales))
267+
# Make sure we sampled different lengthscales (happens with probability 1)
268+
self.assertEqual(len(torch.unique(new_lengthscales)), 3 * batch.numel())
269+
249270
def test_with_multivariate_prior(self) -> None:
250271
# This is modified from https://github.com/pytorch/botorch/issues/780.
272+
tkwargs: Dict[str, Any] = {"device": self.device, "dtype": torch.double}
251273
for batch in (torch.Size([]), torch.Size([3])):
252274
model = SingleTaskGP(
253-
train_X=torch.randn(*batch, 2, 2),
254-
train_Y=torch.randn(*batch, 2, 1),
275+
train_X=torch.rand(*batch, 2, 2, **tkwargs),
276+
train_Y=torch.randn(*batch, 2, 1, **tkwargs),
255277
covar_module=RBFKernel(
256278
ard_num_dims=2,
257279
batch_shape=batch,
258280
lengthscale_prior=NormalPrior(
259281
# Make this almost singular for easy comparison below.
260-
torch.tensor([[1.0, 1.0]]),
261-
torch.tensor(1e-10),
282+
torch.tensor([[1.0, 1.0]], **tkwargs),
283+
torch.tensor(1e-10, **tkwargs),
262284
),
263285
),
264286
)
@@ -267,4 +289,4 @@ def test_with_multivariate_prior(self) -> None:
267289
sample_all_priors(model)
268290
new_lengthscale = model.covar_module.lengthscale
269291
self.assertFalse(torch.allclose(original_lengthscale, new_lengthscale))
270-
self.assertAllClose(new_lengthscale, torch.ones(*batch, 1, 2))
292+
self.assertAllClose(new_lengthscale, torch.ones(*batch, 1, 2, **tkwargs))

0 commit comments

Comments
 (0)