Skip to content

Commit 504ccea

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Adds test_helpers. Do not use relative imports. (#2133)
Summary: Pull Request resolved: #2133 Moves test utilities that were used in multiple test files to `botorch/utils/test_helpers.py`. This makes it possible to remove all relative imports from the test files, which does not play well with some internal tooling we want to use. Going forward, tests, like the rest of BoTorch, should only use absolute imports. Reviewed By: esantorella Differential Revision: D51767702 fbshipit-source-id: 7d7d32a4e06f8217159a65baa96240c49909b5d0
1 parent c14808f commit 504ccea

15 files changed

+209
-174
lines changed

botorch/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
from botorch.utils import manual_seed
3131

3232
try:
33-
from botorch.version import version as __version__
33+
# Marking this as a manual import to avoid autodeps complaints
34+
# due to imports from non-existent file.
35+
from botorch.version import version as __version__ # @manual
3436
except Exception: # pragma: no cover
35-
__version__ = "Unknown" # pragma: no cover
37+
__version__ = "Unknown"
3638

3739
logger.info(
3840
"Turning off `fast_computations` in linear operator and increasing "

botorch/utils/test_helpers.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Dummy classes and other helpers that are used in multiple test files
9+
should be defined here to avoid relative imports.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import math
15+
from typing import Optional, Tuple
16+
17+
import torch
18+
from botorch.acquisition.objective import PosteriorTransform
19+
from botorch.models.gpytorch import GPyTorchModel
20+
from botorch.models.model import FantasizeMixin, Model
21+
from botorch.models.transforms.outcome import Standardize
22+
from botorch.models.utils import add_output_dim
23+
from botorch.models.utils.assorted import fantasize
24+
from botorch.posteriors.posterior import Posterior
25+
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
26+
from gpytorch.distributions.multivariate_normal import MultivariateNormal
27+
from gpytorch.kernels import RBFKernel, ScaleKernel
28+
from gpytorch.likelihoods.gaussian_likelihood import (
29+
FixedNoiseGaussianLikelihood,
30+
GaussianLikelihood,
31+
)
32+
from gpytorch.means import ConstantMean
33+
from gpytorch.models.exact_gp import ExactGP
34+
from torch import Size, Tensor
35+
from torch.nn.functional import pad
36+
37+
38+
def get_sample_moments(samples: Tensor, sample_shape: Size) -> Tuple[Tensor, Tensor]:
39+
"""Computes the mean and covariance of a set of samples.
40+
41+
Args:
42+
samples: A tensor of shape `sample_shape x batch_shape x q`.
43+
sample_shape: The sample_shape input used while generating the samples using
44+
the pathwise sampling API.
45+
"""
46+
sample_dim = len(sample_shape)
47+
samples = samples.view(-1, *samples.shape[sample_dim:])
48+
loc = samples.mean(dim=0)
49+
residuals = (samples - loc).permute(*range(1, samples.ndim), 0)
50+
return loc, (residuals @ residuals.transpose(-2, -1)) / sample_shape.numel()
51+
52+
53+
def standardize_moments(
54+
transform: Standardize,
55+
loc: Tensor,
56+
covariance_matrix: Tensor,
57+
) -> Tuple[Tensor, Tensor]:
58+
"""Standardizes the loc and covariance_matrix using the mean and standard
59+
deviations from a Standardize transform.
60+
"""
61+
m = transform.means.squeeze().unsqueeze(-1)
62+
s = transform.stdvs.squeeze().reciprocal().unsqueeze(-1)
63+
loc = s * (loc - m)
64+
correlation_matrix = s.unsqueeze(-1) * covariance_matrix * s.unsqueeze(-2)
65+
return loc, correlation_matrix
66+
67+
68+
def gen_multi_task_dataset(
69+
yvar: Optional[float] = None, **tkwargs
70+
) -> Tuple[MultiTaskDataset, Tuple[Tensor, Tensor, Tensor]]:
71+
"""Constructs a multi-task dataset with two tasks, each with 10 data points."""
72+
X = torch.linspace(0, 0.95, 10, **tkwargs) + 0.05 * torch.rand(10, **tkwargs)
73+
X = X.unsqueeze(dim=-1)
74+
Y1 = torch.sin(X * (2 * math.pi)) + torch.randn_like(X) * 0.2
75+
Y2 = torch.cos(X * (2 * math.pi)) + torch.randn_like(X) * 0.2
76+
train_X = torch.cat([pad(X, (1, 0), value=i) for i in range(2)])
77+
train_Y = torch.cat([Y1, Y2])
78+
79+
Yvar1 = None if yvar is None else torch.full_like(Y1, yvar)
80+
Yvar2 = None if yvar is None else torch.full_like(Y2, yvar)
81+
train_Yvar = None if yvar is None else torch.cat([Yvar1, Yvar2])
82+
datasets = [
83+
SupervisedDataset(
84+
X=train_X[:10],
85+
Y=Y1,
86+
Yvar=Yvar1,
87+
feature_names=["task", "X"],
88+
outcome_names=["y"],
89+
),
90+
SupervisedDataset(
91+
X=train_X[10:],
92+
Y=Y2,
93+
Yvar=Yvar2,
94+
feature_names=["task", "X"],
95+
outcome_names=["y1"],
96+
),
97+
]
98+
dataset = MultiTaskDataset(
99+
datasets=datasets, target_outcome_name="y", task_feature_index=0
100+
)
101+
return dataset, (train_X, train_Y, train_Yvar)
102+
103+
104+
def get_pvar_expected(posterior: Posterior, model: Model, X: Tensor, m: int) -> Tensor:
105+
"""Computes the expected variance of a posterior after adding the
106+
predictive noise from the likelihood.
107+
"""
108+
X = model.transform_inputs(X)
109+
lh_kwargs = {}
110+
if isinstance(model.likelihood, FixedNoiseGaussianLikelihood):
111+
lh_kwargs["noise"] = model.likelihood.noise.mean().expand(X.shape[:-1])
112+
if m == 1:
113+
return model.likelihood(
114+
posterior.distribution, X, **lh_kwargs
115+
).variance.unsqueeze(-1)
116+
X_, odi = add_output_dim(X=X, original_batch_shape=model._input_batch_shape)
117+
pvar_exp = model.likelihood(model(X_), X_, **lh_kwargs).variance
118+
return torch.stack([pvar_exp.select(dim=odi, index=i) for i in range(m)], dim=-1)
119+
120+
121+
class DummyNonScalarizingPosteriorTransform(PosteriorTransform):
122+
scalarize = False
123+
124+
def evaluate(self, Y):
125+
pass # pragma: no cover
126+
127+
def forward(self, posterior):
128+
pass # pragma: no cover
129+
130+
131+
class SimpleGPyTorchModel(GPyTorchModel, ExactGP, FantasizeMixin):
132+
last_fantasize_flag: bool = False
133+
134+
def __init__(self, train_X, train_Y, outcome_transform=None, input_transform=None):
135+
r"""
136+
Args:
137+
train_X: A tensor of inputs, passed to self.transform_inputs.
138+
train_Y: Passed to outcome_transform.
139+
outcome_transform: Transform applied to train_Y.
140+
input_transform: A Module that performs the input transformation, passed to
141+
self.transform_inputs.
142+
"""
143+
with torch.no_grad():
144+
transformed_X = self.transform_inputs(
145+
X=train_X, input_transform=input_transform
146+
)
147+
if outcome_transform is not None:
148+
train_Y, _ = outcome_transform(train_Y)
149+
self._validate_tensor_args(transformed_X, train_Y)
150+
train_Y = train_Y.squeeze(-1)
151+
likelihood = GaussianLikelihood()
152+
super().__init__(train_X, train_Y, likelihood)
153+
self.mean_module = ConstantMean()
154+
self.covar_module = ScaleKernel(RBFKernel())
155+
if outcome_transform is not None:
156+
self.outcome_transform = outcome_transform
157+
if input_transform is not None:
158+
self.input_transform = input_transform
159+
self._num_outputs = 1
160+
self.to(train_X)
161+
self.transformed_call_args = []
162+
163+
def forward(self, x):
164+
self.last_fantasize_flag = fantasize.on()
165+
if self.training:
166+
x = self.transform_inputs(x)
167+
self.transformed_call_args.append(x)
168+
mean_x = self.mean_module(x)
169+
covar_x = self.covar_module(x)
170+
return MultivariateNormal(mean_x, covar_x)

botorch/utils/transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
from botorch.utils.safe_math import logmeanexp
1919
from torch import Tensor
2020

21-
if TYPE_CHECKING:
22-
from botorch.acquisition import AcquisitionFunction # pragma: no cover
23-
from botorch.model import Model # pragma: no cover
21+
if TYPE_CHECKING: # pragma: no cover
22+
from botorch.acquisition import AcquisitionFunction
23+
from botorch.models.model import Model
2424

2525

2626
def standardize(Y: Tensor) -> Tensor:

sphinx/source/utils.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,16 @@ Sampling from GP priors
5757
.. automodule:: botorch.utils.gp_sampling
5858
:members:
5959

60-
6160
Testing
6261
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6362
.. automodule:: botorch.utils.testing
6463
:members:
6564

65+
Test Helpers
66+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
67+
.. automodule:: botorch.utils.test_helpers
68+
:members:
69+
6670
Torch
6771
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6872
.. automodule:: botorch.utils.torch

test/acquisition/test_knowledge_gradient.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,10 @@
3030
from botorch.optim.utils import _filter_kwargs
3131
from botorch.posteriors.gpytorch import GPyTorchPosterior
3232
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
33+
from botorch.utils.test_helpers import DummyNonScalarizingPosteriorTransform
3334
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
3435
from gpytorch.distributions import MultitaskMultivariateNormal
3536

36-
from .test_monte_carlo import DummyNonScalarizingPosteriorTransform
37-
3837
NO = "botorch.utils.testing.MockModel.num_outputs"
3938

4039

test/acquisition/test_monte_carlo.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@
2626
ConstrainedMCObjective,
2727
GenericMCObjective,
2828
IdentityMCObjective,
29-
PosteriorTransform,
3029
ScalarizedPosteriorTransform,
3130
)
3231
from botorch.acquisition.utils import prune_inferior_points
3332
from botorch.exceptions import BotorchWarning, UnsupportedError
3433
from botorch.models import SingleTaskGP
3534
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
3635
from botorch.utils.low_rank import sample_cached_cholesky
36+
from botorch.utils.test_helpers import DummyNonScalarizingPosteriorTransform
3737
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
3838
from botorch.utils.transforms import standardize
3939
from torch import Tensor
@@ -49,16 +49,6 @@ def _sample_forward(self, X):
4949
pass
5050

5151

52-
class DummyNonScalarizingPosteriorTransform(PosteriorTransform):
53-
scalarize = False
54-
55-
def evaluate(self, Y):
56-
pass # pragma: no cover
57-
58-
def forward(self, posterior):
59-
pass # pragma: no cover
60-
61-
6252
def infeasible_con(samples: Tensor) -> Tensor:
6353
return torch.ones_like(samples[..., 0])
6454

test/models/test_converter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@
2020
)
2121
from botorch.models.transforms.input import AppendFeatures, Normalize
2222
from botorch.models.transforms.outcome import Standardize
23+
from botorch.utils.test_helpers import SimpleGPyTorchModel
2324
from botorch.utils.testing import BotorchTestCase
2425
from gpytorch.kernels import RBFKernel
2526
from gpytorch.likelihoods import GaussianLikelihood
2627
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
2728

28-
from .test_gpytorch import SimpleGPyTorchModel
29-
3029

3130
class TestConverters(BotorchTestCase):
3231
def test_batched_to_model_list(self):

test/models/test_fully_bayesian_multitask.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,13 @@
4242
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
4343
NondominatedPartitioning,
4444
)
45+
from botorch.utils.test_helpers import gen_multi_task_dataset
4546
from botorch.utils.testing import BotorchTestCase
4647
from gpytorch.kernels import MaternKernel, ScaleKernel
4748
from gpytorch.likelihoods import FixedNoiseGaussianLikelihood
4849
from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood
4950
from gpytorch.means import ConstantMean
5051

51-
from .test_multitask import _gen_multi_task_dataset
52-
5352
EXPECTED_KEYS = [
5453
"latent_features",
5554
"mean_module.raw_constant",
@@ -566,7 +565,7 @@ def test_construct_inputs(self):
566565
for dtype, infer_noise in [(torch.float, False), (torch.double, True)]:
567566
tkwargs = {"device": self.device, "dtype": dtype}
568567
task_feature = 0
569-
datasets, (train_X, train_Y, train_Yvar) = _gen_multi_task_dataset(
568+
datasets, (train_X, train_Y, train_Yvar) = gen_multi_task_dataset(
570569
yvar=None if infer_noise else 0.05, **tkwargs
571570
)
572571

test/models/test_gp_regression.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
)
1818
from botorch.models.transforms import Normalize, Standardize
1919
from botorch.models.transforms.input import InputStandardize
20-
from botorch.models.utils import add_output_dim
2120
from botorch.posteriors import GPyTorchPosterior
2221
from botorch.sampling import SobolQMCNormalSampler
2322
from botorch.utils.datasets import SupervisedDataset
2423
from botorch.utils.sampling import manual_seed
24+
from botorch.utils.test_helpers import get_pvar_expected
2525
from botorch.utils.testing import _get_random_data, BotorchTestCase
2626
from gpytorch.kernels import MaternKernel, RBFKernel, ScaleKernel
2727
from gpytorch.likelihoods import (
@@ -142,7 +142,7 @@ def test_gp(self, double_only: bool = False):
142142
self.assertAllClose(posterior_pred.variance, expected_var)
143143
else:
144144
pvar = posterior_pred.variance
145-
pvar_exp = _get_pvar_expected(posterior, model, X, m)
145+
pvar_exp = get_pvar_expected(posterior, model, X, m)
146146
self.assertAllClose(pvar, pvar_exp, rtol=1e-4, atol=1e-5)
147147

148148
# Tensor valued observation noise.
@@ -176,7 +176,7 @@ def test_gp(self, double_only: bool = False):
176176
self.assertAllClose(posterior_pred.variance, expected_var)
177177
else:
178178
pvar = posterior_pred.variance
179-
pvar_exp = _get_pvar_expected(posterior, model, X, m)
179+
pvar_exp = get_pvar_expected(posterior, model, X, m)
180180
self.assertAllClose(pvar, pvar_exp, rtol=1e-4, atol=1e-5)
181181

182182
def test_custom_init(self):
@@ -599,17 +599,3 @@ def test_condition_on_observations(self):
599599
def test_subset_model(self):
600600
with self.assertRaises(NotImplementedError):
601601
super().test_subset_model()
602-
603-
604-
def _get_pvar_expected(posterior, model, X, m):
605-
X = model.transform_inputs(X)
606-
lh_kwargs = {}
607-
if isinstance(model.likelihood, FixedNoiseGaussianLikelihood):
608-
lh_kwargs["noise"] = model.likelihood.noise.mean().expand(X.shape[:-1])
609-
if m == 1:
610-
return model.likelihood(
611-
posterior.distribution, X, **lh_kwargs
612-
).variance.unsqueeze(-1)
613-
X_, odi = add_output_dim(X=X, original_batch_shape=model._input_batch_shape)
614-
pvar_exp = model.likelihood(model(X_), X_, **lh_kwargs).variance
615-
return torch.stack([pvar_exp.select(dim=odi, index=i) for i in range(m)], dim=-1)

test/models/test_gp_regression_mixed.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from botorch.posteriors import GPyTorchPosterior
1818
from botorch.sampling import SobolQMCNormalSampler
1919
from botorch.utils.datasets import SupervisedDataset
20+
from botorch.utils.test_helpers import get_pvar_expected
2021
from botorch.utils.testing import _get_random_data, BotorchTestCase
2122
from gpytorch.kernels.kernel import AdditiveKernel, ProductKernel
2223
from gpytorch.kernels.matern_kernel import MaternKernel
@@ -26,8 +27,6 @@
2627
from gpytorch.means import ConstantMean
2728
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
2829

29-
from .test_gp_regression import _get_pvar_expected
30-
3130

3231
class TestMixedSingleTaskGP(BotorchTestCase):
3332
observed_noise = False
@@ -119,7 +118,7 @@ def test_gp(self):
119118
self.assertEqual(posterior_pred.mean.shape, expected_shape)
120119
self.assertEqual(posterior_pred.variance.shape, expected_shape)
121120
pvar = posterior_pred.variance
122-
pvar_exp = _get_pvar_expected(posterior, model, X, m)
121+
pvar_exp = get_pvar_expected(posterior, model, X, m)
123122
self.assertAllClose(pvar, pvar_exp, rtol=1e-4, atol=1e-5)
124123

125124
# test batch evaluation
@@ -133,7 +132,7 @@ def test_gp(self):
133132
self.assertIsInstance(posterior_pred, GPyTorchPosterior)
134133
self.assertEqual(posterior_pred.mean.shape, expected_shape)
135134
pvar = posterior_pred.variance
136-
pvar_exp = _get_pvar_expected(posterior, model, X, m)
135+
pvar_exp = get_pvar_expected(posterior, model, X, m)
137136
self.assertAllClose(pvar, pvar_exp, rtol=1e-4, atol=1e-5)
138137

139138
# test that model converter throws an exception

0 commit comments

Comments
 (0)