diff --git a/examples/045_GPLVM/Gaussian_Process_Latent_Variable_Models_with_Stochastic_Variational_Inference.ipynb b/examples/045_GPLVM/Gaussian_Process_Latent_Variable_Models_with_Stochastic_Variational_Inference.ipynb index 6059b8166..b4a28327c 100644 --- a/examples/045_GPLVM/Gaussian_Process_Latent_Variable_Models_with_Stochastic_Variational_Inference.ipynb +++ b/examples/045_GPLVM/Gaussian_Process_Latent_Variable_Models_with_Stochastic_Variational_Inference.ipynb @@ -137,14 +137,13 @@ "class bGPLVM(BayesianGPLVM):\n", " def __init__(self, n, data_dim, latent_dim, n_inducing, pca=False):\n", " self.n = n\n", - " self.batch_shape = torch.Size([data_dim])\n", " \n", " # Locations Z_{d} corresponding to u_{d}, they can be randomly initialized or \n", " # regularly placed with shape (D x n_inducing x latent_dim).\n", " self.inducing_inputs = torch.randn(data_dim, n_inducing, latent_dim)\n", " \n", " # Sparse Variational Formulation (inducing variables initialised as randn)\n", - " q_u = CholeskyVariationalDistribution(n_inducing, batch_shape=self.batch_shape) \n", + " q_u = CholeskyVariationalDistribution(n_inducing, batch_shape=torch.Size([data_dim])) \n", " q_f = VariationalStrategy(self, self.inducing_inputs, q_u, learn_inducing_locations=True)\n", " \n", " # Define prior for X\n", diff --git a/gpytorch/models/approximate_gp.py b/gpytorch/models/approximate_gp.py index 85e2674f4..dc7a8406d 100644 --- a/gpytorch/models/approximate_gp.py +++ b/gpytorch/models/approximate_gp.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +import torch + from .gp import GP from .pyro import _PyroMixin # This will only contain functions if Pyro is installed @@ -40,6 +42,11 @@ class ApproximateGP(GP, _PyroMixin): >>> # test_x = ...; >>> model(test_x) # Returns the approximate GP latent function at test_x >>> likelihood(model(test_x)) # Returns the (approximate) predictive posterior distribution at test_x + + :ivar torch.Size batch_shape: The batch shape of the model. This is a batch shape from an I/O perspective, + independent of the internal representation of the model. For a model with `(m)` outputs, a + `test_batch_shape x q x d`-shaped input to the model in eval mode returns a + distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`. """ def __init__(self, variational_strategy): @@ -49,6 +56,17 @@ def __init__(self, variational_strategy): def forward(self, x): raise NotImplementedError + @property + def batch_shape(self) -> torch.Size: + r"""The batch shape of the model. + + This is a batch shape from an I/O perspective, independent of the internal + representation of the model. For a model with `(m)` outputs, a + `test_batch_shape x q x d`-shaped input to the model in eval mode returns a + distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`. + """ + return self.variational_strategy.batch_shape + def pyro_guide(self, input, beta=1.0, name_prefix=""): r""" (For Pyro integration only). The component of a `pyro.guide` that diff --git a/gpytorch/models/exact_gp.py b/gpytorch/models/exact_gp.py index 318df70a3..f7545ba66 100644 --- a/gpytorch/models/exact_gp.py +++ b/gpytorch/models/exact_gp.py @@ -50,6 +50,11 @@ class ExactGP(GP): >>> # test_x = ...; >>> model(test_x) # Returns the GP latent function at test_x >>> likelihood(model(test_x)) # Returns the (approximate) predictive posterior distribution at test_x + + :ivar torch.Size batch_shape: The batch shape of the model. This is a batch shape from an I/O perspective, + independent of the internal representation of the model. For a model with `(m)` outputs, a + `test_batch_shape x q x d`-shaped input to the model in eval mode returns a + distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`. """ def __init__(self, train_inputs, train_targets, likelihood): @@ -71,6 +76,17 @@ def __init__(self, train_inputs, train_targets, likelihood): self.prediction_strategy = None + @property + def batch_shape(self) -> torch.Size: + r"""The batch shape of the model. + + This is a batch shape from an I/O perspective, independent of the internal + representation of the model. For a model with `(m)` outputs, a + `test_batch_shape x q x d`-shaped input to the model in eval mode returns a + distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`. + """ + return self.train_inputs[0].shape[:-2] + @property def train_targets(self): return self._train_targets @@ -160,8 +176,6 @@ def get_fantasy_model(self, inputs, targets, **kwargs): "all test independent caches exist. Call the model on some data first!" ) - model_batch_shape = self.train_inputs[0].shape[:-2] - if not isinstance(inputs, list): inputs = [inputs] @@ -184,17 +198,17 @@ def get_fantasy_model(self, inputs, targets, **kwargs): # Check whether we can properly broadcast batch dimensions try: - torch.broadcast_shapes(model_batch_shape, target_batch_shape) + torch.broadcast_shapes(self.batch_shape, target_batch_shape) except RuntimeError: raise RuntimeError( - f"Model batch shape ({model_batch_shape}) and target batch shape " + f"Model batch shape ({self.batch_shape}) and target batch shape " f"({target_batch_shape}) are not broadcastable." ) - if len(model_batch_shape) > len(input_batch_shape): - input_batch_shape = model_batch_shape - if len(model_batch_shape) > len(target_batch_shape): - target_batch_shape = model_batch_shape + if len(self.batch_shape) > len(input_batch_shape): + input_batch_shape = self.batch_shape + if len(self.batch_shape) > len(target_batch_shape): + target_batch_shape = self.batch_shape # If input has no fantasy batch dimension but target does, we can save memory and computation by not # computing the covariance for each element of the batch. Therefore we don't expand the inputs to the diff --git a/gpytorch/models/gp.py b/gpytorch/models/gp.py index 922e981b0..f2c514333 100644 --- a/gpytorch/models/gp.py +++ b/gpytorch/models/gp.py @@ -1,7 +1,19 @@ #!/usr/bin/env python3 +import torch + from ..module import Module class GP(Module): - pass + @property + def batch_shape(self) -> torch.Size: + r"""The batch shape of the model. + + This is a batch shape from an I/O perspective, independent of the internal + representation of the model. For a model with `(m)` outputs, a + `test_batch_shape x q x d`-shaped input to the model in eval mode returns a + distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`. + """ + cls_name = self.__class__.__name__ + raise NotImplementedError(f"{cls_name} does not define batch_shape property") diff --git a/gpytorch/models/model_list.py b/gpytorch/models/model_list.py index 66b27d320..6fc3875e5 100644 --- a/gpytorch/models/model_list.py +++ b/gpytorch/models/model_list.py @@ -30,6 +30,29 @@ def __init__(self, *models): "IndependentModelList currently only supports models that have a likelihood (e.g. ExactGPs)" ) self.likelihood = LikelihoodList(*[m.likelihood for m in models]) + try: + batch_shapes = [m.batch_shape for m in self.models] + except RuntimeError: + # Some models may not have a batch shape, e.g. those for which it is only known + # after fitting the model. In this case, we skip the batch shape validation. + return + # If we know the batch shapes of the models, we can validate that the batch shapes + # are compatible (i.e., that we can broadcast the batch dimensions). + try: + self._batch_shape = torch.broadcast_shapes(*batch_shapes) + except RuntimeError: + raise RuntimeError(f"Model batch shapes are not broadcastable: {batch_shapes}.") + + @property + def batch_shape(self) -> torch.Size: + r"""The batch shape of the model. + + This is a batch shape from an I/O perspective, independent of the internal + representation of the model. For a model with `(m)` outputs, a + `test_batch_shape x q x d`-shaped input to the model in eval mode returns a + distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`. + """ + return self._batch_shape def forward_i(self, i, *args, **kwargs): return self.models[i].forward(*args, **kwargs) diff --git a/gpytorch/test/model_test_case.py b/gpytorch/test/model_test_case.py index ac2aed788..734e718b2 100644 --- a/gpytorch/test/model_test_case.py +++ b/gpytorch/test/model_test_case.py @@ -32,6 +32,7 @@ def test_forward_train(self): data = self.create_test_data() likelihood, labels = self.create_likelihood_and_labels() model = self.create_model(data, labels, likelihood) + self.assertEqual(model.batch_shape, data.shape[:-2]) # test batch_shape property model.train() output = model(data) self.assertTrue(output.lazy_covariance_matrix.dim() == 2) @@ -42,6 +43,7 @@ def test_batch_forward_train(self): batch_data = self.create_batch_test_data() likelihood, labels = self.create_batch_likelihood_and_labels() model = self.create_model(batch_data, labels, likelihood) + self.assertEqual(model.batch_shape, batch_data.shape[:-2]) # test batch_shape property model.train() output = model(batch_data) self.assertTrue(output.lazy_covariance_matrix.dim() == 3) @@ -52,6 +54,7 @@ def test_multi_batch_forward_train(self): batch_data = self.create_batch_test_data(batch_shape=torch.Size([2, 3])) likelihood, labels = self.create_batch_likelihood_and_labels(batch_shape=torch.Size([2, 3])) model = self.create_model(batch_data, labels, likelihood) + self.assertEqual(model.batch_shape, batch_data.shape[:-2]) # test batch_shape property model.train() output = model(batch_data) self.assertTrue(output.lazy_covariance_matrix.dim() == 4) diff --git a/gpytorch/variational/_variational_strategy.py b/gpytorch/variational/_variational_strategy.py index f5d3a1ed1..c7c6fe563 100644 --- a/gpytorch/variational/_variational_strategy.py +++ b/gpytorch/variational/_variational_strategy.py @@ -90,11 +90,16 @@ def _expand_inputs(self, x: Tensor, inducing_points: Tensor) -> Tuple[Tensor, Te """ Pre-processing step in __call__ to make x the same batch_shape as the inducing points """ - batch_shape = torch.broadcast_shapes(inducing_points.shape[:-2], x.shape[:-2]) + batch_shape = torch.broadcast_shapes(self.batch_shape, x.shape[:-2]) inducing_points = inducing_points.expand(*batch_shape, *inducing_points.shape[-2:]) x = x.expand(*batch_shape, *x.shape[-2:]) return x, inducing_points + @property + def batch_shape(self) -> torch.Size: + r"""The batch shape of the variational strategy.""" + return self.inducing_points.shape[:-2] + @property def jitter_val(self) -> float: if self._jitter_val is None: diff --git a/gpytorch/variational/lmc_variational_strategy.py b/gpytorch/variational/lmc_variational_strategy.py index 2e0f249ba..c71be7a04 100644 --- a/gpytorch/variational/lmc_variational_strategy.py +++ b/gpytorch/variational/lmc_variational_strategy.py @@ -116,26 +116,24 @@ def __init__( Module.__init__(self) self.base_variational_strategy = base_variational_strategy self.num_tasks = num_tasks - batch_shape = self.base_variational_strategy._variational_distribution.batch_shape + vdist_batch_shape = self.base_variational_strategy._variational_distribution.batch_shape # Check if no functions if latent_dim >= 0: raise RuntimeError(f"latent_dim must be a negative indexed batch dimension: got {latent_dim}.") - if not (batch_shape[latent_dim] == num_latents or batch_shape[latent_dim] == 1): + if not (vdist_batch_shape[latent_dim] == num_latents or vdist_batch_shape[latent_dim] == 1): raise RuntimeError( - f"Mismatch in num_latents: got a variational distribution of batch shape {batch_shape}, " + f"Mismatch in num_latents: got a variational distribution of batch shape {vdist_batch_shape}, " f"expected the function dim {latent_dim} to be {num_latents}." ) self.num_latents = num_latents self.latent_dim = latent_dim # Make the batch_shape - self.batch_shape = list(batch_shape) - del self.batch_shape[self.latent_dim] - self.batch_shape = torch.Size(self.batch_shape) + self._batch_shape = vdist_batch_shape[: self.latent_dim] + vdist_batch_shape[self.latent_dim + 1 :] # LCM coefficients - lmc_coefficients = torch.randn(*batch_shape, self.num_tasks) + lmc_coefficients = torch.randn(*vdist_batch_shape, self.num_tasks) self.register_parameter("lmc_coefficients", torch.nn.Parameter(lmc_coefficients)) if jitter_val is None: @@ -145,6 +143,11 @@ def __init__( else: self.jitter_val = jitter_val + @property + def batch_shape(self) -> torch.Size: + r"""The batch shape of the variational strategy.""" + return self._batch_shape + @property def prior_distribution(self) -> MultivariateNormal: return self.base_variational_strategy.prior_distribution diff --git a/test/models/test_exact_gp.py b/test/models/test_exact_gp.py index 6b431cef8..b899f3ef6 100644 --- a/test/models/test_exact_gp.py +++ b/test/models/test_exact_gp.py @@ -106,6 +106,10 @@ def test_batch_forward_then_nonbatch_forward_eval(self): batch_data = self.create_batch_test_data() likelihood, labels = self.create_batch_likelihood_and_labels() model = self.create_model(batch_data, labels, likelihood) + + # test batch_shape property + self.assertEqual(model.batch_shape, batch_data.shape[:-2]) + model.eval() output = model(batch_data) diff --git a/test/models/test_model_list.py b/test/models/test_model_list.py index 23163c9cf..ce3c7d5fc 100644 --- a/test/models/test_model_list.py +++ b/test/models/test_model_list.py @@ -19,6 +19,24 @@ def create_model(self, fixed_noise=False): likelihood = FixedNoiseGaussianLikelihood(noise) return TestExactGP.create_model(self, data, labels, likelihood) + def create_batch_model(self, batch_shape, fixed_noise=False): + data = TestExactGP.create_batch_test_data(self, batch_shape=batch_shape) + likelihood, labels = TestExactGP.create_batch_likelihood_and_labels(self, batch_shape=batch_shape) + if fixed_noise: + noise = 0.1 + 0.2 * torch.rand_like(labels) + likelihood = FixedNoiseGaussianLikelihood(noise) + return TestExactGP.create_model(self, data, labels, likelihood) + + def test_batch_shape(self): + model = self.create_model() + self.assertEqual(model.batch_shape, torch.Size([])) + model_batch = self.create_batch_model(batch_shape=torch.Size([2])) + model_list = IndependentModelList(model, model_batch) + self.assertEqual(model_list.batch_shape, torch.Size([2])) + model_batch_2 = self.create_batch_model(batch_shape=torch.Size([3])) + with self.assertRaisesRegex(RuntimeError, "Model batch shapes are not broadcastable"): + IndependentModelList(model_batch, model_batch_2) + def test_forward_eval(self): models = [self.create_model() for _ in range(2)] model = IndependentModelList(*models) diff --git a/test/models/test_variational_gp.py b/test/models/test_variational_gp.py index 664fe7690..dacf5e431 100644 --- a/test/models/test_variational_gp.py +++ b/test/models/test_variational_gp.py @@ -12,8 +12,9 @@ class GPClassificationModel(ApproximateGP): def __init__(self, train_x, use_inducing=False): - variational_distribution = CholeskyVariationalDistribution(train_x.size(-2), batch_shape=train_x.shape[:-2]) - inducing_points = torch.randn(50, train_x.size(-1)) if use_inducing else train_x + batch_shape = train_x.shape[:-2] + variational_distribution = CholeskyVariationalDistribution(train_x.size(-2), batch_shape=batch_shape) + inducing_points = torch.randn(*batch_shape, 50, train_x.size(-1)) if use_inducing else train_x strategy_cls = VariationalStrategy variational_strategy = strategy_cls( self, inducing_points, variational_distribution, learn_inducing_locations=use_inducing