@@ -245,7 +245,9 @@ def _getitem(self, row_index, col_index, *batch_indices):
245245        # Construct interpolated LazyTensor 
246246        from  . import  InterpolatedLazyTensor 
247247
248-         res  =  InterpolatedLazyTensor (self , row_interp_indices , row_interp_values , col_interp_indices , col_interp_values )
248+         res  =  InterpolatedLazyTensor (
249+             self , row_interp_indices , row_interp_values , col_interp_indices , col_interp_values ,
250+         )
249251        return  res ._getitem (row_index , col_index , * batch_indices )
250252
251253    def  _unsqueeze_batch (self , dim ):
@@ -318,7 +320,7 @@ def _get_indices(self, row_index, col_index, *batch_indices):
318320
319321        res  =  (
320322            InterpolatedLazyTensor (
321-                 base_lazy_tensor , row_interp_indices , row_interp_values , col_interp_indices , col_interp_values 
323+                 base_lazy_tensor , row_interp_indices , row_interp_values , col_interp_indices , col_interp_values , 
322324            )
323325            .evaluate ()
324326            .squeeze (- 2 )
@@ -518,7 +520,7 @@ def _mul_matrix(self, other):
518520        else :
519521            left_lazy_tensor  =  self  if  self ._root_decomposition_size () <  other ._root_decomposition_size () else  other 
520522            right_lazy_tensor  =  other  if  left_lazy_tensor  is  self  else  self 
521-             return  MulLazyTensor (left_lazy_tensor .root_decomposition (), right_lazy_tensor .root_decomposition ())
523+             return  MulLazyTensor (left_lazy_tensor .root_decomposition (), right_lazy_tensor .root_decomposition (), )
522524
523525    def  _preconditioner (self ):
524526        """ 
@@ -559,7 +561,7 @@ def _prod_batch(self, dim):
559561                shape  =  list (roots .shape )
560562                shape [dim ] =  1 
561563                extra_root  =  torch .full (
562-                     shape , dtype = self .dtype , device = self .device , fill_value = (1.0  /  math .sqrt (self .size (- 2 )))
564+                     shape , dtype = self .dtype , device = self .device , fill_value = (1.0  /  math .sqrt (self .size (- 2 ))), 
563565                )
564566                roots  =  torch .cat ([roots , extra_root ], dim )
565567                num_batch  +=  1 
@@ -735,7 +737,9 @@ def add_jitter(self, jitter_val=1e-3):
735737        diag  =  torch .tensor (jitter_val , dtype = self .dtype , device = self .device )
736738        return  self .add_diag (diag )
737739
738-     def  cat_rows (self , cross_mat , new_mat , generate_roots = True , ** root_decomp_kwargs ):
740+     def  cat_rows (
741+         self , cross_mat , new_mat , generate_roots = True , generate_inv_roots = True , ** root_decomp_kwargs ,
742+     ):
739743        """ 
740744        Concatenates new rows and columns to the matrix that this LazyTensor represents, e.g. 
741745        C = [A B^T; B D]. where A is the existing lazy tensor, and B (cross_mat) and D (new_mat) 
@@ -762,8 +766,10 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
762766                If :math:`A` is n x n, then this matrix should be n x k. 
763767            new_mat (:obj:`torch.tensor`): the matrix :math:`D` we are appending to the matrix :math:`A`. 
764768                If :math:`B` is n x k, then this matrix should be k x k. 
765-             generate_roots (:obj:`bool`): whether to generate the root decomposition of :math:`A` even if it 
766-                 has not been created yet. 
769+             generate_roots (:obj:`bool`): whether to generate the root 
770+                 decomposition of :math:`A` even if it has not been created yet. 
771+             generate_inv_roots (:obj:`bool`): whether to generate the root inv 
772+                 decomposition of :math:`A` even if it has not been created yet. 
767773
768774        Returns: 
769775            :obj:`LazyTensor`: concatenated lazy tensor with the new rows and columns. 
@@ -773,6 +779,10 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
773779        from  .root_lazy_tensor  import  RootLazyTensor 
774780        from  .triangular_lazy_tensor  import  TriangularLazyTensor 
775781
782+         if  not  generate_roots  and  generate_inv_roots :
783+             warnings .warn (
784+                 "root_inv_decomposition is only generated when "  "root_decomposition is generated." , UserWarning ,
785+             )
776786        B_ , B  =  cross_mat , lazify (cross_mat )
777787        D  =  lazify (new_mat )
778788        batch_shape  =  B .shape [:- 2 ]
@@ -789,13 +799,13 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
789799
790800        # if the old lazy tensor does not have either a root decomposition or a root inverse decomposition 
791801        # don't create one 
792-         does_not_have_roots  =  any (
793-             _is_in_cache_ignore_args (self , key ) for  key  in  ("root_inv_decomposition " , "root_inv_decomposition" )
802+         has_roots  =  any (
803+             _is_in_cache_ignore_args (self , key ) for  key  in  ("root_decomposition " , "root_inv_decomposition" , )
794804        )
795-         if  not  generate_roots  and  not  does_not_have_roots :
805+         if  not  generate_roots  and  not  has_roots :
796806            return  new_lazy_tensor 
797807
798-         # Get compomnents  for new root Z = [E 0; F G] 
808+         # Get components  for new root Z = [E 0; F G] 
799809        E  =  self .root_decomposition (** root_decomp_kwargs ).root   # E = L, LL^T = A 
800810        m , n  =  E .shape [- 2 :]
801811        R  =  self .root_inv_decomposition ().root .evaluate ()  # RR^T = A^{-1} (this is fast if L is triangular) 
@@ -809,20 +819,22 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
809819        new_root [..., :m , :n ] =  E .evaluate ()
810820        new_root [..., m :, : lower_left .shape [- 1 ]] =  lower_left 
811821        new_root [..., m :, n  : (n  +  schur_root .shape [- 1 ])] =  schur_root 
812- 
813-         if  isinstance (E , TriangularLazyTensor ) and  isinstance (schur_root , TriangularLazyTensor ):
814-             # make sure these are actually upper triangular 
815-             if  getattr (E , "upper" , False ) or  getattr (schur_root , "upper" , False ):
816-                 raise  NotImplementedError 
817-             # in this case we know new_root is triangular as well 
818-             new_root  =  TriangularLazyTensor (new_root )
819-             new_inv_root  =  new_root .inverse ().transpose (- 1 , - 2 )
820-         else :
821-             # otherwise we use the pseudo-inverse of Z as new inv root 
822-             new_inv_root  =  stable_pinverse (new_root ).transpose (- 2 , - 1 )
822+         if  generate_inv_roots :
823+             if  isinstance (E , TriangularLazyTensor ) and  isinstance (schur_root , TriangularLazyTensor ):
824+                 # make sure these are actually upper triangular 
825+                 if  getattr (E , "upper" , False ) or  getattr (schur_root , "upper" , False ):
826+                     raise  NotImplementedError 
827+                 # in this case we know new_root is triangular as well 
828+                 new_root  =  TriangularLazyTensor (new_root )
829+                 new_inv_root  =  new_root .inverse ().transpose (- 1 , - 2 )
830+             else :
831+                 # otherwise we use the pseudo-inverse of Z as new inv root 
832+                 new_inv_root  =  stable_pinverse (new_root ).transpose (- 2 , - 1 )
833+             add_to_cache (
834+                 new_lazy_tensor , "root_inv_decomposition" , RootLazyTensor (lazify (new_inv_root )),
835+             )
823836
824837        add_to_cache (new_lazy_tensor , "root_decomposition" , RootLazyTensor (lazify (new_root )))
825-         add_to_cache (new_lazy_tensor , "root_inv_decomposition" , RootLazyTensor (lazify (new_inv_root )))
826838
827839        return  new_lazy_tensor 
828840
@@ -864,7 +876,7 @@ def add_low_rank(
864876            new_lazy_tensor  =  self  +  lazify (low_rank_mat .matmul (low_rank_mat .transpose (- 1 , - 2 )))
865877        else :
866878            new_lazy_tensor  =  SumLazyTensor (
867-                 * self .lazy_tensors , lazify (low_rank_mat .matmul (low_rank_mat .transpose (- 1 , - 2 )))
879+                 * self .lazy_tensors , lazify (low_rank_mat .matmul (low_rank_mat .transpose (- 1 , - 2 ))), 
868880            )
869881
870882            # return as a nonlazy tensor if small enough to reduce memory overhead 
@@ -873,10 +885,8 @@ def add_low_rank(
873885
874886        # if the old lazy tensor does not have either a root decomposition or a root inverse decomposition 
875887        # don't create one 
876-         does_not_have_roots  =  any (
877-             _is_in_cache_ignore_args (self , key ) for  key  in  ("root_decomposition" , "root_inv_decomposition" )
878-         )
879-         if  not  generate_roots  and  not  does_not_have_roots :
888+         has_roots  =  any (_is_in_cache_ignore_args (self , key ) for  key  in  ("root_decomposition" , "root_inv_decomposition" ))
889+         if  not  generate_roots  and  not  has_roots :
880890            return  new_lazy_tensor 
881891
882892        # we are going to compute the following 
@@ -914,7 +924,7 @@ def add_low_rank(
914924            updated_root  =  torch .cat (
915925                (
916926                    current_root .evaluate (),
917-                     torch .zeros (* current_root .shape [:- 1 ], 1 , device = current_root .device , dtype = current_root .dtype ),
927+                     torch .zeros (* current_root .shape [:- 1 ], 1 , device = current_root .device , dtype = current_root .dtype , ),
918928                ),
919929                dim = - 1 ,
920930            )
@@ -1174,7 +1184,7 @@ def inv_matmul(self, right_tensor, left_tensor=None):
11741184        if  left_tensor  is  None :
11751185            return  func .apply (self .representation_tree (), False , right_tensor , * self .representation ())
11761186        else :
1177-             return  func .apply (self .representation_tree (), True , left_tensor , right_tensor , * self .representation ())
1187+             return  func .apply (self .representation_tree (), True , left_tensor , right_tensor , * self .representation (), )
11781188
11791189    def  inv_quad (self , tensor , reduce_inv_quad = True ):
11801190        """ 
@@ -1241,7 +1251,7 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True)
12411251                    will_need_cholesky  =  False 
12421252            if  will_need_cholesky :
12431253                cholesky  =  CholLazyTensor (TriangularLazyTensor (self .cholesky ()))
1244-             return  cholesky .inv_quad_logdet (inv_quad_rhs = inv_quad_rhs , logdet = logdet , reduce_inv_quad = reduce_inv_quad )
1254+             return  cholesky .inv_quad_logdet (inv_quad_rhs = inv_quad_rhs , logdet = logdet , reduce_inv_quad = reduce_inv_quad , )
12451255
12461256        # Default: use modified batch conjugate gradients to compute these terms 
12471257        # See NeurIPS 2018 paper: https://arxiv.org/abs/1809.11165 
@@ -1988,7 +1998,7 @@ def zero_mean_mvn_samples(self, num_samples):
19881998
19891999        if  settings .ciq_samples .on ():
19902000            base_samples  =  torch .randn (
1991-                 * self .batch_shape , self .size (- 1 ), num_samples , dtype = self .dtype , device = self .device 
2001+                 * self .batch_shape , self .size (- 1 ), num_samples , dtype = self .dtype , device = self .device , 
19922002            )
19932003            base_samples  =  base_samples .permute (- 1 , * range (self .dim () -  1 )).contiguous ()
19942004            base_samples  =  base_samples .unsqueeze (- 1 )
@@ -2008,7 +2018,7 @@ def zero_mean_mvn_samples(self, num_samples):
20082018                covar_root  =  self .root_decomposition ().root 
20092019
20102020            base_samples  =  torch .randn (
2011-                 * self .batch_shape , covar_root .size (- 1 ), num_samples , dtype = self .dtype , device = self .device 
2021+                 * self .batch_shape , covar_root .size (- 1 ), num_samples , dtype = self .dtype , device = self .device , 
20122022            )
20132023            samples  =  covar_root .matmul (base_samples ).permute (- 1 , * range (self .dim () -  1 )).contiguous ()
20142024
0 commit comments