Skip to content

Commit 1cd7c9d

Browse files
add warning and update variable names
1 parent 32924e2 commit 1cd7c9d

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

gpytorch/lazy/lazy_tensor.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,12 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, generate_inv_roots=T
774774
from .cat_lazy_tensor import CatLazyTensor
775775
from .root_lazy_tensor import RootLazyTensor
776776
from .triangular_lazy_tensor import TriangularLazyTensor
777-
777+
if not generate_roots and generate_inv_roots:
778+
warnings.warn(
779+
"root_inv_decomposition is only generated when "
780+
"root_decomposition is generated.",
781+
UserWarning,
782+
)
778783
B_, B = cross_mat, lazify(cross_mat)
779784
D = lazify(new_mat)
780785
batch_shape = B.shape[:-2]
@@ -791,13 +796,13 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, generate_inv_roots=T
791796

792797
# if the old lazy tensor does not have either a root decomposition or a root inverse decomposition
793798
# don't create one
794-
does_not_have_roots = any(
795-
_is_in_cache_ignore_args(self, key) for key in ("root_inv_decomposition", "root_inv_decomposition")
799+
has_roots = any(
800+
_is_in_cache_ignore_args(self, key) for key in ("root_decomposition", "root_inv_decomposition")
796801
)
797-
if not generate_roots and not does_not_have_roots:
802+
if not generate_roots and not has_roots:
798803
return new_lazy_tensor
799804

800-
# Get compomnents for new root Z = [E 0; F G]
805+
# Get components for new root Z = [E 0; F G]
801806
E = self.root_decomposition(**root_decomp_kwargs).root # E = L, LL^T = A
802807
m, n = E.shape[-2:]
803808
R = self.root_inv_decomposition().root.evaluate() # RR^T = A^{-1} (this is fast if L is triangular)
@@ -875,10 +880,10 @@ def add_low_rank(
875880

876881
# if the old lazy tensor does not have either a root decomposition or a root inverse decomposition
877882
# don't create one
878-
does_not_have_roots = any(
883+
has_roots = any(
879884
_is_in_cache_ignore_args(self, key) for key in ("root_decomposition", "root_inv_decomposition")
880885
)
881-
if not generate_roots and not does_not_have_roots:
886+
if not generate_roots and not has_roots:
882887
return new_lazy_tensor
883888

884889
# we are going to compute the following

0 commit comments

Comments
 (0)