Skip to content

Commit 6137f18

Browse files
esantorellafacebook-github-bot
authored andcommitted
Reduce repetition in test_model_list_gp_regression (#1520)
Summary: Pull Request resolved: #1520 `TestModelListGP.test_ModelListGP` and `TestModelListGP.test_ModelListGP_fixed_noise` share almost all of their code, so I put most of their code in one helper method they both call, with the parts that depend on whether there is fixed nosie remaining in `TestModelListGP.test_ModelListGP` and `TestModelListGP.test_ModelListGP_fixed_noise`. This adds testing for `test_ModelListGP_fixed_noise` that it was missing but `test_ModelListGP` had. Reviewed By: Balandat Differential Revision: D41508834 fbshipit-source-id: 018ab84f9017e5e86f876368e772fceb40f84f12
1 parent 706e2a1 commit 6137f18

File tree

2 files changed

+39
-77
lines changed

2 files changed

+39
-77
lines changed

botorch/models/model_list_gp_regression.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def __init__(self, *gp_models: GPyTorchModel) -> None:
5252
"""
5353
super().__init__(*gp_models)
5454

55+
# pyre-fixme[14]: Inconsistent override. Here `X` is a List[Tensor], but in the
56+
# parent method it's a Tensor.
5557
def condition_on_observations(
5658
self, X: List[Tensor], Y: Tensor, **kwargs: Any
5759
) -> ModelListGP:

test/models/test_model_list_gp_regression.py

Lines changed: 37 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,15 @@ def _get_model(fixed_noise=False, use_octf=False, use_intf=False, **tkwargs):
7272

7373

7474
class TestModelListGP(BotorchTestCase):
75-
def test_ModelListGP(self):
76-
for dtype, use_octf in itertools.product(
77-
(torch.float, torch.double), (False, True)
78-
):
75+
def _base_test_ModelListGP(
76+
self, fixed_noise: bool, dtype, use_octf: bool
77+
) -> ModelListGP:
78+
# this is to make review easier -- will be removed in the next
79+
# commit in the stack and never landed
80+
unneccessary_condition_for_indentation_remove_me = True
81+
if unneccessary_condition_for_indentation_remove_me:
7982
tkwargs = {"device": self.device, "dtype": dtype}
80-
model = _get_model(use_octf=use_octf, **tkwargs)
83+
model = _get_model(fixed_noise=fixed_noise, use_octf=use_octf, **tkwargs)
8184
self.assertIsInstance(model, ModelListGP)
8285
self.assertIsInstance(model.likelihood, LikelihoodList)
8386
for m in model.models:
@@ -135,11 +138,6 @@ def test_ModelListGP(self):
135138
expected_var = tmp_tf.untransform_posterior(p0_tf).variance
136139
self.assertTrue(torch.allclose(p0.variance, expected_var))
137140

138-
# test observation_noise
139-
posterior = model.posterior(test_x, observation_noise=True)
140-
self.assertIsInstance(posterior, GPyTorchPosterior)
141-
self.assertIsInstance(posterior.distribution, MultitaskMultivariateNormal)
142-
143141
# test output_indices
144142
posterior = model.posterior(
145143
test_x, output_indices=[0], observation_noise=True
@@ -150,28 +148,35 @@ def test_ModelListGP(self):
150148
# test condition_on_observations
151149
f_x = [torch.rand(2, 1, **tkwargs) for _ in range(2)]
152150
f_y = torch.rand(2, 2, **tkwargs)
153-
cm = model.condition_on_observations(f_x, f_y)
151+
if fixed_noise:
152+
noise = 0.1 + 0.1 * torch.rand_like(f_y)
153+
cond_kwargs = {"noise": noise}
154+
else:
155+
cond_kwargs = {}
156+
cm = model.condition_on_observations(f_x, f_y, **cond_kwargs)
154157
self.assertIsInstance(cm, ModelListGP)
155158

156159
# test condition_on_observations batched
157160
f_x = [torch.rand(3, 2, 1, **tkwargs) for _ in range(2)]
158161
f_y = torch.rand(3, 2, 2, **tkwargs)
159-
cm = model.condition_on_observations(f_x, f_y)
162+
cm = model.condition_on_observations(f_x, f_y, **cond_kwargs)
160163
self.assertIsInstance(cm, ModelListGP)
161164

162165
# test condition_on_observations batched (fast fantasies)
163166
f_x = [torch.rand(2, 1, **tkwargs) for _ in range(2)]
164167
f_y = torch.rand(3, 2, 2, **tkwargs)
165-
cm = model.condition_on_observations(f_x, f_y)
168+
cm = model.condition_on_observations(f_x, f_y, **cond_kwargs)
166169
self.assertIsInstance(cm, ModelListGP)
167170

168171
# test condition_on_observations (incorrect input shape error)
169172
with self.assertRaises(BotorchTensorDimensionError):
170-
model.condition_on_observations(f_x, torch.rand(3, 2, 3, **tkwargs))
173+
model.condition_on_observations(
174+
f_x, torch.rand(3, 2, 3, **tkwargs), **cond_kwargs
175+
)
171176

172177
# test X having wrong size
173178
with self.assertRaises(AssertionError):
174-
cm = model.condition_on_observations(f_x[:1], f_y)
179+
model.condition_on_observations(f_x[:1], f_y)
175180

176181
# test posterior transform
177182
X = torch.rand(3, 1, **tkwargs)
@@ -185,82 +190,37 @@ def test_ModelListGP(self):
185190
)
186191
)
187192

188-
def test_ModelListGP_fixed_noise(self):
193+
return model
194+
195+
def test_ModelListGP(self) -> None:
189196
for dtype, use_octf in itertools.product(
190197
(torch.float, torch.double), (False, True)
191198
):
192-
tkwargs = {"device": self.device, "dtype": dtype}
193-
model = _get_model(fixed_noise=True, use_octf=use_octf, **tkwargs)
194-
self.assertIsInstance(model, ModelListGP)
195-
self.assertIsInstance(model.likelihood, LikelihoodList)
196-
for m in model.models:
197-
self.assertIsInstance(m.mean_module, ConstantMean)
198-
self.assertIsInstance(m.covar_module, ScaleKernel)
199-
matern_kernel = m.covar_module.base_kernel
200-
self.assertIsInstance(matern_kernel, MaternKernel)
201-
self.assertIsInstance(matern_kernel.lengthscale_prior, GammaPrior)
202199

203-
# test model fitting
204-
mll = SumMarginalLogLikelihood(model.likelihood, model)
205-
for mll_ in mll.mlls:
206-
self.assertIsInstance(mll_, ExactMarginalLogLikelihood)
207-
with warnings.catch_warnings():
208-
warnings.filterwarnings("ignore", category=OptimizationWarning)
209-
mll = fit_gpytorch_mll(
210-
mll, optimizer_kwargs={"options": {"maxiter": 1}}, max_attempts=1
211-
)
200+
model = self._base_test_ModelListGP(
201+
fixed_noise=False, dtype=dtype, use_octf=use_octf
202+
)
203+
tkwargs = {"device": self.device, "dtype": dtype}
212204

213-
# test posterior
205+
# test observation_noise
214206
test_x = torch.tensor([[0.25], [0.75]], **tkwargs)
215-
posterior = model.posterior(test_x)
207+
posterior = model.posterior(test_x, observation_noise=True)
216208
self.assertIsInstance(posterior, GPyTorchPosterior)
217209
self.assertIsInstance(posterior.distribution, MultitaskMultivariateNormal)
218-
if use_octf:
219-
# ensure un-transformation is applied
220-
submodel = model.models[0]
221-
p0 = submodel.posterior(test_x)
222-
tmp_tf = submodel.outcome_transform
223-
del submodel.outcome_transform
224-
p0_tf = submodel.posterior(test_x)
225-
submodel.outcome_transform = tmp_tf
226-
expected_var = tmp_tf.untransform_posterior(p0_tf).variance
227-
self.assertTrue(torch.allclose(p0.variance, expected_var))
228210

229-
# test output_indices
230-
posterior = model.posterior(
231-
test_x, output_indices=[0], observation_noise=True
232-
)
233-
self.assertIsInstance(posterior, GPyTorchPosterior)
234-
self.assertIsInstance(posterior.distribution, MultivariateNormal)
211+
def test_ModelListGP_fixed_noise(self) -> None:
235212

236-
# test condition_on_observations
213+
for dtype, use_octf in itertools.product(
214+
(torch.float, torch.double), (False, True)
215+
):
216+
model = self._base_test_ModelListGP(
217+
fixed_noise=True, dtype=dtype, use_octf=use_octf
218+
)
219+
tkwargs = {"device": self.device, "dtype": dtype}
237220
f_x = [torch.rand(2, 1, **tkwargs) for _ in range(2)]
238221
f_y = torch.rand(2, 2, **tkwargs)
239-
noise = 0.1 + 0.1 * torch.rand_like(f_y)
240-
cm = model.condition_on_observations(f_x, f_y, noise=noise)
241-
self.assertIsInstance(cm, ModelListGP)
242-
243-
# test condition_on_observations batched
244-
f_x = [torch.rand(3, 2, 1, **tkwargs) for _ in range(2)]
245-
f_y = torch.rand(3, 2, 2, **tkwargs)
246-
noise = 0.1 + 0.1 * torch.rand_like(f_y)
247-
cm = model.condition_on_observations(f_x, f_y, noise=noise)
248-
self.assertIsInstance(cm, ModelListGP)
249-
250-
# test condition_on_observations batched (fast fantasies)
251-
f_x = [torch.rand(2, 1, **tkwargs) for _ in range(2)]
252-
f_y = torch.rand(3, 2, 2, **tkwargs)
253-
noise = 0.1 + 0.1 * torch.rand(2, 2, **tkwargs)
254-
cm = model.condition_on_observations(f_x, f_y, noise=noise)
255-
self.assertIsInstance(cm, ModelListGP)
256222

257-
# test condition_on_observations (incorrect input shape error)
258-
with self.assertRaises(BotorchTensorDimensionError):
259-
model.condition_on_observations(
260-
f_x, torch.rand(3, 2, 3, **tkwargs), noise=noise
261-
)
262223
# test condition_on_observations (incorrect noise shape error)
263-
f_y = torch.rand(2, 2, **tkwargs)
264224
with self.assertRaises(BotorchTensorDimensionError):
265225
model.condition_on_observations(
266226
f_x, f_y, noise=torch.rand(2, 3, **tkwargs)

0 commit comments

Comments
 (0)