Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,13 @@
"class bGPLVM(BayesianGPLVM):\n",
" def __init__(self, n, data_dim, latent_dim, n_inducing, pca=False):\n",
" self.n = n\n",
" self.batch_shape = torch.Size([data_dim])\n",
" \n",
" # Locations Z_{d} corresponding to u_{d}, they can be randomly initialized or \n",
" # regularly placed with shape (D x n_inducing x latent_dim).\n",
" self.inducing_inputs = torch.randn(data_dim, n_inducing, latent_dim)\n",
" \n",
" # Sparse Variational Formulation (inducing variables initialised as randn)\n",
" q_u = CholeskyVariationalDistribution(n_inducing, batch_shape=self.batch_shape) \n",
" q_u = CholeskyVariationalDistribution(n_inducing, batch_shape=torch.Size([data_dim])) \n",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why aren't we using the self.batch_shape property here?

" q_f = VariationalStrategy(self, self.inducing_inputs, q_u, learn_inducing_locations=True)\n",
" \n",
" # Define prior for X\n",
Expand Down
18 changes: 18 additions & 0 deletions gpytorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/env python3

import torch

from .gp import GP
from .pyro import _PyroMixin # This will only contain functions if Pyro is installed

Expand Down Expand Up @@ -40,6 +42,11 @@ class ApproximateGP(GP, _PyroMixin):
>>> # test_x = ...;
>>> model(test_x) # Returns the approximate GP latent function at test_x
>>> likelihood(model(test_x)) # Returns the (approximate) predictive posterior distribution at test_x

:ivar torch.Size batch_shape: The batch shape of the model. This is a batch shape from an I/O perspective,
independent of the internal representation of the model. For a model with `(m)` outputs, a
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
"""

def __init__(self, variational_strategy):
Expand All @@ -49,6 +56,17 @@ def __init__(self, variational_strategy):
def forward(self, x):
raise NotImplementedError

@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the model.

This is a batch shape from an I/O perspective, independent of the internal
representation of the model. For a model with `(m)` outputs, a
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
"""
return self.variational_strategy.batch_shape

def pyro_guide(self, input, beta=1.0, name_prefix=""):
r"""
(For Pyro integration only). The component of a `pyro.guide` that
Expand Down
30 changes: 22 additions & 8 deletions gpytorch/models/exact_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ class ExactGP(GP):
>>> # test_x = ...;
>>> model(test_x) # Returns the GP latent function at test_x
>>> likelihood(model(test_x)) # Returns the (approximate) predictive posterior distribution at test_x

:ivar torch.Size batch_shape: The batch shape of the model. This is a batch shape from an I/O perspective,
independent of the internal representation of the model. For a model with `(m)` outputs, a
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
"""

def __init__(self, train_inputs, train_targets, likelihood):
Expand All @@ -71,6 +76,17 @@ def __init__(self, train_inputs, train_targets, likelihood):

self.prediction_strategy = None

@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the model.

This is a batch shape from an I/O perspective, independent of the internal
representation of the model. For a model with `(m)` outputs, a
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
"""
return self.train_inputs[0].shape[:-2]

@property
def train_targets(self):
return self._train_targets
Expand Down Expand Up @@ -160,8 +176,6 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
"all test independent caches exist. Call the model on some data first!"
)

model_batch_shape = self.train_inputs[0].shape[:-2]

if not isinstance(inputs, list):
inputs = [inputs]

Expand All @@ -184,17 +198,17 @@ def get_fantasy_model(self, inputs, targets, **kwargs):

# Check whether we can properly broadcast batch dimensions
try:
torch.broadcast_shapes(model_batch_shape, target_batch_shape)
torch.broadcast_shapes(self.batch_shape, target_batch_shape)
except RuntimeError:
raise RuntimeError(
f"Model batch shape ({model_batch_shape}) and target batch shape "
f"Model batch shape ({self.batch_shape}) and target batch shape "
f"({target_batch_shape}) are not broadcastable."
)

if len(model_batch_shape) > len(input_batch_shape):
input_batch_shape = model_batch_shape
if len(model_batch_shape) > len(target_batch_shape):
target_batch_shape = model_batch_shape
if len(self.batch_shape) > len(input_batch_shape):
input_batch_shape = self.batch_shape
if len(self.batch_shape) > len(target_batch_shape):
target_batch_shape = self.batch_shape

# If input has no fantasy batch dimension but target does, we can save memory and computation by not
# computing the covariance for each element of the batch. Therefore we don't expand the inputs to the
Expand Down
14 changes: 13 additions & 1 deletion gpytorch/models/gp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
#!/usr/bin/env python3

import torch

from ..module import Module


class GP(Module):
pass
@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the model.

This is a batch shape from an I/O perspective, independent of the internal
representation of the model. For a model with `(m)` outputs, a
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
"""
cls_name = self.__class__.__name__
raise NotImplementedError(f"{cls_name} does not define batch_shape property")
23 changes: 23 additions & 0 deletions gpytorch/models/model_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,29 @@ def __init__(self, *models):
"IndependentModelList currently only supports models that have a likelihood (e.g. ExactGPs)"
)
self.likelihood = LikelihoodList(*[m.likelihood for m in models])
try:
batch_shapes = [m.batch_shape for m in self.models]
except RuntimeError:
# Some models may not have a batch shape, e.g. those for which it is only known
# after fitting the model. In this case, we skip the batch shape validation.
return
# If we know the batch shapes of the models, we can validate that the batch shapes
# are compatible (i.e., that we can broadcast the batch dimensions).
try:
self._batch_shape = torch.broadcast_shapes(*batch_shapes)
except RuntimeError:
raise RuntimeError(f"Model batch shapes are not broadcastable: {batch_shapes}.")

@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the model.

This is a batch shape from an I/O perspective, independent of the internal
representation of the model. For a model with `(m)` outputs, a
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
"""
return self._batch_shape

def forward_i(self, i, *args, **kwargs):
return self.models[i].forward(*args, **kwargs)
Expand Down
3 changes: 3 additions & 0 deletions gpytorch/test/model_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_forward_train(self):
data = self.create_test_data()
likelihood, labels = self.create_likelihood_and_labels()
model = self.create_model(data, labels, likelihood)
self.assertEqual(model.batch_shape, data.shape[:-2]) # test batch_shape property
model.train()
output = model(data)
self.assertTrue(output.lazy_covariance_matrix.dim() == 2)
Expand All @@ -42,6 +43,7 @@ def test_batch_forward_train(self):
batch_data = self.create_batch_test_data()
likelihood, labels = self.create_batch_likelihood_and_labels()
model = self.create_model(batch_data, labels, likelihood)
self.assertEqual(model.batch_shape, batch_data.shape[:-2]) # test batch_shape property
model.train()
output = model(batch_data)
self.assertTrue(output.lazy_covariance_matrix.dim() == 3)
Expand All @@ -52,6 +54,7 @@ def test_multi_batch_forward_train(self):
batch_data = self.create_batch_test_data(batch_shape=torch.Size([2, 3]))
likelihood, labels = self.create_batch_likelihood_and_labels(batch_shape=torch.Size([2, 3]))
model = self.create_model(batch_data, labels, likelihood)
self.assertEqual(model.batch_shape, batch_data.shape[:-2]) # test batch_shape property
model.train()
output = model(batch_data)
self.assertTrue(output.lazy_covariance_matrix.dim() == 4)
Expand Down
7 changes: 6 additions & 1 deletion gpytorch/variational/_variational_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,16 @@ def _expand_inputs(self, x: Tensor, inducing_points: Tensor) -> Tuple[Tensor, Te
"""
Pre-processing step in __call__ to make x the same batch_shape as the inducing points
"""
batch_shape = torch.broadcast_shapes(inducing_points.shape[:-2], x.shape[:-2])
batch_shape = torch.broadcast_shapes(self.batch_shape, x.shape[:-2])
inducing_points = inducing_points.expand(*batch_shape, *inducing_points.shape[-2:])
x = x.expand(*batch_shape, *x.shape[-2:])
return x, inducing_points

@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the variational strategy."""
return self.inducing_points.shape[:-2]

@property
def jitter_val(self) -> float:
if self._jitter_val is None:
Expand Down
17 changes: 10 additions & 7 deletions gpytorch/variational/lmc_variational_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,26 +116,24 @@ def __init__(
Module.__init__(self)
self.base_variational_strategy = base_variational_strategy
self.num_tasks = num_tasks
batch_shape = self.base_variational_strategy._variational_distribution.batch_shape
vdist_batch_shape = self.base_variational_strategy._variational_distribution.batch_shape

# Check if no functions
if latent_dim >= 0:
raise RuntimeError(f"latent_dim must be a negative indexed batch dimension: got {latent_dim}.")
if not (batch_shape[latent_dim] == num_latents or batch_shape[latent_dim] == 1):
if not (vdist_batch_shape[latent_dim] == num_latents or vdist_batch_shape[latent_dim] == 1):
raise RuntimeError(
f"Mismatch in num_latents: got a variational distribution of batch shape {batch_shape}, "
f"Mismatch in num_latents: got a variational distribution of batch shape {vdist_batch_shape}, "
f"expected the function dim {latent_dim} to be {num_latents}."
)
self.num_latents = num_latents
self.latent_dim = latent_dim

# Make the batch_shape
self.batch_shape = list(batch_shape)
del self.batch_shape[self.latent_dim]
self.batch_shape = torch.Size(self.batch_shape)
self._batch_shape = vdist_batch_shape[: self.latent_dim] + vdist_batch_shape[self.latent_dim + 1 :]

# LCM coefficients
lmc_coefficients = torch.randn(*batch_shape, self.num_tasks)
lmc_coefficients = torch.randn(*vdist_batch_shape, self.num_tasks)
self.register_parameter("lmc_coefficients", torch.nn.Parameter(lmc_coefficients))

if jitter_val is None:
Expand All @@ -145,6 +143,11 @@ def __init__(
else:
self.jitter_val = jitter_val

@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the variational strategy."""
return self._batch_shape

@property
def prior_distribution(self) -> MultivariateNormal:
return self.base_variational_strategy.prior_distribution
Expand Down
4 changes: 4 additions & 0 deletions test/models/test_exact_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def test_batch_forward_then_nonbatch_forward_eval(self):
batch_data = self.create_batch_test_data()
likelihood, labels = self.create_batch_likelihood_and_labels()
model = self.create_model(batch_data, labels, likelihood)

# test batch_shape property
self.assertEqual(model.batch_shape, batch_data.shape[:-2])

model.eval()
output = model(batch_data)

Expand Down
18 changes: 18 additions & 0 deletions test/models/test_model_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,24 @@ def create_model(self, fixed_noise=False):
likelihood = FixedNoiseGaussianLikelihood(noise)
return TestExactGP.create_model(self, data, labels, likelihood)

def create_batch_model(self, batch_shape, fixed_noise=False):
data = TestExactGP.create_batch_test_data(self, batch_shape=batch_shape)
likelihood, labels = TestExactGP.create_batch_likelihood_and_labels(self, batch_shape=batch_shape)
if fixed_noise:
noise = 0.1 + 0.2 * torch.rand_like(labels)
likelihood = FixedNoiseGaussianLikelihood(noise)
return TestExactGP.create_model(self, data, labels, likelihood)

def test_batch_shape(self):
model = self.create_model()
self.assertEqual(model.batch_shape, torch.Size([]))
model_batch = self.create_batch_model(batch_shape=torch.Size([2]))
model_list = IndependentModelList(model, model_batch)
self.assertEqual(model_list.batch_shape, torch.Size([2]))
model_batch_2 = self.create_batch_model(batch_shape=torch.Size([3]))
with self.assertRaisesRegex(RuntimeError, "Model batch shapes are not broadcastable"):
IndependentModelList(model_batch, model_batch_2)

def test_forward_eval(self):
models = [self.create_model() for _ in range(2)]
model = IndependentModelList(*models)
Expand Down
5 changes: 3 additions & 2 deletions test/models/test_variational_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@

class GPClassificationModel(ApproximateGP):
def __init__(self, train_x, use_inducing=False):
variational_distribution = CholeskyVariationalDistribution(train_x.size(-2), batch_shape=train_x.shape[:-2])
inducing_points = torch.randn(50, train_x.size(-1)) if use_inducing else train_x
batch_shape = train_x.shape[:-2]
variational_distribution = CholeskyVariationalDistribution(train_x.size(-2), batch_shape=batch_shape)
inducing_points = torch.randn(*batch_shape, 50, train_x.size(-1)) if use_inducing else train_x
strategy_cls = VariationalStrategy
variational_strategy = strategy_cls(
self, inducing_points, variational_distribution, learn_inducing_locations=use_inducing
Expand Down