Skip to content

Commit e36ca9b

Browse files
pre commit formatting
1 parent 0351a1a commit e36ca9b

File tree

1 file changed

+21
-89
lines changed

1 file changed

+21
-89
lines changed

gpytorch/lazy/lazy_tensor.py

Lines changed: 21 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,13 @@
1919
from ..functions._matmul import Matmul
2020
from ..functions._root_decomposition import RootDecomposition
2121
from ..functions._sqrt_inv_matmul import SqrtInvMatmul
22-
from ..utils.broadcasting import (
23-
_matmul_broadcast_shape,
24-
_mul_broadcast_shape,
25-
_to_helper,
26-
)
22+
from ..utils.broadcasting import _matmul_broadcast_shape, _mul_broadcast_shape, _to_helper
2723
from ..utils.cholesky import psd_safe_cholesky
2824
from ..utils.deprecation import _deprecate_renamed_methods
2925
from ..utils.errors import CachingError
30-
from ..utils.getitem import (
31-
_compute_getitem_size,
32-
_convert_indices_to_tensors,
33-
_is_noop_index,
34-
_noop_index,
35-
)
26+
from ..utils.getitem import _compute_getitem_size, _convert_indices_to_tensors, _is_noop_index, _noop_index
3627
from ..utils.lanczos import _postprocess_lanczos_root_inv_decomp
37-
from ..utils.memoize import (
38-
_is_in_cache_ignore_all_args,
39-
_is_in_cache_ignore_args,
40-
add_to_cache,
41-
cached,
42-
pop_from_cache,
43-
)
28+
from ..utils.memoize import _is_in_cache_ignore_all_args, _is_in_cache_ignore_args, add_to_cache, cached, pop_from_cache
4429
from ..utils.pinverse import stable_pinverse
4530
from ..utils.pivoted_cholesky import pivoted_cholesky
4631
from ..utils.warnings import NumericalWarning
@@ -261,11 +246,7 @@ def _getitem(self, row_index, col_index, *batch_indices):
261246
from . import InterpolatedLazyTensor
262247

263248
res = InterpolatedLazyTensor(
264-
self,
265-
row_interp_indices,
266-
row_interp_values,
267-
col_interp_indices,
268-
col_interp_values,
249+
self, row_interp_indices, row_interp_values, col_interp_indices, col_interp_values,
269250
)
270251
return res._getitem(row_index, col_index, *batch_indices)
271252

@@ -339,11 +320,7 @@ def _get_indices(self, row_index, col_index, *batch_indices):
339320

340321
res = (
341322
InterpolatedLazyTensor(
342-
base_lazy_tensor,
343-
row_interp_indices,
344-
row_interp_values,
345-
col_interp_indices,
346-
col_interp_values,
323+
base_lazy_tensor, row_interp_indices, row_interp_values, col_interp_indices, col_interp_values,
347324
)
348325
.evaluate()
349326
.squeeze(-2)
@@ -543,10 +520,7 @@ def _mul_matrix(self, other):
543520
else:
544521
left_lazy_tensor = self if self._root_decomposition_size() < other._root_decomposition_size() else other
545522
right_lazy_tensor = other if left_lazy_tensor is self else self
546-
return MulLazyTensor(
547-
left_lazy_tensor.root_decomposition(),
548-
right_lazy_tensor.root_decomposition(),
549-
)
523+
return MulLazyTensor(left_lazy_tensor.root_decomposition(), right_lazy_tensor.root_decomposition(),)
550524

551525
def _preconditioner(self):
552526
"""
@@ -587,10 +561,7 @@ def _prod_batch(self, dim):
587561
shape = list(roots.shape)
588562
shape[dim] = 1
589563
extra_root = torch.full(
590-
shape,
591-
dtype=self.dtype,
592-
device=self.device,
593-
fill_value=(1.0 / math.sqrt(self.size(-2))),
564+
shape, dtype=self.dtype, device=self.device, fill_value=(1.0 / math.sqrt(self.size(-2))),
594565
)
595566
roots = torch.cat([roots, extra_root], dim)
596567
num_batch += 1
@@ -767,12 +738,7 @@ def add_jitter(self, jitter_val=1e-3):
767738
return self.add_diag(diag)
768739

769740
def cat_rows(
770-
self,
771-
cross_mat,
772-
new_mat,
773-
generate_roots=True,
774-
generate_inv_roots=True,
775-
**root_decomp_kwargs,
741+
self, cross_mat, new_mat, generate_roots=True, generate_inv_roots=True, **root_decomp_kwargs,
776742
):
777743
"""
778744
Concatenates new rows and columns to the matrix that this LazyTensor represents, e.g.
@@ -815,8 +781,7 @@ def cat_rows(
815781

816782
if not generate_roots and generate_inv_roots:
817783
warnings.warn(
818-
"root_inv_decomposition is only generated when " "root_decomposition is generated.",
819-
UserWarning,
784+
"root_inv_decomposition is only generated when " "root_decomposition is generated.", UserWarning,
820785
)
821786
B_, B = cross_mat, lazify(cross_mat)
822787
D = lazify(new_mat)
@@ -835,11 +800,7 @@ def cat_rows(
835800
# if the old lazy tensor does not have either a root decomposition or a root inverse decomposition
836801
# don't create one
837802
has_roots = any(
838-
_is_in_cache_ignore_args(self, key)
839-
for key in (
840-
"root_decomposition",
841-
"root_inv_decomposition",
842-
)
803+
_is_in_cache_ignore_args(self, key) for key in ("root_decomposition", "root_inv_decomposition",)
843804
)
844805
if not generate_roots and not has_roots:
845806
return new_lazy_tensor
@@ -870,9 +831,7 @@ def cat_rows(
870831
# otherwise we use the pseudo-inverse of Z as new inv root
871832
new_inv_root = stable_pinverse(new_root).transpose(-2, -1)
872833
add_to_cache(
873-
new_lazy_tensor,
874-
"root_inv_decomposition",
875-
RootLazyTensor(lazify(new_inv_root)),
834+
new_lazy_tensor, "root_inv_decomposition", RootLazyTensor(lazify(new_inv_root)),
876835
)
877836

878837
add_to_cache(new_lazy_tensor, "root_decomposition", RootLazyTensor(lazify(new_root)))
@@ -917,8 +876,7 @@ def add_low_rank(
917876
new_lazy_tensor = self + lazify(low_rank_mat.matmul(low_rank_mat.transpose(-1, -2)))
918877
else:
919878
new_lazy_tensor = SumLazyTensor(
920-
*self.lazy_tensors,
921-
lazify(low_rank_mat.matmul(low_rank_mat.transpose(-1, -2))),
879+
*self.lazy_tensors, lazify(low_rank_mat.matmul(low_rank_mat.transpose(-1, -2))),
922880
)
923881

924882
# return as a nonlazy tensor if small enough to reduce memory overhead
@@ -966,12 +924,7 @@ def add_low_rank(
966924
updated_root = torch.cat(
967925
(
968926
current_root.evaluate(),
969-
torch.zeros(
970-
*current_root.shape[:-1],
971-
1,
972-
device=current_root.device,
973-
dtype=current_root.dtype,
974-
),
927+
torch.zeros(*current_root.shape[:-1], 1, device=current_root.device, dtype=current_root.dtype,),
975928
),
976929
dim=-1,
977930
)
@@ -1231,13 +1184,7 @@ def inv_matmul(self, right_tensor, left_tensor=None):
12311184
if left_tensor is None:
12321185
return func.apply(self.representation_tree(), False, right_tensor, *self.representation())
12331186
else:
1234-
return func.apply(
1235-
self.representation_tree(),
1236-
True,
1237-
left_tensor,
1238-
right_tensor,
1239-
*self.representation(),
1240-
)
1187+
return func.apply(self.representation_tree(), True, left_tensor, right_tensor, *self.representation(),)
12411188

12421189
def inv_quad(self, tensor, reduce_inv_quad=True):
12431190
"""
@@ -1304,11 +1251,7 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True)
13041251
will_need_cholesky = False
13051252
if will_need_cholesky:
13061253
cholesky = CholLazyTensor(TriangularLazyTensor(self.cholesky()))
1307-
return cholesky.inv_quad_logdet(
1308-
inv_quad_rhs=inv_quad_rhs,
1309-
logdet=logdet,
1310-
reduce_inv_quad=reduce_inv_quad,
1311-
)
1254+
return cholesky.inv_quad_logdet(inv_quad_rhs=inv_quad_rhs, logdet=logdet, reduce_inv_quad=reduce_inv_quad,)
13121255

13131256
# Default: use modified batch conjugate gradients to compute these terms
13141257
# See NeurIPS 2018 paper: https://arxiv.org/abs/1809.11165
@@ -1700,8 +1643,7 @@ def root_decomposition(self, method: Optional[str] = None):
17001643
return CholLazyTensor(res)
17011644
except RuntimeError as e:
17021645
warnings.warn(
1703-
f"Runtime Error when computing Cholesky decomposition: {e}. Using symeig method.",
1704-
NumericalWarning,
1646+
f"Runtime Error when computing Cholesky decomposition: {e}. Using symeig method.", NumericalWarning,
17051647
)
17061648
method = "symeig"
17071649

@@ -2056,11 +1998,7 @@ def zero_mean_mvn_samples(self, num_samples):
20561998

20571999
if settings.ciq_samples.on():
20582000
base_samples = torch.randn(
2059-
*self.batch_shape,
2060-
self.size(-1),
2061-
num_samples,
2062-
dtype=self.dtype,
2063-
device=self.device,
2001+
*self.batch_shape, self.size(-1), num_samples, dtype=self.dtype, device=self.device,
20642002
)
20652003
base_samples = base_samples.permute(-1, *range(self.dim() - 1)).contiguous()
20662004
base_samples = base_samples.unsqueeze(-1)
@@ -2080,11 +2018,7 @@ def zero_mean_mvn_samples(self, num_samples):
20802018
covar_root = self.root_decomposition().root
20812019

20822020
base_samples = torch.randn(
2083-
*self.batch_shape,
2084-
covar_root.size(-1),
2085-
num_samples,
2086-
dtype=self.dtype,
2087-
device=self.device,
2021+
*self.batch_shape, covar_root.size(-1), num_samples, dtype=self.dtype, device=self.device,
20882022
)
20892023
samples = covar_root.matmul(base_samples).permute(-1, *range(self.dim() - 1)).contiguous()
20902024

@@ -2205,11 +2139,9 @@ def __getitem__(self, index):
22052139
# Alternatively, if we're using tensor indices and losing dimensions, use self._get_indices
22062140
if row_col_are_absorbed:
22072141
# Convert all indices into tensor indices
2208-
(
2209-
*batch_indices,
2210-
row_index,
2211-
col_index,
2212-
) = _convert_indices_to_tensors(self, (*batch_indices, row_index, col_index))
2142+
(*batch_indices, row_index, col_index,) = _convert_indices_to_tensors(
2143+
self, (*batch_indices, row_index, col_index)
2144+
)
22132145
res = self._get_indices(row_index, col_index, *batch_indices)
22142146
else:
22152147
res = self._getitem(row_index, col_index, *batch_indices)

0 commit comments

Comments
 (0)