Skip to content

Commit 5d2671b

Browse files
add generate_inv_roots option
1 parent 05ebbf2 commit 5d2671b

File tree

2 files changed

+24
-17
lines changed

2 files changed

+24
-17
lines changed

gpytorch/lazy/lazy_tensor.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ def add_jitter(self, jitter_val=1e-3):
735735
diag = torch.tensor(jitter_val, dtype=self.dtype, device=self.device)
736736
return self.add_diag(diag)
737737

738-
def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs):
738+
def cat_rows(self, cross_mat, new_mat, generate_roots=True, generate_inv_roots=True, **root_decomp_kwargs):
739739
"""
740740
Concatenates new rows and columns to the matrix that this LazyTensor represents, e.g.
741741
C = [A B^T; B D]. where A is the existing lazy tensor, and B (cross_mat) and D (new_mat)
@@ -762,8 +762,8 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
762762
If :math:`A` is n x n, then this matrix should be n x k.
763763
new_mat (:obj:`torch.tensor`): the matrix :math:`D` we are appending to the matrix :math:`A`.
764764
If :math:`B` is n x k, then this matrix should be k x k.
765-
generate_roots (:obj:`bool`): whether to generate the root decomposition of :math:`A` even if it
766-
has not been created yet.
765+
generate_roots (:obj:`bool`): whether to generate the root decomposition of :math:`A` even if it has not been created yet.
766+
generate_inv_roots (:obj:`bool`): whether to generate the root inv decomposition of :math:`A` even if it has not been created yet.
767767
768768
Returns:
769769
:obj:`LazyTensor`: concatenated lazy tensor with the new rows and columns.
@@ -809,20 +809,20 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
809809
new_root[..., :m, :n] = E.evaluate()
810810
new_root[..., m:, : lower_left.shape[-1]] = lower_left
811811
new_root[..., m:, n : (n + schur_root.shape[-1])] = schur_root
812-
813-
if isinstance(E, TriangularLazyTensor) and isinstance(schur_root, TriangularLazyTensor):
814-
# make sure these are actually upper triangular
815-
if getattr(E, "upper", False) or getattr(schur_root, "upper", False):
816-
raise NotImplementedError
817-
# in this case we know new_root is triangular as well
818-
new_root = TriangularLazyTensor(new_root)
819-
new_inv_root = new_root.inverse().transpose(-1, -2)
820-
else:
821-
# otherwise we use the pseudo-inverse of Z as new inv root
822-
new_inv_root = stable_pinverse(new_root).transpose(-2, -1)
812+
if generate_inv_roots:
813+
if isinstance(E, TriangularLazyTensor) and isinstance(schur_root, TriangularLazyTensor):
814+
# make sure these are actually upper triangular
815+
if getattr(E, "upper", False) or getattr(schur_root, "upper", False):
816+
raise NotImplementedError
817+
# in this case we know new_root is triangular as well
818+
new_root = TriangularLazyTensor(new_root)
819+
new_inv_root = new_root.inverse().transpose(-1, -2)
820+
else:
821+
# otherwise we use the pseudo-inverse of Z as new inv root
822+
new_inv_root = stable_pinverse(new_root).transpose(-2, -1)
823+
add_to_cache(new_lazy_tensor, "root_inv_decomposition", RootLazyTensor(lazify(new_inv_root)))
823824

824825
add_to_cache(new_lazy_tensor, "root_decomposition", RootLazyTensor(lazify(new_root)))
825-
add_to_cache(new_lazy_tensor, "root_inv_decomposition", RootLazyTensor(lazify(new_inv_root)))
826826

827827
return new_lazy_tensor
828828

gpytorch/test/lazy_tensor_test_case.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
import gpytorch
1212
from gpytorch.settings import linalg_dtypes
1313
from gpytorch.utils.cholesky import CHOLESKY_METHOD
14-
14+
from gpytorch.utils.memoize import get_from_cache
15+
from gpytorch.utils.errors import CachingError
1516
from .base_test_case import BaseTestCase
1617

1718

@@ -457,13 +458,19 @@ def test_cat_rows(self):
457458
root_rhs = new_lt.root_decomposition().matmul(rhs)
458459
self.assertAllClose(root_rhs, concat_rhs, **self.tolerances["root_decomposition"])
459460

461+
# check that root inv is cached
462+
root_inv = get_from_cache(new_lt, "root_inv_decomposition")
460463
# check that the inverse root decomposition is close
461464
concat_solve = torch.linalg.solve(concatenated_lt, rhs.unsqueeze(-1)).squeeze(-1)
462-
root_inv_solve = new_lt.root_inv_decomposition().matmul(rhs)
465+
root_inv_solve = root_inv.matmul(rhs)
463466
self.assertLess(
464467
(root_inv_solve - concat_solve).norm() / concat_solve.norm(),
465468
self.tolerances["root_inv_decomposition"]["rtol"],
466469
)
470+
# test generate_inv_roots=False
471+
new_lt = lazy_tensor.cat_rows(new_rows, new_point, generate_inv_roots=False)
472+
with self.assertRaises(CachingError):
473+
get_from_cache(new_lt, "root_inv_decomposition")
467474

468475
def test_cholesky(self):
469476
lazy_tensor = self.create_lazy_tensor()

0 commit comments

Comments
 (0)