Skip to content

Commit d753706

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Update sample_all_priors to support wider set of priors (#2371)
Summary: Pull Request resolved: #2371 Addresses #780 Previously, this would pass in `closure(module).shape` as the `sample_shape`, which only worked if the prior was a univariate distribution. `Distribution.sample` produces samples of shape `Distribution._extended_shape(sample_shape) = sample_shape + Distribution._extended_shape()`, so we can calculate the `sample_shape` required to support both univariate and multivariate / batched priors. Reviewed By: dme65 Differential Revision: D58377495 fbshipit-source-id: 17510505012838a3fe670492656be4d13bc0db5e
1 parent f3dd493 commit d753706

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

botorch/optim/utils/model_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,12 @@ def sample_all_priors(model: GPyTorchModel, max_retries: int = 100) -> None:
161161
)
162162
for i in range(max_retries):
163163
try:
164-
setting_closure(module, prior.sample(closure(module).shape))
164+
# Set sample shape, so that the prior samples have the same shape
165+
# as `closure(module)` without having to be repeated.
166+
closure_shape = closure(module).shape
167+
prior_shape = prior._extended_shape()
168+
sample_shape = closure_shape[: -len(prior_shape)]
169+
setting_closure(module, prior.sample(sample_shape=sample_shape))
165170
break
166171
except NotImplementedError:
167172
warn(

test/optim/utils/test_model_utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,14 @@
2525
)
2626
from botorch.utils.testing import BotorchTestCase
2727
from gpytorch.constraints import GreaterThan
28+
from gpytorch.kernels import RBFKernel
2829
from gpytorch.kernels.matern_kernel import MaternKernel
2930
from gpytorch.kernels.scale_kernel import ScaleKernel
3031
from gpytorch.likelihoods import GaussianLikelihood
3132
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
3233
from gpytorch.priors import UniformPrior
3334
from gpytorch.priors.prior import Prior
34-
from gpytorch.priors.torch_priors import GammaPrior
35+
from gpytorch.priors.torch_priors import GammaPrior, NormalPrior
3536

3637

3738
class DummyPrior(Prior):
@@ -244,3 +245,26 @@ def test_sample_all_priors(self):
244245
original_state_dict = dict(deepcopy(mll.model.state_dict()))
245246
with self.assertRaises(RuntimeError):
246247
sample_all_priors(model)
248+
249+
def test_with_multivariate_prior(self) -> None:
250+
# This is modified from https://github.com/pytorch/botorch/issues/780.
251+
for batch in (torch.Size([]), torch.Size([3])):
252+
model = SingleTaskGP(
253+
train_X=torch.randn(*batch, 2, 2),
254+
train_Y=torch.randn(*batch, 2, 1),
255+
covar_module=RBFKernel(
256+
ard_num_dims=2,
257+
batch_shape=batch,
258+
lengthscale_prior=NormalPrior(
259+
# Make this almost singular for easy comparison below.
260+
torch.tensor([[1.0, 1.0]]),
261+
torch.tensor(1e-10),
262+
),
263+
),
264+
)
265+
# Check that the lengthscale is replaced with the sampled values.
266+
original_lengthscale = model.covar_module.lengthscale
267+
sample_all_priors(model)
268+
new_lengthscale = model.covar_module.lengthscale
269+
self.assertFalse(torch.allclose(original_lengthscale, new_lengthscale))
270+
self.assertAllClose(new_lengthscale, torch.ones(*batch, 1, 2))

0 commit comments

Comments
 (0)