diff --git a/botorch/models/approximate_gp.py b/botorch/models/approximate_gp.py index cdec687452..e20079a87b 100644 --- a/botorch/models/approximate_gp.py +++ b/botorch/models/approximate_gp.py @@ -426,6 +426,17 @@ def __init__( self.to(train_X) + @property + def batch_shape(self) -> torch.Size: + r"""The batch shape of the model. + + This is a batch shape from an I/O perspective. For a model with `m` + outputs, a `test_batch_shape x q x d`-shaped input `X` to the `posterior` + method returns a Posterior object over an output of shape + `broadcast(test_batch_shape, model.batch_shape) x q x m`. + """ + return self._input_batch_shape + def init_inducing_points( self, inputs: Tensor, diff --git a/test/models/test_approximate_gp.py b/test/models/test_approximate_gp.py index 0e013c6d6a..9318ccbd45 100644 --- a/test/models/test_approximate_gp.py +++ b/test/models/test_approximate_gp.py @@ -97,6 +97,8 @@ def test_posterior(self): model = SingleTaskVariationalGP(tx, ty, inducing_points=tx) posterior = model.posterior(test) self.assertIsInstance(posterior, GPyTorchPosterior) + # test batch_shape property + self.assertEqual(model.batch_shape, tx.shape[:-2]) def test_variational_setUp(self): for dtype in [torch.float, torch.double]: