Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions botorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,17 @@ def __init__(

self.to(train_X)

@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the model.

This is a batch shape from an I/O perspective. For a model with `m`
outputs, a `test_batch_shape x q x d`-shaped input `X` to the `posterior`
method returns a Posterior object over an output of shape
`broadcast(test_batch_shape, model.batch_shape) x q x m`.
"""
return self._input_batch_shape

def init_inducing_points(
self,
inputs: Tensor,
Expand Down
2 changes: 2 additions & 0 deletions test/models/test_approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def test_posterior(self):
model = SingleTaskVariationalGP(tx, ty, inducing_points=tx)
posterior = model.posterior(test)
self.assertIsInstance(posterior, GPyTorchPosterior)
# test batch_shape property
self.assertEqual(model.batch_shape, tx.shape[:-2])

def test_variational_setUp(self):
for dtype in [torch.float, torch.double]:
Expand Down