@@ -735,7 +735,7 @@ def add_jitter(self, jitter_val=1e-3):
735735 diag = torch .tensor (jitter_val , dtype = self .dtype , device = self .device )
736736 return self .add_diag (diag )
737737
738- def cat_rows (self , cross_mat , new_mat , generate_roots = True , ** root_decomp_kwargs ):
738+ def cat_rows (self , cross_mat , new_mat , generate_roots = True , generate_inv_roots = True , ** root_decomp_kwargs ):
739739 """
740740 Concatenates new rows and columns to the matrix that this LazyTensor represents, e.g.
741741 C = [A B^T; B D]. where A is the existing lazy tensor, and B (cross_mat) and D (new_mat)
@@ -762,8 +762,8 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
762762 If :math:`A` is n x n, then this matrix should be n x k.
763763 new_mat (:obj:`torch.tensor`): the matrix :math:`D` we are appending to the matrix :math:`A`.
764764 If :math:`B` is n x k, then this matrix should be k x k.
765- generate_roots (:obj:`bool`): whether to generate the root decomposition of :math:`A` even if it
766- has not been created yet.
765+ generate_roots (:obj:`bool`): whether to generate the root decomposition of :math:`A` even if it has not been created yet.
766+ generate_inv_roots (:obj:`bool`): whether to generate the root inv decomposition of :math:`A` even if it has not been created yet.
767767
768768 Returns:
769769 :obj:`LazyTensor`: concatenated lazy tensor with the new rows and columns.
@@ -809,20 +809,20 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
809809 new_root [..., :m , :n ] = E .evaluate ()
810810 new_root [..., m :, : lower_left .shape [- 1 ]] = lower_left
811811 new_root [..., m :, n : (n + schur_root .shape [- 1 ])] = schur_root
812-
813- if isinstance (E , TriangularLazyTensor ) and isinstance (schur_root , TriangularLazyTensor ):
814- # make sure these are actually upper triangular
815- if getattr (E , "upper" , False ) or getattr (schur_root , "upper" , False ):
816- raise NotImplementedError
817- # in this case we know new_root is triangular as well
818- new_root = TriangularLazyTensor (new_root )
819- new_inv_root = new_root .inverse ().transpose (- 1 , - 2 )
820- else :
821- # otherwise we use the pseudo-inverse of Z as new inv root
822- new_inv_root = stable_pinverse (new_root ).transpose (- 2 , - 1 )
812+ if generate_inv_roots :
813+ if isinstance (E , TriangularLazyTensor ) and isinstance (schur_root , TriangularLazyTensor ):
814+ # make sure these are actually upper triangular
815+ if getattr (E , "upper" , False ) or getattr (schur_root , "upper" , False ):
816+ raise NotImplementedError
817+ # in this case we know new_root is triangular as well
818+ new_root = TriangularLazyTensor (new_root )
819+ new_inv_root = new_root .inverse ().transpose (- 1 , - 2 )
820+ else :
821+ # otherwise we use the pseudo-inverse of Z as new inv root
822+ new_inv_root = stable_pinverse (new_root ).transpose (- 2 , - 1 )
823+ add_to_cache (new_lazy_tensor , "root_inv_decomposition" , RootLazyTensor (lazify (new_inv_root )))
823824
824825 add_to_cache (new_lazy_tensor , "root_decomposition" , RootLazyTensor (lazify (new_root )))
825- add_to_cache (new_lazy_tensor , "root_inv_decomposition" , RootLazyTensor (lazify (new_inv_root )))
826826
827827 return new_lazy_tensor
828828
0 commit comments