1919from ..functions ._matmul import Matmul
2020from ..functions ._root_decomposition import RootDecomposition
2121from ..functions ._sqrt_inv_matmul import SqrtInvMatmul
22- from ..utils .broadcasting import (
23- _matmul_broadcast_shape ,
24- _mul_broadcast_shape ,
25- _to_helper ,
26- )
22+ from ..utils .broadcasting import _matmul_broadcast_shape , _mul_broadcast_shape , _to_helper
2723from ..utils .cholesky import psd_safe_cholesky
2824from ..utils .deprecation import _deprecate_renamed_methods
2925from ..utils .errors import CachingError
30- from ..utils .getitem import (
31- _compute_getitem_size ,
32- _convert_indices_to_tensors ,
33- _is_noop_index ,
34- _noop_index ,
35- )
26+ from ..utils .getitem import _compute_getitem_size , _convert_indices_to_tensors , _is_noop_index , _noop_index
3627from ..utils .lanczos import _postprocess_lanczos_root_inv_decomp
37- from ..utils .memoize import (
38- _is_in_cache_ignore_all_args ,
39- _is_in_cache_ignore_args ,
40- add_to_cache ,
41- cached ,
42- pop_from_cache ,
43- )
28+ from ..utils .memoize import _is_in_cache_ignore_all_args , _is_in_cache_ignore_args , add_to_cache , cached , pop_from_cache
4429from ..utils .pinverse import stable_pinverse
4530from ..utils .pivoted_cholesky import pivoted_cholesky
4631from ..utils .warnings import NumericalWarning
@@ -261,11 +246,7 @@ def _getitem(self, row_index, col_index, *batch_indices):
261246 from . import InterpolatedLazyTensor
262247
263248 res = InterpolatedLazyTensor (
264- self ,
265- row_interp_indices ,
266- row_interp_values ,
267- col_interp_indices ,
268- col_interp_values ,
249+ self , row_interp_indices , row_interp_values , col_interp_indices , col_interp_values ,
269250 )
270251 return res ._getitem (row_index , col_index , * batch_indices )
271252
@@ -339,11 +320,7 @@ def _get_indices(self, row_index, col_index, *batch_indices):
339320
340321 res = (
341322 InterpolatedLazyTensor (
342- base_lazy_tensor ,
343- row_interp_indices ,
344- row_interp_values ,
345- col_interp_indices ,
346- col_interp_values ,
323+ base_lazy_tensor , row_interp_indices , row_interp_values , col_interp_indices , col_interp_values ,
347324 )
348325 .evaluate ()
349326 .squeeze (- 2 )
@@ -543,10 +520,7 @@ def _mul_matrix(self, other):
543520 else :
544521 left_lazy_tensor = self if self ._root_decomposition_size () < other ._root_decomposition_size () else other
545522 right_lazy_tensor = other if left_lazy_tensor is self else self
546- return MulLazyTensor (
547- left_lazy_tensor .root_decomposition (),
548- right_lazy_tensor .root_decomposition (),
549- )
523+ return MulLazyTensor (left_lazy_tensor .root_decomposition (), right_lazy_tensor .root_decomposition (),)
550524
551525 def _preconditioner (self ):
552526 """
@@ -587,10 +561,7 @@ def _prod_batch(self, dim):
587561 shape = list (roots .shape )
588562 shape [dim ] = 1
589563 extra_root = torch .full (
590- shape ,
591- dtype = self .dtype ,
592- device = self .device ,
593- fill_value = (1.0 / math .sqrt (self .size (- 2 ))),
564+ shape , dtype = self .dtype , device = self .device , fill_value = (1.0 / math .sqrt (self .size (- 2 ))),
594565 )
595566 roots = torch .cat ([roots , extra_root ], dim )
596567 num_batch += 1
@@ -767,12 +738,7 @@ def add_jitter(self, jitter_val=1e-3):
767738 return self .add_diag (diag )
768739
769740 def cat_rows (
770- self ,
771- cross_mat ,
772- new_mat ,
773- generate_roots = True ,
774- generate_inv_roots = True ,
775- ** root_decomp_kwargs ,
741+ self , cross_mat , new_mat , generate_roots = True , generate_inv_roots = True , ** root_decomp_kwargs ,
776742 ):
777743 """
778744 Concatenates new rows and columns to the matrix that this LazyTensor represents, e.g.
@@ -815,8 +781,7 @@ def cat_rows(
815781
816782 if not generate_roots and generate_inv_roots :
817783 warnings .warn (
818- "root_inv_decomposition is only generated when " "root_decomposition is generated." ,
819- UserWarning ,
784+ "root_inv_decomposition is only generated when " "root_decomposition is generated." , UserWarning ,
820785 )
821786 B_ , B = cross_mat , lazify (cross_mat )
822787 D = lazify (new_mat )
@@ -835,11 +800,7 @@ def cat_rows(
835800 # if the old lazy tensor does not have either a root decomposition or a root inverse decomposition
836801 # don't create one
837802 has_roots = any (
838- _is_in_cache_ignore_args (self , key )
839- for key in (
840- "root_decomposition" ,
841- "root_inv_decomposition" ,
842- )
803+ _is_in_cache_ignore_args (self , key ) for key in ("root_decomposition" , "root_inv_decomposition" ,)
843804 )
844805 if not generate_roots and not has_roots :
845806 return new_lazy_tensor
@@ -870,9 +831,7 @@ def cat_rows(
870831 # otherwise we use the pseudo-inverse of Z as new inv root
871832 new_inv_root = stable_pinverse (new_root ).transpose (- 2 , - 1 )
872833 add_to_cache (
873- new_lazy_tensor ,
874- "root_inv_decomposition" ,
875- RootLazyTensor (lazify (new_inv_root )),
834+ new_lazy_tensor , "root_inv_decomposition" , RootLazyTensor (lazify (new_inv_root )),
876835 )
877836
878837 add_to_cache (new_lazy_tensor , "root_decomposition" , RootLazyTensor (lazify (new_root )))
@@ -917,8 +876,7 @@ def add_low_rank(
917876 new_lazy_tensor = self + lazify (low_rank_mat .matmul (low_rank_mat .transpose (- 1 , - 2 )))
918877 else :
919878 new_lazy_tensor = SumLazyTensor (
920- * self .lazy_tensors ,
921- lazify (low_rank_mat .matmul (low_rank_mat .transpose (- 1 , - 2 ))),
879+ * self .lazy_tensors , lazify (low_rank_mat .matmul (low_rank_mat .transpose (- 1 , - 2 ))),
922880 )
923881
924882 # return as a nonlazy tensor if small enough to reduce memory overhead
@@ -966,12 +924,7 @@ def add_low_rank(
966924 updated_root = torch .cat (
967925 (
968926 current_root .evaluate (),
969- torch .zeros (
970- * current_root .shape [:- 1 ],
971- 1 ,
972- device = current_root .device ,
973- dtype = current_root .dtype ,
974- ),
927+ torch .zeros (* current_root .shape [:- 1 ], 1 , device = current_root .device , dtype = current_root .dtype ,),
975928 ),
976929 dim = - 1 ,
977930 )
@@ -1231,13 +1184,7 @@ def inv_matmul(self, right_tensor, left_tensor=None):
12311184 if left_tensor is None :
12321185 return func .apply (self .representation_tree (), False , right_tensor , * self .representation ())
12331186 else :
1234- return func .apply (
1235- self .representation_tree (),
1236- True ,
1237- left_tensor ,
1238- right_tensor ,
1239- * self .representation (),
1240- )
1187+ return func .apply (self .representation_tree (), True , left_tensor , right_tensor , * self .representation (),)
12411188
12421189 def inv_quad (self , tensor , reduce_inv_quad = True ):
12431190 """
@@ -1304,11 +1251,7 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True)
13041251 will_need_cholesky = False
13051252 if will_need_cholesky :
13061253 cholesky = CholLazyTensor (TriangularLazyTensor (self .cholesky ()))
1307- return cholesky .inv_quad_logdet (
1308- inv_quad_rhs = inv_quad_rhs ,
1309- logdet = logdet ,
1310- reduce_inv_quad = reduce_inv_quad ,
1311- )
1254+ return cholesky .inv_quad_logdet (inv_quad_rhs = inv_quad_rhs , logdet = logdet , reduce_inv_quad = reduce_inv_quad ,)
13121255
13131256 # Default: use modified batch conjugate gradients to compute these terms
13141257 # See NeurIPS 2018 paper: https://arxiv.org/abs/1809.11165
@@ -1700,8 +1643,7 @@ def root_decomposition(self, method: Optional[str] = None):
17001643 return CholLazyTensor (res )
17011644 except RuntimeError as e :
17021645 warnings .warn (
1703- f"Runtime Error when computing Cholesky decomposition: { e } . Using symeig method." ,
1704- NumericalWarning ,
1646+ f"Runtime Error when computing Cholesky decomposition: { e } . Using symeig method." , NumericalWarning ,
17051647 )
17061648 method = "symeig"
17071649
@@ -2056,11 +1998,7 @@ def zero_mean_mvn_samples(self, num_samples):
20561998
20571999 if settings .ciq_samples .on ():
20582000 base_samples = torch .randn (
2059- * self .batch_shape ,
2060- self .size (- 1 ),
2061- num_samples ,
2062- dtype = self .dtype ,
2063- device = self .device ,
2001+ * self .batch_shape , self .size (- 1 ), num_samples , dtype = self .dtype , device = self .device ,
20642002 )
20652003 base_samples = base_samples .permute (- 1 , * range (self .dim () - 1 )).contiguous ()
20662004 base_samples = base_samples .unsqueeze (- 1 )
@@ -2080,11 +2018,7 @@ def zero_mean_mvn_samples(self, num_samples):
20802018 covar_root = self .root_decomposition ().root
20812019
20822020 base_samples = torch .randn (
2083- * self .batch_shape ,
2084- covar_root .size (- 1 ),
2085- num_samples ,
2086- dtype = self .dtype ,
2087- device = self .device ,
2021+ * self .batch_shape , covar_root .size (- 1 ), num_samples , dtype = self .dtype , device = self .device ,
20882022 )
20892023 samples = covar_root .matmul (base_samples ).permute (- 1 , * range (self .dim () - 1 )).contiguous ()
20902024
@@ -2205,11 +2139,9 @@ def __getitem__(self, index):
22052139 # Alternatively, if we're using tensor indices and losing dimensions, use self._get_indices
22062140 if row_col_are_absorbed :
22072141 # Convert all indices into tensor indices
2208- (
2209- * batch_indices ,
2210- row_index ,
2211- col_index ,
2212- ) = _convert_indices_to_tensors (self , (* batch_indices , row_index , col_index ))
2142+ (* batch_indices , row_index , col_index ,) = _convert_indices_to_tensors (
2143+ self , (* batch_indices , row_index , col_index )
2144+ )
22132145 res = self ._get_indices (row_index , col_index , * batch_indices )
22142146 else :
22152147 res = self ._getitem (row_index , col_index , * batch_indices )
0 commit comments