Skip to content

Commit 1f09133

Browse files
committed
apply type hint convention for python >= 3.10
1 parent 11e6ce7 commit 1f09133

File tree

7 files changed

+33
-44
lines changed

7 files changed

+33
-44
lines changed

linopy/common.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def get_from_iterable(lst: DimsLike | None, index: int) -> Any | None:
117117
"""
118118
if lst is None:
119119
return None
120-
if isinstance(lst, (Sequence, Iterable)):
120+
if isinstance(lst, Sequence | Iterable):
121121
lst = list(lst)
122122
else:
123123
lst = [lst]
@@ -210,7 +210,7 @@ def numpy_to_dataarray(
210210
return DataArray(arr.item(), coords=coords, dims=dims, **kwargs)
211211

212212
ndim = max(arr.ndim, 0 if coords is None else len(coords))
213-
if isinstance(dims, (Iterable, Sequence)):
213+
if isinstance(dims, Iterable | Sequence):
214214
dims = list(dims)
215215
elif dims is not None:
216216
dims = [dims]
@@ -250,11 +250,11 @@ def as_dataarray(
250250
DataArray:
251251
The converted DataArray.
252252
"""
253-
if isinstance(arr, (pd.Series, pd.DataFrame)):
253+
if isinstance(arr, pd.Series | pd.DataFrame):
254254
arr = pandas_to_dataarray(arr, coords=coords, dims=dims, **kwargs)
255255
elif isinstance(arr, np.ndarray):
256256
arr = numpy_to_dataarray(arr, coords=coords, dims=dims, **kwargs)
257-
elif isinstance(arr, (np.number, int, float, str, bool, list)):
257+
elif isinstance(arr, np.number | int | float | str | bool | list):
258258
arr = DataArray(arr, coords=coords, dims=dims, **kwargs)
259259

260260
elif not isinstance(arr, DataArray):
@@ -493,7 +493,7 @@ def fill_missing_coords(
493493
494494
"""
495495
ds = ds.copy()
496-
if not isinstance(ds, (Dataset, DataArray)):
496+
if not isinstance(ds, Dataset | DataArray):
497497
raise TypeError(f"Expected xarray.DataArray or xarray.Dataset, got {type(ds)}.")
498498

499499
skip_dims = [] if fill_helper_dims else HELPER_DIMS
@@ -807,7 +807,7 @@ def print_coord(coord: dict[str, Any] | Iterable[Any]) -> str:
807807
# Convert each coordinate component to string
808808
formatted = []
809809
for value in values:
810-
if isinstance(value, (list, tuple)):
810+
if isinstance(value, list | tuple):
811811
formatted.append(f"({', '.join(str(x) for x in value)})")
812812
else:
813813
formatted.append(str(value))
@@ -946,11 +946,9 @@ def is_constant(func: Callable[..., Any]) -> Callable[..., Any]:
946946
def wrapper(self: Any, arg: Any) -> Any:
947947
if isinstance(
948948
arg,
949-
(
950-
variables.Variable,
951-
variables.ScalarVariable,
952-
expressions.LinearExpression,
953-
),
949+
variables.Variable
950+
| variables.ScalarVariable
951+
| expressions.LinearExpression,
954952
):
955953
raise TypeError(f"Assigned rhs must be a constant, got {type(arg)}).")
956954
return func(self, arg)
@@ -1061,7 +1059,7 @@ def align(
10611059
finisher: list[partial[Any] | Callable[[Any], Any]] = []
10621060
das: list[Any] = []
10631061
for obj in objects:
1064-
if isinstance(obj, (LinearExpression, QuadraticExpression)):
1062+
if isinstance(obj, LinearExpression | QuadraticExpression):
10651063
finisher.append(partial(obj.__class__, model=obj.model))
10661064
das.append(obj.data)
10671065
elif isinstance(obj, Variable):

linopy/constraints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,7 @@ def get_name_by_label(self, label: int | float) -> str:
940940
name : str
941941
Name of the containing constraint.
942942
"""
943-
if not isinstance(label, (float, int)) or label < 0:
943+
if not isinstance(label, float | int) or label < 0:
944944
raise ValueError("Label must be a positive number.")
945945
for name, ds in self.items():
946946
if label in ds.labels:
@@ -1084,7 +1084,7 @@ def __init__(
10841084
"""
10851085
Initialize a anonymous scalar constraint.
10861086
"""
1087-
if not isinstance(rhs, (int, float, np.floating, np.integer)):
1087+
if not isinstance(rhs, int | float | np.floating | np.integer):
10881088
raise TypeError(f"Assigned rhs must be a constant, got {type(rhs)}).")
10891089
self._lhs = lhs
10901090
self._sign = sign

linopy/matrices.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def create_vector(
3232
"""Create a vector of a size equal to the maximum index plus one."""
3333
if shape is None:
3434
max_value = indices.max()
35-
if not isinstance(max_value, (np.integer, int)):
35+
if not isinstance(max_value, np.integer | int):
3636
raise ValueError("Indices must be integers.")
3737
shape = max_value + 1
3838
vector = np.full(shape, fill_value)

linopy/monkey_patch_xarray.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,9 @@ def __mul__(
2929
) -> DataArray | NotImplementedType:
3030
if isinstance(
3131
other,
32-
(
33-
variables.Variable,
34-
expressions.LinearExpression,
35-
expressions.QuadraticExpression,
36-
),
32+
variables.Variable
33+
| expressions.LinearExpression
34+
| expressions.QuadraticExpression,
3735
):
3836
return NotImplemented
3937
return unpatched_method(da, other)

linopy/objective.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,11 @@ def expression(
175175
"""
176176
Sets the expression of the objective.
177177
"""
178-
if isinstance(expr, (list, tuple)):
178+
if isinstance(expr, list | tuple):
179179
expr = self.model.linexpr(*expr)
180180

181181
if not isinstance(
182-
expr, (expressions.LinearExpression, expressions.QuadraticExpression)
182+
expr, expressions.LinearExpression | expressions.QuadraticExpression
183183
):
184184
raise ValueError(
185185
f"Invalid type of `expr` ({type(expr)})."
@@ -263,7 +263,7 @@ def __sub__(self, expr: LinearExpression | Objective) -> Objective:
263263

264264
def __mul__(self, expr: ConstantLike) -> Objective:
265265
# only allow scalar multiplication
266-
if not isinstance(expr, (int, float, np.floating, np.integer)):
266+
if not isinstance(expr, int | float | np.floating | np.integer):
267267
raise ValueError("Invalid type for multiplication.")
268268
return Objective(self.expression * expr, self.model, self.sense)
269269

@@ -272,6 +272,6 @@ def __neg__(self) -> Objective:
272272

273273
def __truediv__(self, expr: ConstantLike) -> Objective:
274274
# only allow scalar division
275-
if not isinstance(expr, (int, float, np.floating, np.integer)):
275+
if not isinstance(expr, int | float | np.floating | np.integer):
276276
raise ValueError("Invalid type for division.")
277277
return Objective(self.expression / expr, self.model, self.sense)

linopy/types.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import sys
43
from collections.abc import Hashable, Iterable, Mapping, Sequence
54
from pathlib import Path
65
from typing import TYPE_CHECKING, Union
@@ -11,12 +10,6 @@
1110
from xarray import DataArray
1211
from xarray.core.coordinates import DataArrayCoordinates, DatasetCoordinates
1312

14-
if sys.version_info >= (3, 10):
15-
from types import EllipsisType, NotImplementedType
16-
else:
17-
EllipsisType = type(Ellipsis)
18-
NotImplementedType = type(NotImplemented)
19-
2013
if TYPE_CHECKING:
2114
from linopy.constraints import AnonymousScalarConstraint, Constraint
2215
from linopy.expressions import (
@@ -27,15 +20,15 @@
2720
from linopy.variables import ScalarVariable, Variable
2821

2922
# Type aliases using Union for Python 3.9 compatibility
30-
CoordsLike = Union[
23+
CoordsLike = Union[ # noqa: UP007
3124
Sequence[Sequence | Index | DataArray],
3225
Mapping,
3326
DataArrayCoordinates,
3427
DatasetCoordinates,
3528
]
36-
DimsLike = Union[str, Iterable[Hashable]]
29+
DimsLike = Union[str, Iterable[Hashable]] # noqa: UP007
3730

38-
ConstantLike = Union[
31+
ConstantLike = Union[ # noqa: UP007
3932
int,
4033
float,
4134
numpy.floating,
@@ -45,14 +38,14 @@
4538
Series,
4639
DataFrame,
4740
]
48-
SignLike = Union[str, numpy.ndarray, DataArray, Series, DataFrame]
41+
SignLike = Union[str, numpy.ndarray, DataArray, Series, DataFrame] # noqa: UP007
4942
VariableLike = Union["ScalarVariable", "Variable"]
5043
ExpressionLike = Union[
5144
"ScalarLinearExpression",
5245
"LinearExpression",
5346
"QuadraticExpression",
5447
]
5548
ConstraintLike = Union["Constraint", "AnonymousScalarConstraint"]
56-
MaskLike = Union[numpy.ndarray, DataArray, Series, DataFrame]
57-
SideLike = Union[ConstantLike, VariableLike, ExpressionLike]
58-
PathLike = Union[str, Path]
49+
MaskLike = Union[numpy.ndarray, DataArray, Series, DataFrame] # noqa: UP007
50+
SideLike = Union[ConstantLike, VariableLike, ExpressionLike] # noqa: UP007
51+
PathLike = Union[str, Path] # noqa: UP007

linopy/variables.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def __mul__(self, other: SideLike) -> ExpressionLike:
400400
Multiply variables with a coefficient, variable, or expression.
401401
"""
402402
try:
403-
if isinstance(other, (Variable, ScalarVariable)):
403+
if isinstance(other, Variable | ScalarVariable):
404404
return self.to_linexpr() * other
405405

406406
return self.to_linexpr(other)
@@ -449,7 +449,7 @@ def __div__(
449449
"""
450450
Divide variables with a coefficient.
451451
"""
452-
if isinstance(other, (expressions.LinearExpression, Variable)):
452+
if isinstance(other, expressions.LinearExpression | Variable):
453453
raise TypeError(
454454
"unsupported operand type(s) for /: "
455455
f"{type(self)} and {type(other)}. "
@@ -1028,7 +1028,7 @@ def where(
10281028
_other = other.data
10291029
elif isinstance(other, ScalarVariable):
10301030
_other = {"labels": other.label, "lower": other.lower, "upper": other.upper}
1031-
elif isinstance(other, (dict, Dataset)):
1031+
elif isinstance(other, dict | Dataset):
10321032
_other = other
10331033
else:
10341034
raise ValueError(
@@ -1432,7 +1432,7 @@ def get_name_by_label(self, label: int) -> str:
14321432
name : str
14331433
Name of the containing variable.
14341434
"""
1435-
if not isinstance(label, (float, int, np.integer)) or label < 0:
1435+
if not isinstance(label, float | int | np.integer) or label < 0:
14361436
raise ValueError("Label must be a positive number.")
14371437
for name, labels in self.labels.items():
14381438
if label in labels:
@@ -1564,7 +1564,7 @@ def model(self) -> Model:
15641564
return self._model
15651565

15661566
def to_scalar_linexpr(self, coeff: int | float = 1) -> ScalarLinearExpression:
1567-
if not isinstance(coeff, (int, np.integer, float)):
1567+
if not isinstance(coeff, int | np.integer | float):
15681568
raise TypeError(f"Coefficient must be a numeric value, got {type(coeff)}.")
15691569
return expressions.ScalarLinearExpression((coeff,), (self.label,), self.model)
15701570

@@ -1588,7 +1588,7 @@ def __mul__(self, coeff: int | float) -> ScalarLinearExpression:
15881588
return self.to_scalar_linexpr(coeff)
15891589

15901590
def __rmul__(self, coeff: int | float) -> ScalarLinearExpression:
1591-
if isinstance(coeff, (Variable, ScalarVariable)):
1591+
if isinstance(coeff, Variable | ScalarVariable):
15921592
return NotImplemented
15931593
return self.to_scalar_linexpr(coeff)
15941594

0 commit comments

Comments
 (0)