Skip to content

Commit f06004e

Browse files
authored
Merge pull request #1752 from sdaulton/cat_row_skip_root_inv
add generate_inv_roots option to cat_rows
2 parents 05ebbf2 + e36ca9b commit f06004e

File tree

2 files changed

+52
-34
lines changed

2 files changed

+52
-34
lines changed

gpytorch/lazy/lazy_tensor.py

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,9 @@ def _getitem(self, row_index, col_index, *batch_indices):
245245
# Construct interpolated LazyTensor
246246
from . import InterpolatedLazyTensor
247247

248-
res = InterpolatedLazyTensor(self, row_interp_indices, row_interp_values, col_interp_indices, col_interp_values)
248+
res = InterpolatedLazyTensor(
249+
self, row_interp_indices, row_interp_values, col_interp_indices, col_interp_values,
250+
)
249251
return res._getitem(row_index, col_index, *batch_indices)
250252

251253
def _unsqueeze_batch(self, dim):
@@ -318,7 +320,7 @@ def _get_indices(self, row_index, col_index, *batch_indices):
318320

319321
res = (
320322
InterpolatedLazyTensor(
321-
base_lazy_tensor, row_interp_indices, row_interp_values, col_interp_indices, col_interp_values
323+
base_lazy_tensor, row_interp_indices, row_interp_values, col_interp_indices, col_interp_values,
322324
)
323325
.evaluate()
324326
.squeeze(-2)
@@ -518,7 +520,7 @@ def _mul_matrix(self, other):
518520
else:
519521
left_lazy_tensor = self if self._root_decomposition_size() < other._root_decomposition_size() else other
520522
right_lazy_tensor = other if left_lazy_tensor is self else self
521-
return MulLazyTensor(left_lazy_tensor.root_decomposition(), right_lazy_tensor.root_decomposition())
523+
return MulLazyTensor(left_lazy_tensor.root_decomposition(), right_lazy_tensor.root_decomposition(),)
522524

523525
def _preconditioner(self):
524526
"""
@@ -559,7 +561,7 @@ def _prod_batch(self, dim):
559561
shape = list(roots.shape)
560562
shape[dim] = 1
561563
extra_root = torch.full(
562-
shape, dtype=self.dtype, device=self.device, fill_value=(1.0 / math.sqrt(self.size(-2)))
564+
shape, dtype=self.dtype, device=self.device, fill_value=(1.0 / math.sqrt(self.size(-2))),
563565
)
564566
roots = torch.cat([roots, extra_root], dim)
565567
num_batch += 1
@@ -735,7 +737,9 @@ def add_jitter(self, jitter_val=1e-3):
735737
diag = torch.tensor(jitter_val, dtype=self.dtype, device=self.device)
736738
return self.add_diag(diag)
737739

738-
def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs):
740+
def cat_rows(
741+
self, cross_mat, new_mat, generate_roots=True, generate_inv_roots=True, **root_decomp_kwargs,
742+
):
739743
"""
740744
Concatenates new rows and columns to the matrix that this LazyTensor represents, e.g.
741745
C = [A B^T; B D]. where A is the existing lazy tensor, and B (cross_mat) and D (new_mat)
@@ -762,8 +766,10 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
762766
If :math:`A` is n x n, then this matrix should be n x k.
763767
new_mat (:obj:`torch.tensor`): the matrix :math:`D` we are appending to the matrix :math:`A`.
764768
If :math:`B` is n x k, then this matrix should be k x k.
765-
generate_roots (:obj:`bool`): whether to generate the root decomposition of :math:`A` even if it
766-
has not been created yet.
769+
generate_roots (:obj:`bool`): whether to generate the root
770+
decomposition of :math:`A` even if it has not been created yet.
771+
generate_inv_roots (:obj:`bool`): whether to generate the root inv
772+
decomposition of :math:`A` even if it has not been created yet.
767773
768774
Returns:
769775
:obj:`LazyTensor`: concatenated lazy tensor with the new rows and columns.
@@ -773,6 +779,10 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
773779
from .root_lazy_tensor import RootLazyTensor
774780
from .triangular_lazy_tensor import TriangularLazyTensor
775781

782+
if not generate_roots and generate_inv_roots:
783+
warnings.warn(
784+
"root_inv_decomposition is only generated when " "root_decomposition is generated.", UserWarning,
785+
)
776786
B_, B = cross_mat, lazify(cross_mat)
777787
D = lazify(new_mat)
778788
batch_shape = B.shape[:-2]
@@ -789,13 +799,13 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
789799

790800
# if the old lazy tensor does not have either a root decomposition or a root inverse decomposition
791801
# don't create one
792-
does_not_have_roots = any(
793-
_is_in_cache_ignore_args(self, key) for key in ("root_inv_decomposition", "root_inv_decomposition")
802+
has_roots = any(
803+
_is_in_cache_ignore_args(self, key) for key in ("root_decomposition", "root_inv_decomposition",)
794804
)
795-
if not generate_roots and not does_not_have_roots:
805+
if not generate_roots and not has_roots:
796806
return new_lazy_tensor
797807

798-
# Get compomnents for new root Z = [E 0; F G]
808+
# Get components for new root Z = [E 0; F G]
799809
E = self.root_decomposition(**root_decomp_kwargs).root # E = L, LL^T = A
800810
m, n = E.shape[-2:]
801811
R = self.root_inv_decomposition().root.evaluate() # RR^T = A^{-1} (this is fast if L is triangular)
@@ -809,20 +819,22 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
809819
new_root[..., :m, :n] = E.evaluate()
810820
new_root[..., m:, : lower_left.shape[-1]] = lower_left
811821
new_root[..., m:, n : (n + schur_root.shape[-1])] = schur_root
812-
813-
if isinstance(E, TriangularLazyTensor) and isinstance(schur_root, TriangularLazyTensor):
814-
# make sure these are actually upper triangular
815-
if getattr(E, "upper", False) or getattr(schur_root, "upper", False):
816-
raise NotImplementedError
817-
# in this case we know new_root is triangular as well
818-
new_root = TriangularLazyTensor(new_root)
819-
new_inv_root = new_root.inverse().transpose(-1, -2)
820-
else:
821-
# otherwise we use the pseudo-inverse of Z as new inv root
822-
new_inv_root = stable_pinverse(new_root).transpose(-2, -1)
822+
if generate_inv_roots:
823+
if isinstance(E, TriangularLazyTensor) and isinstance(schur_root, TriangularLazyTensor):
824+
# make sure these are actually upper triangular
825+
if getattr(E, "upper", False) or getattr(schur_root, "upper", False):
826+
raise NotImplementedError
827+
# in this case we know new_root is triangular as well
828+
new_root = TriangularLazyTensor(new_root)
829+
new_inv_root = new_root.inverse().transpose(-1, -2)
830+
else:
831+
# otherwise we use the pseudo-inverse of Z as new inv root
832+
new_inv_root = stable_pinverse(new_root).transpose(-2, -1)
833+
add_to_cache(
834+
new_lazy_tensor, "root_inv_decomposition", RootLazyTensor(lazify(new_inv_root)),
835+
)
823836

824837
add_to_cache(new_lazy_tensor, "root_decomposition", RootLazyTensor(lazify(new_root)))
825-
add_to_cache(new_lazy_tensor, "root_inv_decomposition", RootLazyTensor(lazify(new_inv_root)))
826838

827839
return new_lazy_tensor
828840

@@ -864,7 +876,7 @@ def add_low_rank(
864876
new_lazy_tensor = self + lazify(low_rank_mat.matmul(low_rank_mat.transpose(-1, -2)))
865877
else:
866878
new_lazy_tensor = SumLazyTensor(
867-
*self.lazy_tensors, lazify(low_rank_mat.matmul(low_rank_mat.transpose(-1, -2)))
879+
*self.lazy_tensors, lazify(low_rank_mat.matmul(low_rank_mat.transpose(-1, -2))),
868880
)
869881

870882
# return as a nonlazy tensor if small enough to reduce memory overhead
@@ -873,10 +885,8 @@ def add_low_rank(
873885

874886
# if the old lazy tensor does not have either a root decomposition or a root inverse decomposition
875887
# don't create one
876-
does_not_have_roots = any(
877-
_is_in_cache_ignore_args(self, key) for key in ("root_decomposition", "root_inv_decomposition")
878-
)
879-
if not generate_roots and not does_not_have_roots:
888+
has_roots = any(_is_in_cache_ignore_args(self, key) for key in ("root_decomposition", "root_inv_decomposition"))
889+
if not generate_roots and not has_roots:
880890
return new_lazy_tensor
881891

882892
# we are going to compute the following
@@ -914,7 +924,7 @@ def add_low_rank(
914924
updated_root = torch.cat(
915925
(
916926
current_root.evaluate(),
917-
torch.zeros(*current_root.shape[:-1], 1, device=current_root.device, dtype=current_root.dtype),
927+
torch.zeros(*current_root.shape[:-1], 1, device=current_root.device, dtype=current_root.dtype,),
918928
),
919929
dim=-1,
920930
)
@@ -1174,7 +1184,7 @@ def inv_matmul(self, right_tensor, left_tensor=None):
11741184
if left_tensor is None:
11751185
return func.apply(self.representation_tree(), False, right_tensor, *self.representation())
11761186
else:
1177-
return func.apply(self.representation_tree(), True, left_tensor, right_tensor, *self.representation())
1187+
return func.apply(self.representation_tree(), True, left_tensor, right_tensor, *self.representation(),)
11781188

11791189
def inv_quad(self, tensor, reduce_inv_quad=True):
11801190
"""
@@ -1241,7 +1251,7 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True)
12411251
will_need_cholesky = False
12421252
if will_need_cholesky:
12431253
cholesky = CholLazyTensor(TriangularLazyTensor(self.cholesky()))
1244-
return cholesky.inv_quad_logdet(inv_quad_rhs=inv_quad_rhs, logdet=logdet, reduce_inv_quad=reduce_inv_quad)
1254+
return cholesky.inv_quad_logdet(inv_quad_rhs=inv_quad_rhs, logdet=logdet, reduce_inv_quad=reduce_inv_quad,)
12451255

12461256
# Default: use modified batch conjugate gradients to compute these terms
12471257
# See NeurIPS 2018 paper: https://arxiv.org/abs/1809.11165
@@ -1988,7 +1998,7 @@ def zero_mean_mvn_samples(self, num_samples):
19881998

19891999
if settings.ciq_samples.on():
19902000
base_samples = torch.randn(
1991-
*self.batch_shape, self.size(-1), num_samples, dtype=self.dtype, device=self.device
2001+
*self.batch_shape, self.size(-1), num_samples, dtype=self.dtype, device=self.device,
19922002
)
19932003
base_samples = base_samples.permute(-1, *range(self.dim() - 1)).contiguous()
19942004
base_samples = base_samples.unsqueeze(-1)
@@ -2008,7 +2018,7 @@ def zero_mean_mvn_samples(self, num_samples):
20082018
covar_root = self.root_decomposition().root
20092019

20102020
base_samples = torch.randn(
2011-
*self.batch_shape, covar_root.size(-1), num_samples, dtype=self.dtype, device=self.device
2021+
*self.batch_shape, covar_root.size(-1), num_samples, dtype=self.dtype, device=self.device,
20122022
)
20132023
samples = covar_root.matmul(base_samples).permute(-1, *range(self.dim() - 1)).contiguous()
20142024

gpytorch/test/lazy_tensor_test_case.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import gpytorch
1212
from gpytorch.settings import linalg_dtypes
1313
from gpytorch.utils.cholesky import CHOLESKY_METHOD
14+
from gpytorch.utils.errors import CachingError
15+
from gpytorch.utils.memoize import get_from_cache
1416

1517
from .base_test_case import BaseTestCase
1618

@@ -457,13 +459,19 @@ def test_cat_rows(self):
457459
root_rhs = new_lt.root_decomposition().matmul(rhs)
458460
self.assertAllClose(root_rhs, concat_rhs, **self.tolerances["root_decomposition"])
459461

462+
# check that root inv is cached
463+
root_inv = get_from_cache(new_lt, "root_inv_decomposition")
460464
# check that the inverse root decomposition is close
461465
concat_solve = torch.linalg.solve(concatenated_lt, rhs.unsqueeze(-1)).squeeze(-1)
462-
root_inv_solve = new_lt.root_inv_decomposition().matmul(rhs)
466+
root_inv_solve = root_inv.matmul(rhs)
463467
self.assertLess(
464468
(root_inv_solve - concat_solve).norm() / concat_solve.norm(),
465469
self.tolerances["root_inv_decomposition"]["rtol"],
466470
)
471+
# test generate_inv_roots=False
472+
new_lt = lazy_tensor.cat_rows(new_rows, new_point, generate_inv_roots=False)
473+
with self.assertRaises(CachingError):
474+
get_from_cache(new_lt, "root_inv_decomposition")
467475

468476
def test_cholesky(self):
469477
lazy_tensor = self.create_lazy_tensor()

0 commit comments

Comments
 (0)