Skip to content

Commit 84e716f

Browse files
Balandatfacebook-github-bot
authored andcommitted
Add num_outputs property to the Model API (#330)
Summary: Pull Request resolved: #330 Addresses #295. Reviewed By: liangshi7 Differential Revision: D18737863 fbshipit-source-id: 3df4d6b2b2b65609ac508e990d735838df33357b
1 parent b26062f commit 84e716f

File tree

10 files changed

+47
-28
lines changed

10 files changed

+47
-28
lines changed

botorch/models/cost.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
self.fixed_cost = fixed_cost
4848
weights = torch.tensor([fidelity_weights[i] for i in self.fidelity_dims])
4949
self.register_buffer("weights", weights)
50+
self._num_outputs = 1
5051

5152
def forward(self, X: Tensor) -> Tensor:
5253
r"""Evaluate the cost on a candidate set X.

botorch/models/deterministic.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ def forward(self, X: Tensor) -> Tensor:
3737
"""
3838
pass # pragma: no cover
3939

40+
@property
41+
def num_outputs(self) -> int:
42+
r"""The number of outputs of the model."""
43+
return self._num_outputs
44+
4045
def posterior(
4146
self, X: Tensor, output_indices: Optional[List[int]] = None, **kwargs: Any
4247
) -> DeterministicPosterior:
@@ -55,16 +60,18 @@ def posterior(
5560
class GenericDeterministicModel(DeterministicModel):
5661
r"""A generic deterministic model constructed from a callable."""
5762

58-
def __init__(self, f: Callable[[Tensor], Tensor]) -> None:
63+
def __init__(self, f: Callable[[Tensor], Tensor], num_outputs: int = 1) -> None:
5964
r"""A generic deterministic model constructed from a callable.
6065
6166
Args:
6267
f: A callable mapping a `batch_shape x n x d`-dim input tensor `X`
6368
to a `batch_shape x n x m`-dimensional output tensor (the
6469
outcome dimension `m` must be explicit, even if `m=1`).
70+
num_outputs: The number of outputs `m`.
6571
"""
6672
super().__init__()
6773
self._f = f
74+
self._num_outputs = num_outputs
6875

6976
def forward(self, X: Tensor) -> Tensor:
7077
r"""Compute the (deterministic) model output at X.
@@ -104,6 +111,7 @@ def __init__(self, a: Tensor, b: Union[Tensor, float] = 0.01) -> None:
104111
super().__init__()
105112
self.register_buffer("a", a)
106113
self.register_buffer("b", b.expand(a.size(-1)))
114+
self._num_outputs = a.size(-1)
107115

108116
def forward(self, X: Tensor) -> Tensor:
109117
return self.b + torch.einsum("...d,dm", X, self.a)

botorch/models/gpytorch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"""
1313

1414
import warnings
15-
from abc import ABC, abstractproperty
15+
from abc import ABC
1616
from typing import Any, Iterator, List, Optional, Tuple, Union
1717

1818
import torch
@@ -89,6 +89,11 @@ def _validate_tensor_args(
8989
f" {Yvar.shape})."
9090
)
9191

92+
@property
93+
def num_outputs(self) -> int:
94+
r"""The number of outputs of the model."""
95+
return self._num_outputs
96+
9297
def posterior(
9398
self, X: Tensor, observation_noise: Union[bool, Tensor] = False, **kwargs: Any
9499
) -> GPyTorchPosterior:
@@ -361,11 +366,6 @@ class ModelListGPyTorchModel(GPyTorchModel, ABC):
361366
evaluation of submodels.
362367
"""
363368

364-
@abstractproperty
365-
def num_outputs(self) -> int:
366-
r"""The number of outputs of the model."""
367-
pass # pragma: no cover
368-
369369
def posterior(
370370
self,
371371
X: Tensor,

botorch/models/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ def posterior(
4949
"""
5050
pass # pragma: no cover
5151

52+
@property
53+
def num_outputs(self) -> int:
54+
r"""The number of outputs of the model."""
55+
cls_name = self.__class__.__name__
56+
raise NotImplementedError(f"{cls_name} does not define num_outputs property")
57+
5258
def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> "Model":
5359
r"""Condition the model on new observations.
5460

botorch/models/multitask.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(
8686
if any(t not in all_tasks for t in output_tasks):
8787
raise RuntimeError("All output tasks must be present in input data.")
8888
self._output_tasks = output_tasks
89+
self._num_outputs = len(output_tasks)
8990

9091
# TODO (T41270962): Support task-specific noise levels in likelihood
9192
likelihood = GaussianLikelihood(noise_prior=GammaPrior(1.1, 0.05))

test/models/test_cost.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def test_affine_fidelity_cost_model(self):
1818
X = torch.rand(*batch_shape, 3, 4, device=self.device, dtype=dtype)
1919
# test default parameters
2020
model = AffineFidelityCostModel()
21+
self.assertEqual(model.num_outputs, 1)
2122
self.assertEqual(model.fidelity_dims, [-1])
2223
self.assertEqual(model.fixed_cost, 0.01)
2324
cost = model(X)

test/models/test_deterministic.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def f(X):
2525
return X.mean(dim=-1, keepdim=True)
2626

2727
model = GenericDeterministicModel(f)
28+
self.assertEqual(model.num_outputs, 1)
2829
X = torch.rand(3, 2)
2930
# basic test
3031
p = model.posterior(X)
@@ -34,7 +35,8 @@ def f(X):
3435
with self.assertRaises(UnsupportedError):
3536
model.posterior(X, observation_noise=True)
3637
# check output indices
37-
model = GenericDeterministicModel(lambda X: X)
38+
model = GenericDeterministicModel(lambda X: X, num_outputs=2)
39+
self.assertEqual(model.num_outputs, 2)
3840
p = model.posterior(X, output_indices=[0])
3941
self.assertTrue(torch.equal(p.mean, X[..., [0]]))
4042

@@ -48,6 +50,7 @@ def test_AffineDeterministicModel(self):
4850
# test one-dim output
4951
a = torch.rand(3, 1)
5052
model = AffineDeterministicModel(a)
53+
self.assertEqual(model.num_outputs, 1)
5154
for shape in ((4, 3), (1, 4, 3)):
5255
X = torch.rand(*shape)
5356
p = model.posterior(X)
@@ -56,6 +59,7 @@ def test_AffineDeterministicModel(self):
5659
# # test two-dim output
5760
a = torch.rand(3, 2)
5861
model = AffineDeterministicModel(a)
62+
self.assertEqual(model.num_outputs, 2)
5963
for shape in ((4, 3), (1, 4, 3)):
6064
X = torch.rand(*shape)
6165
p = model.posterior(X)

test/models/test_gpytorch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(self, train_X, train_Y, outcome_transform=None):
4040
self.covar_module = ScaleKernel(RBFKernel())
4141
if outcome_transform is not None:
4242
self.outcome_transform = outcome_transform
43+
self._num_outputs = 1
4344
self.to(train_X)
4445

4546
def forward(self, x):
@@ -84,6 +85,7 @@ def test_gpytorch_model(self):
8485
train_Y = torch.sin(train_X)
8586
# basic test
8687
model = SimpleGPyTorchModel(train_X, train_Y, octf)
88+
self.assertEqual(model.num_outputs, 1)
8789
test_X = torch.rand(2, 1, **tkwargs)
8890
posterior = model.posterior(test_X)
8991
self.assertIsInstance(posterior, GPyTorchPosterior)
@@ -175,6 +177,7 @@ def test_batched_multi_output_gpytorch_model(self):
175177
train_Y = torch.cat([torch.sin(train_X), torch.cos(train_X)], dim=-1)
176178
# basic test
177179
model = SimpleBatchedMultiOutputGPyTorchModel(train_X, train_Y)
180+
self.assertEqual(model.num_outputs, 2)
178181
test_X = torch.rand(2, 1, **tkwargs)
179182
posterior = model.posterior(test_X)
180183
self.assertIsInstance(posterior, GPyTorchPosterior)
@@ -226,6 +229,7 @@ def test_model_list_gpytorch_model(self):
226229
m1 = SimpleGPyTorchModel(train_X1, train_Y1)
227230
m2 = SimpleGPyTorchModel(train_X2, train_Y2)
228231
model = SimpleModelListGPyTorchModel(m1, m2)
232+
self.assertEqual(model.num_outputs, 2)
229233
test_X = torch.rand(2, 1, **tkwargs)
230234
posterior = model.posterior(test_X)
231235
self.assertIsInstance(posterior, GPyTorchPosterior)

test/models/test_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@ def test_not_so_abstract_base_model(self):
2222
model = NotSoAbstractBaseModel()
2323
with self.assertRaises(NotImplementedError):
2424
model.condition_on_observations(None, None)
25+
with self.assertRaises(NotImplementedError):
26+
model.num_outputs

test/models/test_multitask.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,11 @@ def _get_fixed_noise_model_single_output(**tkwargs):
6565

6666
class TestMultiTaskGP(BotorchTestCase):
6767
def test_MultiTaskGP(self):
68-
for double in (False, True):
69-
tkwargs = {
70-
"device": self.device,
71-
"dtype": torch.double if double else torch.float,
72-
}
68+
for dtype in (torch.float, torch.double):
69+
tkwargs = {"device": self.device, "dtype": dtype}
7370
model = _get_model(**tkwargs)
7471
self.assertIsInstance(model, MultiTaskGP)
72+
self.assertEqual(model.num_outputs, 2)
7573
self.assertIsInstance(model.likelihood, GaussianLikelihood)
7674
self.assertIsInstance(model.mean_module, ConstantMean)
7775
self.assertIsInstance(model.covar_module, ScaleKernel)
@@ -140,13 +138,11 @@ def test_MultiTaskGP(self):
140138
model.posterior(test_x)
141139

142140
def test_MultiTaskGP_single_output(self):
143-
for double in (False, True):
144-
tkwargs = {
145-
"device": self.device,
146-
"dtype": torch.double if double else torch.float,
147-
}
141+
for dtype in (torch.float, torch.double):
142+
tkwargs = {"device": self.device, "dtype": dtype}
148143
model = _get_model_single_output(**tkwargs)
149144
self.assertIsInstance(model, MultiTaskGP)
145+
self.assertEqual(model.num_outputs, 1)
150146
self.assertIsInstance(model.likelihood, GaussianLikelihood)
151147
self.assertIsInstance(model.mean_module, ConstantMean)
152148
self.assertIsInstance(model.covar_module, ScaleKernel)
@@ -180,13 +176,11 @@ def test_MultiTaskGP_single_output(self):
180176

181177
class TestFixedNoiseMultiTaskGP(BotorchTestCase):
182178
def test_FixedNoiseMultiTaskGP(self):
183-
for double in (False, True):
184-
tkwargs = {
185-
"device": self.device,
186-
"dtype": torch.double if double else torch.float,
187-
}
179+
for dtype in (torch.float, torch.double):
180+
tkwargs = {"device": self.device, "dtype": dtype}
188181
model = _get_fixed_noise_model(**tkwargs)
189182
self.assertIsInstance(model, FixedNoiseMultiTaskGP)
183+
self.assertEqual(model.num_outputs, 2)
190184
self.assertIsInstance(model.likelihood, FixedNoiseGaussianLikelihood)
191185
self.assertIsInstance(model.mean_module, ConstantMean)
192186
self.assertIsInstance(model.covar_module, ScaleKernel)
@@ -253,13 +247,11 @@ def test_FixedNoiseMultiTaskGP(self):
253247
FixedNoiseMultiTaskGP(train_X, train_Y, train_Yvar, 0, output_tasks=[2])
254248

255249
def test_FixedNoiseMultiTaskGP_single_output(self):
256-
for double in (False, True):
257-
tkwargs = {
258-
"device": self.device,
259-
"dtype": torch.double if double else torch.float,
260-
}
250+
for dtype in (torch.float, torch.double):
251+
tkwargs = {"device": self.device, "dtype": dtype}
261252
model = _get_fixed_noise_model_single_output(**tkwargs)
262253
self.assertIsInstance(model, FixedNoiseMultiTaskGP)
254+
self.assertEqual(model.num_outputs, 1)
263255
self.assertIsInstance(model.likelihood, FixedNoiseGaussianLikelihood)
264256
self.assertIsInstance(model.mean_module, ConstantMean)
265257
self.assertIsInstance(model.covar_module, ScaleKernel)

0 commit comments

Comments
 (0)