Skip to content

Commit c72d9f5

Browse files
authored
Upgrade to use py3.10 features (#2712)
1 parent 67f59e4 commit c72d9f5

File tree

197 files changed

+729
-509
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

197 files changed

+729
-509
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ repos:
1111
args: [--fix=lf]
1212
- id: trailing-whitespace
1313
- id: debug-statements
14+
- repo: https://github.com/asottile/pyupgrade
15+
rev: v3.19.1
16+
hooks:
17+
- id: pyupgrade
18+
args: [--py310-plus]
19+
exclude: ^(examples/.*)|(docs/.*)
1420
- repo: https://github.com/pycqa/flake8
1521
rev: 7.3.0
1622
hooks:

gpytorch/__init__.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22

3-
from typing import Optional, Tuple, Union
3+
from __future__ import annotations
44

55
import linear_operator
66
import torch
@@ -28,7 +28,7 @@
2828
from .mlls import ExactMarginalLogLikelihood
2929
from .module import Module
3030

31-
Anysor = Union[LinearOperator, Tensor]
31+
Anysor = LinearOperator | Tensor
3232

3333

3434
def add_diagonal(input: Anysor, diag: Tensor) -> LinearOperator:
@@ -58,7 +58,7 @@ def add_jitter(input: Anysor, jitter_val: float = 1e-3) -> Anysor:
5858
return linear_operator.add_jitter(input=input, jitter_val=jitter_val)
5959

6060

61-
def diagonalization(input: Anysor, method: Optional[str] = None) -> Tuple[Tensor, Tensor]:
61+
def diagonalization(input: Anysor, method: str | None = None) -> tuple[Tensor, Tensor]:
6262
r"""
6363
Returns a (usually partial) diagonalization of a symmetric positive definite matrix (or batch of matrices).
6464
:math:`\mathbf A`.
@@ -74,7 +74,7 @@ def diagonalization(input: Anysor, method: Optional[str] = None) -> Tuple[Tensor
7474

7575

7676
def dsmm(
77-
sparse_mat: Union[torch.sparse.HalfTensor, torch.sparse.FloatTensor, torch.sparse.DoubleTensor],
77+
sparse_mat: torch.sparse.HalfTensor | torch.sparse.FloatTensor | torch.sparse.DoubleTensor,
7878
dense_mat: Tensor,
7979
) -> Tensor:
8080
r"""
@@ -117,10 +117,10 @@ def inv_quad(input: Anysor, inv_quad_rhs: Tensor, reduce_inv_quad: bool = True)
117117

118118
def inv_quad_logdet(
119119
input: Anysor,
120-
inv_quad_rhs: Optional[Tensor] = None,
120+
inv_quad_rhs: Tensor | None = None,
121121
logdet: bool = False,
122122
reduce_inv_quad: bool = True,
123-
) -> Tuple[Tensor, Tensor]:
123+
) -> tuple[Tensor, Tensor]:
124124
r"""
125125
Calls both :func:`inv_quad_logdet` and :func:`logdet` on a positive definite matrix (or batch) :math:`\mathbf A`.
126126
However, calling this method is far more efficient and stable than calling each method independently.
@@ -146,9 +146,9 @@ def inv_quad_logdet(
146146
def pivoted_cholesky(
147147
input: Anysor,
148148
rank: int,
149-
error_tol: Optional[float] = None,
149+
error_tol: float | None = None,
150150
return_pivots: bool = False,
151-
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
151+
) -> Tensor | tuple[Tensor, Tensor]:
152152
r"""
153153
Performs a partial pivoted Cholesky factorization of a positive definite matrix (or batch of matrices).
154154
:math:`\mathbf L \mathbf L^\top = \mathbf A`.
@@ -173,7 +173,7 @@ def pivoted_cholesky(
173173
return linear_operator.pivoted_cholesky(input=input, rank=rank, return_pivots=return_pivots)
174174

175175

176-
def root_decomposition(input: Anysor, method: Optional[str] = None) -> LinearOperator:
176+
def root_decomposition(input: Anysor, method: str | None = None) -> LinearOperator:
177177
r"""
178178
Returns a (usually low-rank) root decomposition linear operator of the
179179
positive definite matrix (or batch of matrices) :math:`\mathbf A`.
@@ -190,9 +190,9 @@ def root_decomposition(input: Anysor, method: Optional[str] = None) -> LinearOpe
190190

191191
def root_inv_decomposition(
192192
input: Anysor,
193-
initial_vectors: Optional[Tensor] = None,
194-
test_vectors: Optional[Tensor] = None,
195-
method: Optional[str] = None,
193+
initial_vectors: Tensor | None = None,
194+
test_vectors: Tensor | None = None,
195+
method: str | None = None,
196196
) -> LinearOperator:
197197
r"""
198198
Returns a (usually low-rank) inverse root decomposition linear operator
@@ -217,7 +217,7 @@ def root_inv_decomposition(
217217
)
218218

219219

220-
def solve(input: Anysor, rhs: Tensor, lhs: Optional[Tensor] = None) -> Tensor:
220+
def solve(input: Anysor, rhs: Tensor, lhs: Tensor | None = None) -> Tensor:
221221
r"""
222222
Given a positive definite matrix (or batch of matrices) :math:`\mathbf A`,
223223
computes a linear solve with right hand side :math:`\mathbf R`:
@@ -249,7 +249,7 @@ def solve(input: Anysor, rhs: Tensor, lhs: Optional[Tensor] = None) -> Tensor:
249249
return linear_operator.solve(input=input, rhs=rhs, lhs=lhs)
250250

251251

252-
def sqrt_inv_matmul(input: Anysor, rhs: Tensor, lhs: Optional[Tensor] = None) -> Tensor:
252+
def sqrt_inv_matmul(input: Anysor, rhs: Tensor, lhs: Tensor | None = None) -> Tensor:
253253
r"""
254254
Given a positive definite matrix (or batch of matrices) :math:`\mathbf A`
255255
and a right hand size :math:`\mathbf R`,

gpytorch/beta_features.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
#!/usr/bin/env python3
22

3+
from __future__ import annotations
4+
35
import warnings
46

57
from .settings import _feature_flag, _value_context
68

79

8-
class _moved_beta_feature(object):
10+
class _moved_beta_feature:
911
def __init__(self, new_cls, orig_name=None):
1012
self.new_cls = new_cls
11-
self.orig_name = orig_name if orig_name is not None else "gpytorch.settings.{}".format(new_cls.__name__)
13+
self.orig_name = orig_name if orig_name is not None else f"gpytorch.settings.{new_cls.__name__}"
1214

1315
def __call__(self, *args, **kwargs):
1416
warnings.warn(
15-
"`{}` has moved to `gpytorch.settings.{}`.".format(self.orig_name, self.new_cls.__name__),
17+
f"`{self.orig_name}` has moved to `gpytorch.settings.{self.new_cls.__name__}`.",
1618
DeprecationWarning,
1719
)
1820
return self.new_cls(*args, **kwargs)
@@ -55,7 +57,5 @@ class default_preconditioner(_feature_flag):
5557
Add a diagonal correction to scalable inducing point methods
5658
"""
5759

58-
pass
59-
6060

6161
__all__ = ["checkpoint_kernel", "default_preconditioner"]

gpytorch/constraints/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from .constraints import GreaterThan, Interval, LessThan, Positive
24

35
__all__ = ["GreaterThan", "Interval", "LessThan", "Positive"]

gpytorch/constraints/constraints.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import math
6-
from typing import Optional
76

87
import torch
98
from torch import sigmoid, Tensor
@@ -87,7 +86,7 @@ def check(self, tensor) -> bool:
8786

8887
def check_raw(self, tensor) -> bool:
8988
return bool(
90-
torch.all((self.transform(tensor) <= self.upper_bound))
89+
torch.all(self.transform(tensor) <= self.upper_bound)
9190
and torch.all(self.transform(tensor) >= self.lower_bound)
9291
)
9392

@@ -137,7 +136,7 @@ def inverse_transform(self, transformed_tensor: Tensor) -> Tensor:
137136
return tensor
138137

139138
@property
140-
def initial_value(self) -> Optional[Tensor]:
139+
def initial_value(self) -> Tensor | None:
141140
"""
142141
The initial parameter value (if specified, None otherwise)
143142
"""

gpytorch/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python3
22

3+
from __future__ import annotations
4+
35
from .delta import Delta
46
from .distribution import Distribution
57
from .multitask_multivariate_normal import MultitaskMultivariateNormal

gpytorch/distributions/delta.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python3
22

3+
from __future__ import annotations
4+
35
import numbers
46

57
import torch
@@ -34,14 +36,14 @@ class Delta(Distribution):
3436

3537
def __init__(self, v, log_density=0.0, event_dim=0, validate_args=None):
3638
if event_dim > v.dim():
37-
raise ValueError("Expected event_dim <= v.dim(), actual {} vs {}".format(event_dim, v.dim()))
39+
raise ValueError(f"Expected event_dim <= v.dim(), actual {event_dim} vs {v.dim()}")
3840
batch_dim = v.dim() - event_dim
3941
batch_shape = v.shape[:batch_dim]
4042
event_shape = v.shape[batch_dim:]
4143
if isinstance(log_density, numbers.Number):
4244
log_density = torch.full(batch_shape, log_density, dtype=v.dtype, device=v.device)
4345
elif validate_args and log_density.shape != batch_shape:
44-
raise ValueError("Expected log_density.shape = {}, actual {}".format(log_density.shape, batch_shape))
46+
raise ValueError(f"Expected log_density.shape = {log_density.shape}, actual {batch_shape}")
4547
self.v = v
4648
self.log_density = log_density
4749
super().__init__(batch_shape, event_shape, validate_args=validate_args)

gpytorch/distributions/distribution.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python3
22

3+
from __future__ import annotations
4+
35
from torch.distributions import Distribution as TDistribution
46

57

gpytorch/distributions/multitask_multivariate_normal.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python3
22

3+
from __future__ import annotations
4+
35
import torch
46
from linear_operator import LinearOperator, to_linear_operator
57
from linear_operator.operators import (

gpytorch/distributions/multivariate_normal.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import math
66
import warnings
77
from numbers import Number
8-
from typing import Optional, Tuple, Union
98

109
import torch
1110
from linear_operator import to_dense, to_linear_operator
@@ -42,7 +41,7 @@ class MultivariateNormal(TMultivariateNormal, Distribution):
4241
:ivar torch.Tensor variance: The variance.
4342
"""
4443

45-
def __init__(self, mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator], validate_args: bool = False):
44+
def __init__(self, mean: Tensor, covariance_matrix: Tensor | LinearOperator, validate_args: bool = False):
4645
self._islazy = isinstance(mean, LinearOperator) or isinstance(covariance_matrix, LinearOperator)
4746
if self._islazy:
4847
if validate_args:
@@ -78,7 +77,7 @@ def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size
7877
return sample_shape + self._batch_shape + self.base_sample_shape
7978

8079
@staticmethod
81-
def _repr_sizes(mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator]) -> str:
80+
def _repr_sizes(mean: Tensor, covariance_matrix: Tensor | LinearOperator) -> str:
8281
return f"MultivariateNormal(loc: {mean.size()}, scale: {covariance_matrix.size()})"
8382

8483
@property
@@ -119,7 +118,7 @@ def covariance_matrix(self) -> Tensor:
119118
else:
120119
return super().covariance_matrix
121120

122-
def confidence_region(self) -> Tuple[Tensor, Tensor]:
121+
def confidence_region(self) -> tuple[Tensor, Tensor]:
123122
"""
124123
Returns 2 standard deviations above and below the mean.
125124
@@ -252,7 +251,7 @@ def log_prob(self, value: Tensor) -> Tensor:
252251
res = -0.5 * sum([inv_quad, logdet, diff.size(-1) * math.log(2 * math.pi)])
253252
return res
254253

255-
def rsample(self, sample_shape: torch.Size = torch.Size(), base_samples: Optional[Tensor] = None) -> Tensor:
254+
def rsample(self, sample_shape: torch.Size = torch.Size(), base_samples: Tensor | None = None) -> Tensor:
256255
r"""
257256
Generates a `sample_shape` shaped reparameterized sample or `sample_shape`
258257
shaped batch of reparameterized samples if the distribution parameters
@@ -320,7 +319,7 @@ def rsample(self, sample_shape: torch.Size = torch.Size(), base_samples: Optiona
320319

321320
return res
322321

323-
def sample(self, sample_shape: torch.Size = torch.Size(), base_samples: Optional[Tensor] = None) -> Tensor:
322+
def sample(self, sample_shape: torch.Size = torch.Size(), base_samples: Tensor | None = None) -> Tensor:
324323
r"""
325324
Generates a `sample_shape` shaped sample or `sample_shape`
326325
shaped batch of samples if the distribution parameters
@@ -391,7 +390,7 @@ def __add__(self, other: MultivariateNormal) -> MultivariateNormal:
391390
elif isinstance(other, int) or isinstance(other, float):
392391
return self.__class__(self.mean + other, self.lazy_covariance_matrix)
393392
else:
394-
raise RuntimeError("Unsupported type {} for addition w/ MultivariateNormal".format(type(other)))
393+
raise RuntimeError(f"Unsupported type {type(other)} for addition w/ MultivariateNormal")
395394

396395
def __getitem__(self, idx) -> MultivariateNormal:
397396
r"""

0 commit comments

Comments
 (0)