Skip to content

Commit 98c1504

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Update HigherOrderGP to use new priors & standardize by default (#2555)
Summary: Pull Request resolved: #2555 As titled. This was leading to some test failures when I first tried it. Turns out it was just the reference point being a bit too aggressive and filtering out all of X_baseline. Reviewed By: Balandat Differential Revision: D63478493 fbshipit-source-id: 834c73b453356a3a310fe6ebd8ca03c941f91701
1 parent 8924d1b commit 98c1504

File tree

6 files changed

+41
-27
lines changed

6 files changed

+41
-27
lines changed

botorch/models/higher_order_gp.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,20 @@
2727
from botorch.models.utils import gpt_posterior_settings
2828
from botorch.models.utils.assorted import fantasize as fantasize_flag
2929
from botorch.models.utils.gpytorch_modules import (
30-
get_gaussian_likelihood_with_gamma_prior,
30+
get_covar_module_with_dim_scaled_prior,
31+
get_gaussian_likelihood_with_lognormal_prior,
3132
)
3233
from botorch.posteriors import (
3334
GPyTorchPosterior,
3435
HigherOrderGPPosterior,
3536
TransformedPosterior,
3637
)
38+
from botorch.utils.types import _DefaultType, DEFAULT
3739
from gpytorch.distributions import MultivariateNormal
38-
from gpytorch.kernels import Kernel, MaternKernel
40+
from gpytorch.kernels import Kernel
3941
from gpytorch.likelihoods import Likelihood
4042
from gpytorch.models import ExactGP
41-
from gpytorch.priors.torch_priors import GammaPrior, MultivariateNormalPrior
43+
from gpytorch.priors.torch_priors import MultivariateNormalPrior
4244
from gpytorch.settings import fast_pred_var, skip_posterior_variances
4345
from linear_operator.operators import (
4446
BatchRepeatLinearOperator,
@@ -183,7 +185,7 @@ def __init__(
183185
num_latent_dims: Optional[list[int]] = None,
184186
learn_latent_pars: bool = True,
185187
latent_init: str = "default",
186-
outcome_transform: Optional[OutcomeTransform] = None,
188+
outcome_transform: Union[OutcomeTransform, _DefaultType, None] = DEFAULT,
187189
input_transform: Optional[InputTransform] = None,
188190
):
189191
r"""
@@ -196,7 +198,6 @@ def __init__(
196198
learn_latent_pars: If true, learn the latent parameters.
197199
latent_init: [default or gp] how to initialize the latent parameters.
198200
"""
199-
200201
if input_transform is not None:
201202
input_transform.to(train_X)
202203

@@ -207,7 +208,11 @@ def __init__(
207208
raise NotImplementedError(
208209
"HigherOrderGP currently only supports 1-dim `batch_shape`."
209210
)
210-
211+
if outcome_transform == DEFAULT:
212+
outcome_transform = FlattenedStandardize(
213+
output_shape=train_Y.shape[-num_output_dims:],
214+
batch_shape=batch_shape,
215+
)
211216
if outcome_transform is not None:
212217
if isinstance(outcome_transform, Standardize) and not isinstance(
213218
outcome_transform, FlattenedStandardize
@@ -218,6 +223,7 @@ def __init__(
218223
f"{train_Y.shape[- num_output_dims:]} and batch_shape="
219224
f"{batch_shape} instead.",
220225
RuntimeWarning,
226+
stacklevel=2,
221227
)
222228
outcome_transform = FlattenedStandardize(
223229
output_shape=train_Y.shape[-num_output_dims:],
@@ -232,7 +238,7 @@ def __init__(
232238
self._input_batch_shape = batch_shape
233239

234240
if likelihood is None:
235-
likelihood = get_gaussian_likelihood_with_gamma_prior(
241+
likelihood = get_gaussian_likelihood_with_lognormal_prior(
236242
batch_shape=self._aug_batch_shape
237243
)
238244
else:
@@ -249,11 +255,9 @@ def __init__(
249255
else:
250256
self.covar_modules = ModuleList(
251257
[
252-
MaternKernel(
253-
nu=2.5,
254-
lengthscale_prior=GammaPrior(3.0, 6.0),
255-
batch_shape=self._aug_batch_shape,
258+
get_covar_module_with_dim_scaled_prior(
256259
ard_num_dims=1 if dim > 0 else train_X.shape[-1],
260+
batch_shape=self._aug_batch_shape,
257261
)
258262
for dim in range(self._num_dimensions)
259263
]

botorch/sampling/normal.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _construct_base_samples(self, posterior: Posterior) -> None:
6666
pass # pragma: no cover
6767

6868
def _update_base_samples(
69-
self, posterior: Posterior, base_sampler: NormalMCSampler
69+
self, posterior: Posterior, base_sampler: MCSampler
7070
) -> None:
7171
r"""Update the sampler to use the original base samples for X_baseline.
7272
@@ -102,7 +102,15 @@ def _update_base_samples(
102102
expanded_samples = current_base_samples.view(view_shape).expand(
103103
expanded_shape
104104
)
105-
if isinstance(posterior, (HigherOrderGPPosterior, MultitaskGPPosterior)):
105+
non_transformed_posterior = (
106+
posterior._posterior
107+
if isinstance(posterior, TransformedPosterior)
108+
else posterior
109+
)
110+
if isinstance(
111+
non_transformed_posterior,
112+
(HigherOrderGPPosterior, MultitaskGPPosterior),
113+
):
106114
n_train_samples = current_base_samples.shape[-1] // 2
107115
# The train base samples.
108116
self.base_samples[..., :n_train_samples] = expanded_samples[
@@ -113,11 +121,7 @@ def _update_base_samples(
113121
..., -n_train_samples:
114122
]
115123
else:
116-
batch_shape = (
117-
posterior._posterior.batch_shape
118-
if isinstance(posterior, TransformedPosterior)
119-
else posterior.batch_shape
120-
)
124+
batch_shape = non_transformed_posterior.batch_shape
121125
single_output = (
122126
len(posterior.base_sample_shape) - len(batch_shape)
123127
) == 1

test/acquisition/multi_objective/test_monte_carlo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1857,7 +1857,7 @@ def _test_with_multitask(self, acqf_class: type[AcquisitionFunction]):
18571857
def get_acqf(model):
18581858
return acqf_class(
18591859
model=model,
1860-
ref_point=torch.tensor([0.0, 0.0], **tkwargs),
1860+
ref_point=torch.tensor([-1.0, -1.0], **tkwargs),
18611861
X_baseline=train_x,
18621862
sampler=IIDNormalSampler(sample_shape=torch.Size([2])),
18631863
objective=hogp_obj if isinstance(model, HigherOrderGP) else None,

test/acquisition/test_analytic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def test_posterior_stddev_batch(self):
342342
acqf = PosteriorStandardDeviation(model=mm)
343343
X = torch.empty(3, 1, 1, device=self.device, dtype=dtype)
344344
pm = acqf(X)
345-
self.assertTrue(torch.equal(pm, std.view(-1)))
345+
self.assertAllClose(pm, std.view(-1))
346346
# check for proper error if multi-output model
347347
mean2 = torch.rand(3, 1, 2, device=self.device, dtype=dtype)
348348
std2 = torch.rand_like(mean2)

test/models/test_higher_order_gp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def setUp(self):
3939
train_x = torch.rand(2, 10, 1, device=self.device)
4040
train_y = torch.randn(2, 10, 3, 5, device=self.device)
4141

42-
self.model = HigherOrderGP(train_x, train_y)
42+
self.model = HigherOrderGP(train_x, train_y, outcome_transform=None)
4343

4444
# check that we can assign different kernels and likelihoods
4545
model_2 = HigherOrderGP(
@@ -48,6 +48,7 @@ def setUp(self):
4848
covar_modules=[RBFKernel(), RBFKernel(), RBFKernel()],
4949
likelihood=GaussianLikelihood(),
5050
)
51+
self.assertIsInstance(model_2.outcome_transform, FlattenedStandardize)
5152

5253
model_3 = HigherOrderGP(
5354
train_X=train_x,

test/posteriors/test_higher_order.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from botorch.exceptions.errors import BotorchTensorDimensionError
1010
from botorch.models.higher_order_gp import HigherOrderGP
1111
from botorch.posteriors.higher_order import HigherOrderGPPosterior
12+
from botorch.posteriors.transformed import TransformedPosterior
1213
from botorch.sampling.normal import IIDNormalSampler
1314
from botorch.utils.testing import BotorchTestCase
1415

@@ -22,7 +23,7 @@ def setUp(self):
2223
train_y = torch.randn(2, 10, 3, 5, device=self.device)
2324

2425
m1 = HigherOrderGP(train_x, train_y)
25-
m2 = HigherOrderGP(train_x[0], train_y[0])
26+
m2 = HigherOrderGP(train_x[0], train_y[0], outcome_transform=None)
2627

2728
torch.random.manual_seed(0)
2829
test_x = torch.rand(2, 5, 1, device=self.device)
@@ -32,18 +33,18 @@ def setUp(self):
3233
posterior3 = m2.posterior(test_x)
3334

3435
self.post_list = [
35-
[m1, test_x, posterior1],
36-
[m2, test_x[0], posterior2],
37-
[m2, test_x, posterior3],
36+
[m1, test_x, posterior1, TransformedPosterior],
37+
[m2, test_x[0], posterior2, HigherOrderGPPosterior],
38+
[m2, test_x, posterior3, HigherOrderGPPosterior],
3839
]
3940

4041
def test_HigherOrderGPPosterior(self):
4142
sample_shaping = torch.Size([5, 3, 5])
4243

4344
for post_collection in self.post_list:
44-
model, test_x, posterior = post_collection
45+
model, test_x, posterior, posterior_class = post_collection
4546

46-
self.assertIsInstance(posterior, HigherOrderGPPosterior)
47+
self.assertIsInstance(posterior, posterior_class)
4748

4849
batch_shape = test_x.shape[:-2]
4950
expected_extended_shape = batch_shape + sample_shaping
@@ -105,6 +106,10 @@ def test_HigherOrderGPPosterior(self):
105106

106107
model.eval()
107108
eval_mode_variance = model(test_x).variance.reshape_as(posterior_variance)
109+
if hasattr(model, "outcome_transform"):
110+
eval_mode_variance = model.outcome_transform.untransform(
111+
eval_mode_variance, eval_mode_variance
112+
)[1]
108113
self.assertLess(
109114
(posterior_variance - eval_mode_variance).norm()
110115
/ eval_mode_variance.norm(),

0 commit comments

Comments
 (0)