Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions gpytorch/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.distributions.kl import register_kl
from torch.distributions.utils import _standard_normal, lazy_property

from gpytorch.functions import TensorInvQuadLogdet
from .. import settings
from ..utils.warnings import NumericalWarning
from .distribution import Distribution
Expand Down Expand Up @@ -245,9 +246,16 @@ def log_prob(self, value: Tensor) -> Tensor:
1,
)

# Get log determininant and first part of quadratic form
covar = covar.evaluate_kernel()
inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True)

if (
settings.fast_computations.log_prob.off() or covar.size(-1) <= settings.max_cholesky_size.value()
) and settings.use_torch_tensors.on():
# If we are to use Cholesky decomposition for inference, and we are allowed to use torch tensors as opposed
# to linear operators, then do so.
inv_quad, logdet = TensorInvQuadLogdet.apply(covar.to_dense(), diff.unsqueeze(-1))
else:
inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True)

res = -0.5 * sum([inv_quad, logdet, diff.size(-1) * math.log(2 * math.pi)])
return res
Expand Down
2 changes: 2 additions & 0 deletions gpytorch/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch

from ._log_normal_cdf import LogNormalCDF
from .inv_quad_logdet import TensorInvQuadLogdet
from .matern_covariance import MaternCovariance
from .rbf_covariance import RBFCovariance

Expand Down Expand Up @@ -39,6 +40,7 @@ def inv_matmul(mat, right_tensor, left_tensor=None):


__all__ = [
"TensorInvQuadLogdet",
"MaternCovariance",
"RBFCovariance",
"inv_matmul",
Expand Down
58 changes: 58 additions & 0 deletions gpytorch/functions/inv_quad_logdet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python3

import torch

from torch import Tensor


class TensorInvQuadLogdet(torch.autograd.Function):
r"""This function computes the inverse quadratic form and the log determinant of a positive semi-definite matrix.
This is a light weight implementation of `LinearOperator.inv_quad_logdet`. The main motivation is to avoid the
overhead of linear operators for dense kernel matrices by doing linear algebra operations directly on torch tensors.
"""

@staticmethod
def forward(
ctx,
matrix: Tensor,
inv_quad_rhs: Tensor,
) -> tuple[Tensor, Tensor]:
r"""Compute the inverse quadratic form and the log determinant.

:param matrix: A positive semi-definite matrix of size `(..., N, N)`.
:param inv_quad_rhs: The right-hand side vector of size `(..., N, 1)`.
:return: The inverse quadratic form and the log determinant, both of size `(...)`.
"""
chol = torch.linalg.cholesky(matrix)

# The inverse quadratic term
inv_quad_solves = torch.cholesky_solve(inv_quad_rhs, chol)
inv_quad_term = (inv_quad_solves * inv_quad_rhs).sum(-2)
inv_quad_term = inv_quad_term.squeeze(-1)

# The log determinant term
logdet_term = 2.0 * chol.diagonal(dim1=-1, dim2=-2).log().sum(-1)

ctx.save_for_backward(chol, inv_quad_solves)

return inv_quad_term, logdet_term

@staticmethod
def backward(ctx, d_inv_quad: Tensor, d_logdet: Tensor) -> tuple[Tensor, Tensor]:
r"""Compute the backward pass for the inverse quadratic form and the log determinant.

:param d_inv_quad: The gradient of the inverse quadratic form of size `(...)`.
:param d_logdet: The gradient of the log determinant of size `(...)`.
:return: The gradients with respect to the input matrix and the right-hand side vector.
"""
chol, inv_quad_solves = ctx.saved_tensors

d_matrix_one = (
-1.0 * inv_quad_solves @ inv_quad_solves.transpose(-2, -1) * d_inv_quad.unsqueeze(-1).unsqueeze(-1)
)
d_matrix_two = torch.cholesky_inverse(chol) * d_logdet.unsqueeze(-1).unsqueeze(-1)
d_matrix = d_matrix_one + d_matrix_two

d_inv_quad_rhs = 2.0 * inv_quad_solves * d_inv_quad.unsqueeze(-1).unsqueeze(-1)

return d_matrix, d_inv_quad_rhs
12 changes: 12 additions & 0 deletions gpytorch/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,17 @@ class use_keops(_feature_flag):
_default = True


class use_torch_tensors(_feature_flag):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we think about making this on by default up to some N?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, the first version of this PR turns on this flag up to some N as you suggested. But the benchmark shows speed up even for N=1000 (whereas the default threshold for Cholesky decomposition is N=800). So I decided to turns this on as long as Cholesky decomposition is used for training and inference.

I think the design here is intertwined with your comments below---what would happen for larger N. I'll circle back on this once we have benchmark results for larger N.

"""
Whether or not to use torch tensors instead of linear operators. If true, then we will use torch tensors as much as
possible to avoid the overhead of linear operators for dense kernel matrices.

(Default: False)
"""

_default = False


__all__ = [
"_linalg_dtype_symeig",
"_linalg_dtype_cholesky",
Expand Down Expand Up @@ -502,6 +513,7 @@ class use_keops(_feature_flag):
"tridiagonal_jitter",
"use_keops",
"use_toeplitz",
"use_torch_tensors",
"variational_cholesky_jitter",
"verbose_linalg",
]
79 changes: 79 additions & 0 deletions test/functions/test_inv_quad_logdet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#!/usr/bin/env python3

import unittest

import torch

from gpytorch.functions import TensorInvQuadLogdet
from gpytorch.kernels import RBFKernel


class TestInvQuadLogdet(unittest.TestCase):
def test_inv_quad_logdet(self):
# NOTE: Use small matrics here to avoid flakiness since we are testing in `float32`.
num_data = 3
jitter = 1e-4

train_x = torch.linspace(0, 1, num_data).view(num_data, 1)

# Foward and backward using `InvQuadLogdet`
covar_module = RBFKernel()
covar_matrix = covar_module(train_x).evaluate_kernel().add_jitter(jitter).to_dense()

inv_quad_rhs = torch.linspace(0, 1, num_data).requires_grad_(True)

inv_quad, logdet = TensorInvQuadLogdet.apply(covar_matrix, inv_quad_rhs.unsqueeze(-1))
inv_quad_logdet = inv_quad + logdet
inv_quad_logdet.backward()

# Forward and backward using linear operators
covar_module_linop = RBFKernel()
covar_matrix_linop = covar_module_linop(train_x).evaluate_kernel().add_jitter(jitter)

inv_quad_rhs_linop = inv_quad_rhs.detach().clone().requires_grad_(True)

inv_quad_linop, logdet_linop = covar_matrix_linop.inv_quad_logdet(inv_quad_rhs_linop.unsqueeze(-1), logdet=True)
inv_quad_logdet_linop = inv_quad_linop + logdet_linop
inv_quad_logdet_linop.backward()

self.assertTrue(torch.allclose(inv_quad, inv_quad_linop))
self.assertTrue(torch.allclose(logdet, logdet_linop))
self.assertTrue(torch.allclose(inv_quad_logdet, inv_quad_logdet_linop))
self.assertTrue(torch.allclose(covar_module.raw_lengthscale.grad, covar_module_linop.raw_lengthscale.grad))
self.assertTrue(torch.allclose(inv_quad_rhs.grad, inv_quad_rhs_linop.grad))

def test_batch_inv_quad_logdet(self):
num_data = 3
jitter = 1e-4

train_x = torch.linspace(0, 1, 2 * num_data).view(2, num_data, 1)

# Foward and backward using `InvQuadLogdet`
covar_module = RBFKernel(batch_shape=torch.Size([2]))
covar_matrix = covar_module(train_x).evaluate_kernel().add_jitter(jitter).to_dense()

inv_quad_rhs = torch.linspace(0, 1, 2 * num_data).view(2, num_data).requires_grad_(True)

inv_quad, logdet = TensorInvQuadLogdet.apply(covar_matrix, inv_quad_rhs.unsqueeze(-1))
inv_quad_logdet = torch.sum(inv_quad + logdet)
inv_quad_logdet.backward()

# Forward and backward using linear operators
covar_module_linop = RBFKernel(batch_shape=torch.Size([2]))
covar_matrix_linop = covar_module_linop(train_x).evaluate_kernel().add_jitter(jitter)

inv_quad_rhs_linop = inv_quad_rhs.detach().clone().requires_grad_(True)

inv_quad_linop, logdet_linop = covar_matrix_linop.inv_quad_logdet(inv_quad_rhs_linop.unsqueeze(-1), logdet=True)
inv_quad_logdet_linop = torch.sum(inv_quad_linop + logdet_linop)
inv_quad_logdet_linop.backward()

self.assertTrue(torch.allclose(inv_quad, inv_quad_linop))
self.assertTrue(torch.allclose(logdet, logdet_linop))
self.assertTrue(torch.allclose(inv_quad_logdet, inv_quad_logdet_linop))
self.assertTrue(torch.allclose(covar_module.raw_lengthscale.grad, covar_module_linop.raw_lengthscale.grad))
self.assertTrue(torch.allclose(inv_quad_rhs.grad, inv_quad_rhs_linop.grad))


if __name__ == "__main__":
unittest.main()