Skip to content

Commit 519b18b

Browse files
Balandatfacebook-github-bot
authored andcommitted
Add batch_shape property to models (#588)
Summary: A consistent API like this will be useful and avoid ad-hoc inferring of batch shapes. See #587 for more context. Pull Request resolved: #588 Reviewed By: qingfeng10 Differential Revision: D24622409 Pulled By: Balandat fbshipit-source-id: 627c102cdf3f98637c30c96ef1766f907a951797
1 parent 6fc80e5 commit 519b18b

File tree

7 files changed

+78
-8
lines changed

7 files changed

+78
-8
lines changed

botorch/models/gpytorch.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,18 @@ def _validate_tensor_args(
9797
f" {Yvar.shape})."
9898
)
9999

100+
@property
101+
def batch_shape(self) -> torch.Size:
102+
r"""The batch shape of the model.
103+
104+
This is a batch shape from an I/O perspective, independent of the internal
105+
representation of the model (as e.g. in BatchedMultiOutputGPyTorchModel).
106+
For a model with `m` outputs, a `test_batch_shape x q x d`-shaped input `X`
107+
to the `posterior` method returns a Posterior object over an output of
108+
shape `broadcast(test_batch_shape, model.batch_shape) x q x m`.
109+
"""
110+
return self.train_inputs[0].shape[:-2]
111+
100112
@property
101113
def num_outputs(self) -> int:
102114
r"""The number of outputs of the model."""
@@ -234,6 +246,18 @@ def _set_dimensions(self, train_X: Tensor, train_Y: Tensor) -> None:
234246
train_X=train_X, train_Y=train_Y
235247
)
236248

249+
@property
250+
def batch_shape(self) -> torch.Size:
251+
r"""The batch shape of the model.
252+
253+
This is a batch shape from an I/O perspective, independent of the internal
254+
representation of the model (as e.g. in BatchedMultiOutputGPyTorchModel).
255+
For a model with `m` outputs, a `test_batch_shape x q x d`-shaped input `X`
256+
to the `posterior` method returns a Posterior object over an output of
257+
shape `broadcast(test_batch_shape, model.batch_shape) x q x m`.
258+
"""
259+
return self._input_batch_shape
260+
237261
def _transform_tensor_args(
238262
self, X: Tensor, Y: Tensor, Yvar: Optional[Tensor] = None
239263
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
@@ -442,6 +466,19 @@ class ModelListGPyTorchModel(GPyTorchModel, ABC):
442466
evaluation of submodels.
443467
"""
444468

469+
@property
470+
def batch_shape(self) -> torch.Size:
471+
r"""The batch shape of the model.
472+
473+
This is a batch shape from an I/O perspective, independent of the internal
474+
representation of the model (as e.g. in BatchedMultiOutputGPyTorchModel).
475+
For a model with `m` outputs, a `test_batch_shape x q x d`-shaped input `X`
476+
to the `posterior` method returns a Posterior object over an output of
477+
shape `broadcast(test_batch_shape, model.batch_shape) x q x m`.
478+
"""
479+
# TODO: Either check that batch shapes match across models, or broadcast them
480+
raise NotImplementedError
481+
445482
def posterior(
446483
self,
447484
X: Tensor,

botorch/models/model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from abc import ABC, abstractmethod
1414
from typing import Any, Dict, List, Optional
1515

16+
import torch
1617
from botorch import settings
1718
from botorch.posteriors import Posterior
1819
from botorch.sampling.samplers import MCSampler
@@ -51,6 +52,19 @@ def posterior(
5152
"""
5253
pass # pragma: no cover
5354

55+
@property
56+
def batch_shape(self) -> torch.Size:
57+
r"""The batch shape of the model.
58+
59+
This is a batch shape from an I/O perspective, independent of the internal
60+
representation of the model (as e.g. in BatchedMultiOutputGPyTorchModel).
61+
For a model with `m` outputs, a `test_batch_shape x q x d`-shaped input `X`
62+
to the `posterior` method returns a Posterior object over an output of
63+
shape `broadcast(test_batch_shape, model.batch_shape) x q x m`.
64+
"""
65+
cls_name = self.__class__.__name__
66+
raise NotImplementedError(f"{cls_name} does not define batch_shape property")
67+
5468
@property
5569
def num_outputs(self) -> int:
5670
r"""The number of outputs of the model."""

botorch/models/pairwise_gp.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,6 @@ def __deepcopy__(self, memo) -> PairwiseGP:
190190
self.__deepcopy__ = dcp
191191
return new_model
192192

193-
@property
194-
def num_outputs(self) -> int:
195-
r"""The number of outputs of the model."""
196-
return self._num_outputs
197-
198193
def _has_no_data(self):
199194
r"""Return true if the model does not have both datapoints and comparisons"""
200195
return (
@@ -646,6 +641,23 @@ def _util_newton_updates(self, x0, max_iter=1, xtol=None) -> Tensor:
646641

647642
# ============== public APIs ==============
648643

644+
@property
645+
def num_outputs(self) -> int:
646+
r"""The number of outputs of the model."""
647+
return self._num_outputs
648+
649+
@property
650+
def batch_shape(self) -> torch.Size:
651+
r"""The batch shape of the model.
652+
653+
This is a batch shape from an I/O perspective, independent of the internal
654+
representation of the model (as e.g. in BatchedMultiOutputGPyTorchModel).
655+
For a model with `m` outputs, a `test_batch_shape x q x d`-shaped input `X`
656+
to the `posterior` method returns a Posterior object over an output of
657+
shape `broadcast(test_batch_shape, model.batch_shape) x q x m`.
658+
"""
659+
return self.datapoints.shape[:-2]
660+
649661
def set_train_data(
650662
self, datapoints: Tensor, comparisons: Tensor, update_model: bool = True
651663
) -> None:

botorch/optim/initializers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,8 @@ def gen_value_function_initial_conditions(
279279
280280
Returns:
281281
A `num_restarts x batch_shape x q x d` tensor that can be used as initial
282-
conditions for `optimize_acqf()`. Here `batch_shape` is the
283-
`_input_batch_shape` of value function model.
282+
conditions for `optimize_acqf()`. Here `batch_shape` is the batch shape
283+
of value function model.
284284
285285
Example:
286286
>>> fant_X = torch.rand(5, 1, 2)
@@ -325,7 +325,7 @@ def gen_value_function_initial_conditions(
325325
},
326326
)
327327

328-
batch_shape = acq_function.model._input_batch_shape
328+
batch_shape = acq_function.model.batch_shape
329329
# sampling from the optimizers
330330
n_value = int((1 - frac_random) * raw_samples) # number of non-random ICs
331331
if n_value > 0:

test/models/test_gpytorch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def test_gpytorch_model(self):
8686
# basic test
8787
model = SimpleGPyTorchModel(train_X, train_Y, octf)
8888
self.assertEqual(model.num_outputs, 1)
89+
self.assertEqual(model.batch_shape, torch.Size())
8990
test_X = torch.rand(2, 1, **tkwargs)
9091
posterior = model.posterior(test_X)
9192
self.assertIsInstance(posterior, GPyTorchPosterior)
@@ -181,6 +182,7 @@ def test_batched_multi_output_gpytorch_model(self):
181182
# basic test
182183
model = SimpleBatchedMultiOutputGPyTorchModel(train_X, train_Y)
183184
self.assertEqual(model.num_outputs, 2)
185+
self.assertEqual(model.batch_shape, torch.Size())
184186
test_X = torch.rand(2, 1, **tkwargs)
185187
posterior = model.posterior(test_X)
186188
self.assertIsInstance(posterior, GPyTorchPosterior)
@@ -257,6 +259,8 @@ def test_model_list_gpytorch_model(self):
257259
m2 = SimpleGPyTorchModel(train_X2, train_Y2)
258260
model = SimpleModelListGPyTorchModel(m1, m2)
259261
self.assertEqual(model.num_outputs, 2)
262+
with self.assertRaises(NotImplementedError):
263+
model.batch_shape
260264
test_X = torch.rand(2, 1, **tkwargs)
261265
posterior = model.posterior(test_X)
262266
self.assertIsInstance(posterior, GPyTorchPosterior)

test/models/test_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def test_not_so_abstract_base_model(self):
2424
model.condition_on_observations(None, None)
2525
with self.assertRaises(NotImplementedError):
2626
model.num_outputs
27+
with self.assertRaises(NotImplementedError):
28+
model.batch_shape
2729
with self.assertRaises(NotImplementedError):
2830
model.subset_output([0])
2931
with self.assertRaises(NotImplementedError):

test/models/test_pairwise_gp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def test_pairwise_gp(self):
9090
model.covar_module.outputscale_prior, SmoothedBoxPrior
9191
)
9292
self.assertEqual(model.num_outputs, 1)
93+
self.assertEqual(model.batch_shape, batch_shape)
9394

9495
# test custom models
9596
custom_m = PairwiseGP(**model_kwargs, covar_module=LinearKernel())

0 commit comments

Comments
 (0)