Skip to content

Commit 1242234

Browse files
sdaultonfacebook-github-bot
authored andcommitted
static method for getting batch shapes for batched MO models (#346)
Summary: Pull Request resolved: #346 This method makes it easy to figure out the `aug_batch_shape`, which is useful when specifying custom `covar_module`s and `likelihood`s. Reviewed By: Balandat Differential Revision: D19252473 fbshipit-source-id: 7ab7bb1a77016840fa14994810c09fa03f1821bc
1 parent 13936e5 commit 1242234

File tree

2 files changed

+49
-4
lines changed

2 files changed

+49
-4
lines changed

botorch/models/gpytorch.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,31 @@ class BatchedMultiOutputGPyTorchModel(GPyTorchModel):
194194
_input_batch_shape: torch.Size
195195
_aug_batch_shape: torch.Size
196196

197+
@staticmethod
198+
def get_batch_dimensions(
199+
train_X: Tensor, train_Y: Tensor
200+
) -> Tuple[torch.Size, torch.Size]:
201+
r"""Get the raw batch shape and output-augmented batch shape of the inputs.
202+
203+
Args:
204+
train_X: A `n x d` or `batch_shape x n x d` (batch mode) tensor of training
205+
features.
206+
train_Y: A `n x m` or `batch_shape x n x m` (batch mode) tensor of
207+
training observations.
208+
209+
Returns:
210+
2-element tuple containing
211+
212+
- The `input_batch_shape`
213+
- The output-augmented batch shape: `input_batch_shape x (m)`
214+
"""
215+
input_batch_shape = train_X.shape[:-2]
216+
aug_batch_shape = input_batch_shape
217+
num_outputs = train_Y.shape[-1]
218+
if num_outputs > 1:
219+
aug_batch_shape += torch.Size([num_outputs])
220+
return input_batch_shape, aug_batch_shape
221+
197222
def _set_dimensions(self, train_X: Tensor, train_Y: Tensor) -> None:
198223
r"""Store the number of outputs and the batch shape.
199224
@@ -204,10 +229,9 @@ def _set_dimensions(self, train_X: Tensor, train_Y: Tensor) -> None:
204229
training observations.
205230
"""
206231
self._num_outputs = train_Y.shape[-1]
207-
self._input_batch_shape = train_X.shape[:-2]
208-
self._aug_batch_shape = self._input_batch_shape
209-
if self._num_outputs > 1:
210-
self._aug_batch_shape += torch.Size([self._num_outputs])
232+
self._input_batch_shape, self._aug_batch_shape = self.get_batch_dimensions(
233+
train_X=train_X, train_Y=train_Y
234+
)
211235

212236
def _transform_tensor_args(
213237
self, X: Tensor, Y: Tensor, Yvar: Optional[Tensor] = None

test/models/test_gpytorch.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,27 @@ def test_batched_multi_output_gpytorch_model(self):
221221
self.assertIsInstance(cm, SimpleBatchedMultiOutputGPyTorchModel)
222222
self.assertEqual(cm.train_targets.shape, torch.Size([2, 2, 7]))
223223

224+
# test get_batch_dimensions
225+
get_batch_dims = SimpleBatchedMultiOutputGPyTorchModel.get_batch_dimensions
226+
for input_batch_dim in (0, 3):
227+
for num_outputs in (1, 2):
228+
input_batch_shape, aug_batch_shape = get_batch_dims(
229+
train_X=train_X.unsqueeze(0).expand(3, 5, 1)
230+
if input_batch_dim == 3
231+
else train_X,
232+
train_Y=train_Y[:, 0:1] if num_outputs == 1 else train_Y,
233+
)
234+
expected_input_batch_shape = (
235+
torch.Size([3]) if input_batch_dim == 3 else torch.Size([])
236+
)
237+
self.assertEqual(input_batch_shape, expected_input_batch_shape)
238+
self.assertEqual(
239+
aug_batch_shape,
240+
expected_input_batch_shape + torch.Size([])
241+
if num_outputs == 1
242+
else expected_input_batch_shape + torch.Size([2]),
243+
)
244+
224245

225246
class TestModelListGPyTorchModel(BotorchTestCase):
226247
def test_model_list_gpytorch_model(self):

0 commit comments

Comments
 (0)