diff --git a/gpytorch/distributions/multivariate_normal.py b/gpytorch/distributions/multivariate_normal.py index a752e9632..b6d47b24a 100644 --- a/gpytorch/distributions/multivariate_normal.py +++ b/gpytorch/distributions/multivariate_normal.py @@ -15,6 +15,7 @@ from torch.distributions.kl import register_kl from torch.distributions.utils import _standard_normal, lazy_property +from gpytorch.functions import TensorInvQuadLogdet from .. import settings from ..utils.warnings import NumericalWarning from .distribution import Distribution @@ -245,9 +246,16 @@ def log_prob(self, value: Tensor) -> Tensor: 1, ) - # Get log determininant and first part of quadratic form covar = covar.evaluate_kernel() - inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True) + + if ( + settings.fast_computations.log_prob.off() or covar.size(-1) <= settings.max_cholesky_size.value() + ) and settings.use_torch_tensors.on(): + # If we are to use Cholesky decomposition for inference, and we are allowed to use torch tensors as opposed + # to linear operators, then do so. + inv_quad, logdet = TensorInvQuadLogdet.apply(covar.to_dense(), diff.unsqueeze(-1)) + else: + inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True) res = -0.5 * sum([inv_quad, logdet, diff.size(-1) * math.log(2 * math.pi)]) return res diff --git a/gpytorch/functions/__init__.py b/gpytorch/functions/__init__.py index d3294a974..fba396de1 100644 --- a/gpytorch/functions/__init__.py +++ b/gpytorch/functions/__init__.py @@ -9,6 +9,7 @@ import torch from ._log_normal_cdf import LogNormalCDF +from .inv_quad_logdet import TensorInvQuadLogdet from .matern_covariance import MaternCovariance from .rbf_covariance import RBFCovariance @@ -39,6 +40,7 @@ def inv_matmul(mat, right_tensor, left_tensor=None): __all__ = [ + "TensorInvQuadLogdet", "MaternCovariance", "RBFCovariance", "inv_matmul", diff --git a/gpytorch/functions/inv_quad_logdet.py b/gpytorch/functions/inv_quad_logdet.py new file mode 100644 index 000000000..5c3b3a798 --- /dev/null +++ b/gpytorch/functions/inv_quad_logdet.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 + +import torch + +from torch import Tensor + + +class TensorInvQuadLogdet(torch.autograd.Function): + r"""This function computes the inverse quadratic form and the log determinant of a positive semi-definite matrix. + This is a light weight implementation of `LinearOperator.inv_quad_logdet`. The main motivation is to avoid the + overhead of linear operators for dense kernel matrices by doing linear algebra operations directly on torch tensors. + """ + + @staticmethod + def forward( + ctx, + matrix: Tensor, + inv_quad_rhs: Tensor, + ) -> tuple[Tensor, Tensor]: + r"""Compute the inverse quadratic form and the log determinant. + + :param matrix: A positive semi-definite matrix of size `(..., N, N)`. + :param inv_quad_rhs: The right-hand side vector of size `(..., N, 1)`. + :return: The inverse quadratic form and the log determinant, both of size `(...)`. + """ + chol = torch.linalg.cholesky(matrix) + + # The inverse quadratic term + inv_quad_solves = torch.cholesky_solve(inv_quad_rhs, chol) + inv_quad_term = (inv_quad_solves * inv_quad_rhs).sum(-2) + inv_quad_term = inv_quad_term.squeeze(-1) + + # The log determinant term + logdet_term = 2.0 * chol.diagonal(dim1=-1, dim2=-2).log().sum(-1) + + ctx.save_for_backward(chol, inv_quad_solves) + + return inv_quad_term, logdet_term + + @staticmethod + def backward(ctx, d_inv_quad: Tensor, d_logdet: Tensor) -> tuple[Tensor, Tensor]: + r"""Compute the backward pass for the inverse quadratic form and the log determinant. + + :param d_inv_quad: The gradient of the inverse quadratic form of size `(...)`. + :param d_logdet: The gradient of the log determinant of size `(...)`. + :return: The gradients with respect to the input matrix and the right-hand side vector. + """ + chol, inv_quad_solves = ctx.saved_tensors + + d_matrix_one = ( + -1.0 * inv_quad_solves @ inv_quad_solves.transpose(-2, -1) * d_inv_quad.unsqueeze(-1).unsqueeze(-1) + ) + d_matrix_two = torch.cholesky_inverse(chol) * d_logdet.unsqueeze(-1).unsqueeze(-1) + d_matrix = d_matrix_one + d_matrix_two + + d_inv_quad_rhs = 2.0 * inv_quad_solves * d_inv_quad.unsqueeze(-1).unsqueeze(-1) + + return d_matrix, d_inv_quad_rhs diff --git a/gpytorch/settings.py b/gpytorch/settings.py index 99528c419..f22e81bfa 100644 --- a/gpytorch/settings.py +++ b/gpytorch/settings.py @@ -461,6 +461,17 @@ class use_keops(_feature_flag): _default = True +class use_torch_tensors(_feature_flag): + """ + Whether or not to use torch tensors instead of linear operators. If true, then we will use torch tensors as much as + possible to avoid the overhead of linear operators for dense kernel matrices. + + (Default: False) + """ + + _default = False + + __all__ = [ "_linalg_dtype_symeig", "_linalg_dtype_cholesky", @@ -502,6 +513,7 @@ class use_keops(_feature_flag): "tridiagonal_jitter", "use_keops", "use_toeplitz", + "use_torch_tensors", "variational_cholesky_jitter", "verbose_linalg", ] diff --git a/test/functions/test_inv_quad_logdet.py b/test/functions/test_inv_quad_logdet.py new file mode 100644 index 000000000..b9ebbc1b8 --- /dev/null +++ b/test/functions/test_inv_quad_logdet.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 + +import unittest + +import torch + +from gpytorch.functions import TensorInvQuadLogdet +from gpytorch.kernels import RBFKernel + + +class TestInvQuadLogdet(unittest.TestCase): + def test_inv_quad_logdet(self): + # NOTE: Use small matrics here to avoid flakiness since we are testing in `float32`. + num_data = 3 + jitter = 1e-4 + + train_x = torch.linspace(0, 1, num_data).view(num_data, 1) + + # Foward and backward using `InvQuadLogdet` + covar_module = RBFKernel() + covar_matrix = covar_module(train_x).evaluate_kernel().add_jitter(jitter).to_dense() + + inv_quad_rhs = torch.linspace(0, 1, num_data).requires_grad_(True) + + inv_quad, logdet = TensorInvQuadLogdet.apply(covar_matrix, inv_quad_rhs.unsqueeze(-1)) + inv_quad_logdet = inv_quad + logdet + inv_quad_logdet.backward() + + # Forward and backward using linear operators + covar_module_linop = RBFKernel() + covar_matrix_linop = covar_module_linop(train_x).evaluate_kernel().add_jitter(jitter) + + inv_quad_rhs_linop = inv_quad_rhs.detach().clone().requires_grad_(True) + + inv_quad_linop, logdet_linop = covar_matrix_linop.inv_quad_logdet(inv_quad_rhs_linop.unsqueeze(-1), logdet=True) + inv_quad_logdet_linop = inv_quad_linop + logdet_linop + inv_quad_logdet_linop.backward() + + self.assertTrue(torch.allclose(inv_quad, inv_quad_linop)) + self.assertTrue(torch.allclose(logdet, logdet_linop)) + self.assertTrue(torch.allclose(inv_quad_logdet, inv_quad_logdet_linop)) + self.assertTrue(torch.allclose(covar_module.raw_lengthscale.grad, covar_module_linop.raw_lengthscale.grad)) + self.assertTrue(torch.allclose(inv_quad_rhs.grad, inv_quad_rhs_linop.grad)) + + def test_batch_inv_quad_logdet(self): + num_data = 3 + jitter = 1e-4 + + train_x = torch.linspace(0, 1, 2 * num_data).view(2, num_data, 1) + + # Foward and backward using `InvQuadLogdet` + covar_module = RBFKernel(batch_shape=torch.Size([2])) + covar_matrix = covar_module(train_x).evaluate_kernel().add_jitter(jitter).to_dense() + + inv_quad_rhs = torch.linspace(0, 1, 2 * num_data).view(2, num_data).requires_grad_(True) + + inv_quad, logdet = TensorInvQuadLogdet.apply(covar_matrix, inv_quad_rhs.unsqueeze(-1)) + inv_quad_logdet = torch.sum(inv_quad + logdet) + inv_quad_logdet.backward() + + # Forward and backward using linear operators + covar_module_linop = RBFKernel(batch_shape=torch.Size([2])) + covar_matrix_linop = covar_module_linop(train_x).evaluate_kernel().add_jitter(jitter) + + inv_quad_rhs_linop = inv_quad_rhs.detach().clone().requires_grad_(True) + + inv_quad_linop, logdet_linop = covar_matrix_linop.inv_quad_logdet(inv_quad_rhs_linop.unsqueeze(-1), logdet=True) + inv_quad_logdet_linop = torch.sum(inv_quad_linop + logdet_linop) + inv_quad_logdet_linop.backward() + + self.assertTrue(torch.allclose(inv_quad, inv_quad_linop)) + self.assertTrue(torch.allclose(logdet, logdet_linop)) + self.assertTrue(torch.allclose(inv_quad_logdet, inv_quad_logdet_linop)) + self.assertTrue(torch.allclose(covar_module.raw_lengthscale.grad, covar_module_linop.raw_lengthscale.grad)) + self.assertTrue(torch.allclose(inv_quad_rhs.grad, inv_quad_rhs_linop.grad)) + + +if __name__ == "__main__": + unittest.main()