Skip to content

Commit 8079de6

Browse files
committed
Use PEP 695 style generic types
1 parent 0aaf3fa commit 8079de6

File tree

7 files changed

+39
-45
lines changed

7 files changed

+39
-45
lines changed

src/emsarray/cli/commands/plot.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,30 @@
33
import logging
44
from collections.abc import Callable
55
from pathlib import Path
6-
from typing import Any, TypeVar
6+
from typing import Any, overload
77

88
import emsarray
99
from emsarray.cli import BaseCommand, CommandException
1010

11-
T = TypeVar('T')
12-
1311
logger = logging.getLogger(__name__)
1412

1513

16-
def key_value(arg: str, value_type: Callable = str) -> dict[str, T]:
14+
@overload
15+
def key_value(arg: str) -> dict[str, str]: ... # noqa: E704
16+
@overload
17+
def key_value[T](arg: str, value_type: Callable[[str], T]) -> dict[str, T]: ... # noqa: E704
18+
19+
20+
def key_value[T](arg: str, value_type: Callable[[str], T] | None = None) -> dict[str, T] | dict[str, str]:
1721
try:
1822
name, value = arg.split("=", 2)
1923
except ValueError:
2024
raise argparse.ArgumentTypeError(
2125
"Coordinate / dimension indexes must be given as `name=value` pairs")
22-
return {name: value_type(value)}
26+
if value_type is None:
27+
return {name: value}
28+
else:
29+
return {name: value_type(value)}
2330

2431

2532
class UpdateDict(argparse.Action):

src/emsarray/conventions/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
Refer to each Convention implementation for details.
1515
"""
1616
from ._base import (
17-
Convention, DimensionConvention, GridKind, Index, SpatialIndexItem,
18-
Specificity
17+
Convention, DimensionConvention, SpatialIndexItem, Specificity
1918
)
2019
from ._registry import get_dataset_convention, register_convention
2120
from ._utils import open_dataset
@@ -25,7 +24,7 @@
2524
from .ugrid import UGrid
2625

2726
__all__ = [
28-
"Convention", "DimensionConvention", "GridKind", "Index",
27+
"Convention", "DimensionConvention",
2928
"SpatialIndexItem", "Specificity",
3029
"get_dataset_convention", "register_convention",
3130
"open_dataset",

src/emsarray/conventions/_base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import warnings
77
from collections.abc import Callable, Hashable, Iterable, Sequence
88
from functools import cached_property
9-
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
9+
from typing import TYPE_CHECKING, Any, Literal, cast
1010

1111
import numpy
1212
import shapely
@@ -60,17 +60,17 @@
6060
#:
6161
#: More esoteric cases involving datasets with a potentially unbounded numbers of grids
6262
#: can use a type that supports this instead.
63-
GridKind = TypeVar("GridKind")
63+
# GridKind = TypeVar("GridKind")
6464

6565
#: An :ref:`index <indexing>` to a specific point on a grid in this convention.
6666
#: For conventions with :ref:`multiple grids <grids>` (e.g. cells, edges, and nodes),
6767
#: this should be a tuple whos first element is :data:`.GridKind`.
6868
#: For conventions with a single grid, :data:`.GridKind` is not required.
69-
Index = TypeVar("Index")
69+
# Index = TypeVar("Index")
7070

7171

7272
@dataclasses.dataclass
73-
class SpatialIndexItem(Generic[Index]):
73+
class SpatialIndexItem[Index]:
7474
"""Information about an item in the :class:`~shapely.strtree.STRtree`
7575
spatial index for a dataset.
7676
@@ -124,7 +124,7 @@ class Specificity(enum.IntEnum):
124124
HIGH = 30
125125

126126

127-
class Convention(abc.ABC, Generic[GridKind, Index]):
127+
class Convention[GridKind, Index](abc.ABC):
128128
"""
129129
Each supported geometry convention represents data differently.
130130
The :class:`Convention` class abstracts these differences away,
@@ -1749,7 +1749,7 @@ def hash_geometry(self, hash: "hashlib._Hash") -> None:
17491749
hash_attributes(hash, data_array.attrs)
17501750

17511751

1752-
class DimensionConvention(Convention[GridKind, Index]):
1752+
class DimensionConvention[GridKind, Index](Convention[GridKind, Index]):
17531753
"""
17541754
A Convention subclass where different grid kinds
17551755
are always defined on unique sets of dimension.

src/emsarray/conventions/grid.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from collections.abc import Hashable, Sequence
1010
from contextlib import suppress
1111
from functools import cached_property
12-
from typing import Generic, TypeVar, cast
12+
from typing import cast
1313

1414
import numpy
1515
import xarray
@@ -183,10 +183,7 @@ def size(self) -> int:
183183
return int(numpy.prod(self.shape))
184184

185185

186-
Topology = TypeVar('Topology', bound=CFGridTopology)
187-
188-
189-
class CFGrid(Generic[Topology], DimensionConvention[CFGridKind, CFGridIndex]):
186+
class CFGrid[Topology: CFGridTopology](DimensionConvention[CFGridKind, CFGridIndex]):
190187
"""
191188
A base class for CF grid datasets.
192189
There are two concrete subclasses: :class:`CFGrid1D` and :class:`CFGrid2D`.

src/emsarray/operations/geometry.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pathlib
88
from collections.abc import Generator, Iterable, Iterator
99
from contextlib import contextmanager
10-
from typing import IO, Any, Generic, TypeVar
10+
from typing import IO, Any
1111

1212
import geojson
1313
import shapefile
@@ -16,10 +16,8 @@
1616

1717
from emsarray.types import Pathish
1818

19-
T = TypeVar('T')
2019

21-
22-
class _dumpable_iterator(Generic[T], list):
20+
class _dumpable_iterator[T](list):
2321
"""
2422
Wrap an iterator / generator so it can be used in `json.dumps()`.
2523
No guarantees that it works for anything else!

src/emsarray/transect.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import dataclasses
22
from collections.abc import Callable, Iterable
33
from functools import cached_property
4-
from typing import Any, Generic, cast
4+
from typing import Any, cast
55

66
import cfunits
77
import numpy
@@ -16,7 +16,7 @@
1616
from matplotlib.figure import Figure
1717
from matplotlib.ticker import EngFormatter, Formatter
1818

19-
from emsarray.conventions import Convention, Index
19+
from emsarray.conventions import Convention
2020
from emsarray.plot import _requires_plot, make_plot_title
2121
from emsarray.types import DataArrayOrName, Landmark
2222
from emsarray.utils import move_dimensions_to_end, name_to_data_array
@@ -86,7 +86,7 @@ class TransectPoint:
8686

8787

8888
@dataclasses.dataclass
89-
class TransectSegment(Generic[Index]):
89+
class TransectSegment:
9090
"""
9191
A TransectSegment holds information about each intersecting segment of the
9292
transect path and the dataset cells.
@@ -96,7 +96,6 @@ class TransectSegment(Generic[Index]):
9696
intersection: shapely.LineString
9797
start_distance: float
9898
end_distance: float
99-
index: Index
10099
linear_index: int
101100
polygon: shapely.Polygon
102101

@@ -281,7 +280,7 @@ def points(
281280
return points
282281

283282
@cached_property
284-
def segments(self) -> list[TransectSegment[Index]]:
283+
def segments(self) -> list[TransectSegment]:
285284
"""
286285
A list of :class:`.TransectSegmens` for each intersecting segment of the transect line and the dataset geometry.
287286
Segments are listed in order from the start of the line to the end of the line.
@@ -293,7 +292,6 @@ def segments(self) -> list[TransectSegment[Index]]:
293292

294293
for linear_index in intersecting_indexes:
295294
polygon = self.convention.polygons[linear_index]
296-
index = self.convention.wind_index(linear_index)
297295
for intersection in self._intersect_polygon(polygon):
298296
# The line will have two ends.
299297
# The intersection starts and ends at these points.
@@ -314,7 +312,6 @@ def segments(self) -> list[TransectSegment[Index]]:
314312
intersection=intersection,
315313
start_distance=start[1],
316314
end_distance=end[1],
317-
index=index,
318315
linear_index=linear_index,
319316
polygon=polygon,
320317
))

src/emsarray/utils.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
Callable, Hashable, Iterable, Mapping, MutableMapping, Sequence
1919
)
2020
from types import TracebackType
21-
from typing import Any, Literal, TypeVar, cast
21+
from typing import Any, Literal, cast
2222

2323
import cftime
2424
import netCDF4
@@ -35,10 +35,6 @@
3535
DEFAULT_CALENDAR = 'proleptic_gregorian'
3636

3737

38-
_T = TypeVar("_T")
39-
_Exception = TypeVar("_Exception", bound=BaseException)
40-
41-
4238
class PerfTimer:
4339
__slots__ = ('_start', '_stop', 'running')
4440

@@ -56,10 +52,10 @@ def __enter__(self) -> 'PerfTimer':
5652
self._start = time.perf_counter()
5753
return self
5854

59-
def __exit__(
55+
def __exit__[E: BaseException](
6056
self,
61-
exc_type: type[_Exception] | None,
62-
exc_value: _Exception | None,
57+
exc_type: type[E] | None,
58+
exc_value: E | None,
6359
traceback: TracebackType
6460
) -> bool | None:
6561
self._stop = time.perf_counter()
@@ -75,7 +71,7 @@ def elapsed(self) -> float:
7571
return self._stop - self._start
7672

7773

78-
def timed_func(fn: Callable[..., _T]) -> Callable[..., _T]:
74+
def timed_func[F: Callable](fn: F) -> F:
7975
"""
8076
Log the execution time of the decorated function.
8177
Logs "Calling ``<func.__qualname__>``" before the wrapped function is called,
@@ -101,13 +97,13 @@ def polygons(self):
10197
fn_logger = logging.getLogger(fn.__module__)
10298

10399
@functools.wraps(fn)
104-
def wrapper(*args: Any, **kwargs: Any) -> _T:
100+
def wrapper(*args, **kwargs): # type: ignore
105101
fn_logger.debug("Calling %s", fn.__qualname__)
106102
with PerfTimer() as timer:
107103
value = fn(*args, **kwargs)
108104
fn_logger.debug("Completed %s in %fs", fn.__qualname__, timer.elapsed)
109105
return value
110-
return wrapper
106+
return cast(F, wrapper)
111107

112108

113109
def to_netcdf_with_fixes(
@@ -376,7 +372,7 @@ def extract_vars(
376372
return dataset.drop_vars(drop_vars)
377373

378374

379-
def pairwise(iterable: Iterable[_T]) -> Iterable[tuple[_T, _T]]:
375+
def pairwise[T](iterable: Iterable[T]) -> Iterable[tuple[T, T]]:
380376
"""
381377
Iterate over values in an iterator in pairs.
382378
@@ -734,15 +730,15 @@ def __init__(self, extra: str) -> None:
734730
self.extra = extra
735731

736732

737-
def requires_extra(
733+
def requires_extra[T](
738734
extra: str,
739735
import_error: ImportError | None,
740736
exception_class: type[RequiresExtraException] = RequiresExtraException,
741-
) -> Callable[[_T], _T]:
737+
) -> Callable[[T], T]:
742738
if import_error is None:
743739
return lambda fn: fn
744740

745-
def error_decorator(fn: _T) -> _T:
741+
def error_decorator(fn: T) -> T:
746742
@functools.wraps(fn) # type: ignore
747743
def error(*args: Any, **kwargs: Any) -> Any:
748744
raise exception_class(extra) from import_error

0 commit comments

Comments
 (0)