@@ -774,7 +774,12 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, generate_inv_roots=T
774774 from .cat_lazy_tensor import CatLazyTensor
775775 from .root_lazy_tensor import RootLazyTensor
776776 from .triangular_lazy_tensor import TriangularLazyTensor
777-
777+ if not generate_roots and generate_inv_roots :
778+ warnings .warn (
779+ "root_inv_decomposition is only generated when "
780+ "root_decomposition is generated." ,
781+ UserWarning ,
782+ )
778783 B_ , B = cross_mat , lazify (cross_mat )
779784 D = lazify (new_mat )
780785 batch_shape = B .shape [:- 2 ]
@@ -791,13 +796,13 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, generate_inv_roots=T
791796
792797 # if the old lazy tensor does not have either a root decomposition or a root inverse decomposition
793798 # don't create one
794- does_not_have_roots = any (
795- _is_in_cache_ignore_args (self , key ) for key in ("root_inv_decomposition " , "root_inv_decomposition" )
799+ has_roots = any (
800+ _is_in_cache_ignore_args (self , key ) for key in ("root_decomposition " , "root_inv_decomposition" )
796801 )
797- if not generate_roots and not does_not_have_roots :
802+ if not generate_roots and not has_roots :
798803 return new_lazy_tensor
799804
800- # Get compomnents for new root Z = [E 0; F G]
805+ # Get components for new root Z = [E 0; F G]
801806 E = self .root_decomposition (** root_decomp_kwargs ).root # E = L, LL^T = A
802807 m , n = E .shape [- 2 :]
803808 R = self .root_inv_decomposition ().root .evaluate () # RR^T = A^{-1} (this is fast if L is triangular)
@@ -875,10 +880,10 @@ def add_low_rank(
875880
876881 # if the old lazy tensor does not have either a root decomposition or a root inverse decomposition
877882 # don't create one
878- does_not_have_roots = any (
883+ has_roots = any (
879884 _is_in_cache_ignore_args (self , key ) for key in ("root_decomposition" , "root_inv_decomposition" )
880885 )
881- if not generate_roots and not does_not_have_roots :
886+ if not generate_roots and not has_roots :
882887 return new_lazy_tensor
883888
884889 # we are going to compute the following
0 commit comments