@@ -245,7 +245,9 @@ def _getitem(self, row_index, col_index, *batch_indices):
245245 # Construct interpolated LazyTensor
246246 from . import InterpolatedLazyTensor
247247
248- res = InterpolatedLazyTensor (self , row_interp_indices , row_interp_values , col_interp_indices , col_interp_values )
248+ res = InterpolatedLazyTensor (
249+ self , row_interp_indices , row_interp_values , col_interp_indices , col_interp_values ,
250+ )
249251 return res ._getitem (row_index , col_index , * batch_indices )
250252
251253 def _unsqueeze_batch (self , dim ):
@@ -318,7 +320,7 @@ def _get_indices(self, row_index, col_index, *batch_indices):
318320
319321 res = (
320322 InterpolatedLazyTensor (
321- base_lazy_tensor , row_interp_indices , row_interp_values , col_interp_indices , col_interp_values
323+ base_lazy_tensor , row_interp_indices , row_interp_values , col_interp_indices , col_interp_values ,
322324 )
323325 .evaluate ()
324326 .squeeze (- 2 )
@@ -518,7 +520,7 @@ def _mul_matrix(self, other):
518520 else :
519521 left_lazy_tensor = self if self ._root_decomposition_size () < other ._root_decomposition_size () else other
520522 right_lazy_tensor = other if left_lazy_tensor is self else self
521- return MulLazyTensor (left_lazy_tensor .root_decomposition (), right_lazy_tensor .root_decomposition ())
523+ return MulLazyTensor (left_lazy_tensor .root_decomposition (), right_lazy_tensor .root_decomposition (), )
522524
523525 def _preconditioner (self ):
524526 """
@@ -559,7 +561,7 @@ def _prod_batch(self, dim):
559561 shape = list (roots .shape )
560562 shape [dim ] = 1
561563 extra_root = torch .full (
562- shape , dtype = self .dtype , device = self .device , fill_value = (1.0 / math .sqrt (self .size (- 2 )))
564+ shape , dtype = self .dtype , device = self .device , fill_value = (1.0 / math .sqrt (self .size (- 2 ))),
563565 )
564566 roots = torch .cat ([roots , extra_root ], dim )
565567 num_batch += 1
@@ -735,7 +737,9 @@ def add_jitter(self, jitter_val=1e-3):
735737 diag = torch .tensor (jitter_val , dtype = self .dtype , device = self .device )
736738 return self .add_diag (diag )
737739
738- def cat_rows (self , cross_mat , new_mat , generate_roots = True , ** root_decomp_kwargs ):
740+ def cat_rows (
741+ self , cross_mat , new_mat , generate_roots = True , generate_inv_roots = True , ** root_decomp_kwargs ,
742+ ):
739743 """
740744 Concatenates new rows and columns to the matrix that this LazyTensor represents, e.g.
741745 C = [A B^T; B D]. where A is the existing lazy tensor, and B (cross_mat) and D (new_mat)
@@ -762,8 +766,10 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
762766 If :math:`A` is n x n, then this matrix should be n x k.
763767 new_mat (:obj:`torch.tensor`): the matrix :math:`D` we are appending to the matrix :math:`A`.
764768 If :math:`B` is n x k, then this matrix should be k x k.
765- generate_roots (:obj:`bool`): whether to generate the root decomposition of :math:`A` even if it
766- has not been created yet.
769+ generate_roots (:obj:`bool`): whether to generate the root
770+ decomposition of :math:`A` even if it has not been created yet.
771+ generate_inv_roots (:obj:`bool`): whether to generate the root inv
772+ decomposition of :math:`A` even if it has not been created yet.
767773
768774 Returns:
769775 :obj:`LazyTensor`: concatenated lazy tensor with the new rows and columns.
@@ -773,6 +779,10 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
773779 from .root_lazy_tensor import RootLazyTensor
774780 from .triangular_lazy_tensor import TriangularLazyTensor
775781
782+ if not generate_roots and generate_inv_roots :
783+ warnings .warn (
784+ "root_inv_decomposition is only generated when " "root_decomposition is generated." , UserWarning ,
785+ )
776786 B_ , B = cross_mat , lazify (cross_mat )
777787 D = lazify (new_mat )
778788 batch_shape = B .shape [:- 2 ]
@@ -789,13 +799,13 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
789799
790800 # if the old lazy tensor does not have either a root decomposition or a root inverse decomposition
791801 # don't create one
792- does_not_have_roots = any (
793- _is_in_cache_ignore_args (self , key ) for key in ("root_inv_decomposition " , "root_inv_decomposition" )
802+ has_roots = any (
803+ _is_in_cache_ignore_args (self , key ) for key in ("root_decomposition " , "root_inv_decomposition" , )
794804 )
795- if not generate_roots and not does_not_have_roots :
805+ if not generate_roots and not has_roots :
796806 return new_lazy_tensor
797807
798- # Get compomnents for new root Z = [E 0; F G]
808+ # Get components for new root Z = [E 0; F G]
799809 E = self .root_decomposition (** root_decomp_kwargs ).root # E = L, LL^T = A
800810 m , n = E .shape [- 2 :]
801811 R = self .root_inv_decomposition ().root .evaluate () # RR^T = A^{-1} (this is fast if L is triangular)
@@ -809,20 +819,22 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
809819 new_root [..., :m , :n ] = E .evaluate ()
810820 new_root [..., m :, : lower_left .shape [- 1 ]] = lower_left
811821 new_root [..., m :, n : (n + schur_root .shape [- 1 ])] = schur_root
812-
813- if isinstance (E , TriangularLazyTensor ) and isinstance (schur_root , TriangularLazyTensor ):
814- # make sure these are actually upper triangular
815- if getattr (E , "upper" , False ) or getattr (schur_root , "upper" , False ):
816- raise NotImplementedError
817- # in this case we know new_root is triangular as well
818- new_root = TriangularLazyTensor (new_root )
819- new_inv_root = new_root .inverse ().transpose (- 1 , - 2 )
820- else :
821- # otherwise we use the pseudo-inverse of Z as new inv root
822- new_inv_root = stable_pinverse (new_root ).transpose (- 2 , - 1 )
822+ if generate_inv_roots :
823+ if isinstance (E , TriangularLazyTensor ) and isinstance (schur_root , TriangularLazyTensor ):
824+ # make sure these are actually upper triangular
825+ if getattr (E , "upper" , False ) or getattr (schur_root , "upper" , False ):
826+ raise NotImplementedError
827+ # in this case we know new_root is triangular as well
828+ new_root = TriangularLazyTensor (new_root )
829+ new_inv_root = new_root .inverse ().transpose (- 1 , - 2 )
830+ else :
831+ # otherwise we use the pseudo-inverse of Z as new inv root
832+ new_inv_root = stable_pinverse (new_root ).transpose (- 2 , - 1 )
833+ add_to_cache (
834+ new_lazy_tensor , "root_inv_decomposition" , RootLazyTensor (lazify (new_inv_root )),
835+ )
823836
824837 add_to_cache (new_lazy_tensor , "root_decomposition" , RootLazyTensor (lazify (new_root )))
825- add_to_cache (new_lazy_tensor , "root_inv_decomposition" , RootLazyTensor (lazify (new_inv_root )))
826838
827839 return new_lazy_tensor
828840
@@ -864,7 +876,7 @@ def add_low_rank(
864876 new_lazy_tensor = self + lazify (low_rank_mat .matmul (low_rank_mat .transpose (- 1 , - 2 )))
865877 else :
866878 new_lazy_tensor = SumLazyTensor (
867- * self .lazy_tensors , lazify (low_rank_mat .matmul (low_rank_mat .transpose (- 1 , - 2 )))
879+ * self .lazy_tensors , lazify (low_rank_mat .matmul (low_rank_mat .transpose (- 1 , - 2 ))),
868880 )
869881
870882 # return as a nonlazy tensor if small enough to reduce memory overhead
@@ -873,10 +885,8 @@ def add_low_rank(
873885
874886 # if the old lazy tensor does not have either a root decomposition or a root inverse decomposition
875887 # don't create one
876- does_not_have_roots = any (
877- _is_in_cache_ignore_args (self , key ) for key in ("root_decomposition" , "root_inv_decomposition" )
878- )
879- if not generate_roots and not does_not_have_roots :
888+ has_roots = any (_is_in_cache_ignore_args (self , key ) for key in ("root_decomposition" , "root_inv_decomposition" ))
889+ if not generate_roots and not has_roots :
880890 return new_lazy_tensor
881891
882892 # we are going to compute the following
@@ -914,7 +924,7 @@ def add_low_rank(
914924 updated_root = torch .cat (
915925 (
916926 current_root .evaluate (),
917- torch .zeros (* current_root .shape [:- 1 ], 1 , device = current_root .device , dtype = current_root .dtype ),
927+ torch .zeros (* current_root .shape [:- 1 ], 1 , device = current_root .device , dtype = current_root .dtype , ),
918928 ),
919929 dim = - 1 ,
920930 )
@@ -1174,7 +1184,7 @@ def inv_matmul(self, right_tensor, left_tensor=None):
11741184 if left_tensor is None :
11751185 return func .apply (self .representation_tree (), False , right_tensor , * self .representation ())
11761186 else :
1177- return func .apply (self .representation_tree (), True , left_tensor , right_tensor , * self .representation ())
1187+ return func .apply (self .representation_tree (), True , left_tensor , right_tensor , * self .representation (), )
11781188
11791189 def inv_quad (self , tensor , reduce_inv_quad = True ):
11801190 """
@@ -1241,7 +1251,7 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True)
12411251 will_need_cholesky = False
12421252 if will_need_cholesky :
12431253 cholesky = CholLazyTensor (TriangularLazyTensor (self .cholesky ()))
1244- return cholesky .inv_quad_logdet (inv_quad_rhs = inv_quad_rhs , logdet = logdet , reduce_inv_quad = reduce_inv_quad )
1254+ return cholesky .inv_quad_logdet (inv_quad_rhs = inv_quad_rhs , logdet = logdet , reduce_inv_quad = reduce_inv_quad , )
12451255
12461256 # Default: use modified batch conjugate gradients to compute these terms
12471257 # See NeurIPS 2018 paper: https://arxiv.org/abs/1809.11165
@@ -1988,7 +1998,7 @@ def zero_mean_mvn_samples(self, num_samples):
19881998
19891999 if settings .ciq_samples .on ():
19902000 base_samples = torch .randn (
1991- * self .batch_shape , self .size (- 1 ), num_samples , dtype = self .dtype , device = self .device
2001+ * self .batch_shape , self .size (- 1 ), num_samples , dtype = self .dtype , device = self .device ,
19922002 )
19932003 base_samples = base_samples .permute (- 1 , * range (self .dim () - 1 )).contiguous ()
19942004 base_samples = base_samples .unsqueeze (- 1 )
@@ -2008,7 +2018,7 @@ def zero_mean_mvn_samples(self, num_samples):
20082018 covar_root = self .root_decomposition ().root
20092019
20102020 base_samples = torch .randn (
2011- * self .batch_shape , covar_root .size (- 1 ), num_samples , dtype = self .dtype , device = self .device
2021+ * self .batch_shape , covar_root .size (- 1 ), num_samples , dtype = self .dtype , device = self .device ,
20122022 )
20132023 samples = covar_root .matmul (base_samples ).permute (- 1 , * range (self .dim () - 1 )).contiguous ()
20142024
0 commit comments