@@ -97,6 +97,18 @@ def _validate_tensor_args(
9797 f" { Yvar .shape } )."
9898 )
9999
100+ @property
101+ def batch_shape (self ) -> torch .Size :
102+ r"""The batch shape of the model.
103+
104+ This is a batch shape from an I/O perspective, independent of the internal
105+ representation of the model (as e.g. in BatchedMultiOutputGPyTorchModel).
106+ For a model with `m` outputs, a `test_batch_shape x q x d`-shaped input `X`
107+ to the `posterior` method returns a Posterior object over an output of
108+ shape `broadcast(test_batch_shape, model.batch_shape) x q x m`.
109+ """
110+ return self .train_inputs [0 ].shape [:- 2 ]
111+
100112 @property
101113 def num_outputs (self ) -> int :
102114 r"""The number of outputs of the model."""
@@ -234,6 +246,18 @@ def _set_dimensions(self, train_X: Tensor, train_Y: Tensor) -> None:
234246 train_X = train_X , train_Y = train_Y
235247 )
236248
249+ @property
250+ def batch_shape (self ) -> torch .Size :
251+ r"""The batch shape of the model.
252+
253+ This is a batch shape from an I/O perspective, independent of the internal
254+ representation of the model (as e.g. in BatchedMultiOutputGPyTorchModel).
255+ For a model with `m` outputs, a `test_batch_shape x q x d`-shaped input `X`
256+ to the `posterior` method returns a Posterior object over an output of
257+ shape `broadcast(test_batch_shape, model.batch_shape) x q x m`.
258+ """
259+ return self ._input_batch_shape
260+
237261 def _transform_tensor_args (
238262 self , X : Tensor , Y : Tensor , Yvar : Optional [Tensor ] = None
239263 ) -> Tuple [Tensor , Tensor , Optional [Tensor ]]:
@@ -442,6 +466,19 @@ class ModelListGPyTorchModel(GPyTorchModel, ABC):
442466 evaluation of submodels.
443467 """
444468
469+ @property
470+ def batch_shape (self ) -> torch .Size :
471+ r"""The batch shape of the model.
472+
473+ This is a batch shape from an I/O perspective, independent of the internal
474+ representation of the model (as e.g. in BatchedMultiOutputGPyTorchModel).
475+ For a model with `m` outputs, a `test_batch_shape x q x d`-shaped input `X`
476+ to the `posterior` method returns a Posterior object over an output of
477+ shape `broadcast(test_batch_shape, model.batch_shape) x q x m`.
478+ """
479+ # TODO: Either check that batch shapes match across models, or broadcast them
480+ raise NotImplementedError
481+
445482 def posterior (
446483 self ,
447484 X : Tensor ,
0 commit comments