diff --git a/docs/api/conventions/interface.rst b/docs/api/conventions/interface.rst index 4665dcb8..b984ecb0 100644 --- a/docs/api/conventions/interface.rst +++ b/docs/api/conventions/interface.rst @@ -7,6 +7,8 @@ Convention interface .. contents:: :local: +.. currentmodule:: emsarray.conventions + Each supported convention implements the :class:`~emsarray.conventions.Convention` interface. @@ -25,8 +27,37 @@ the :class:`~emsarray.conventions.Convention` interface. .. autoclass:: emsarray.conventions.SpatialIndexItem :members: -.. autodata:: emsarray.conventions._base.GridKind -.. autodata:: emsarray.conventions._base.Index +.. type:: GridKind + + Some type that can enumerate the different :ref:`grid types ` + present in a dataset. + This can be an :class:`enum.Enum` listing each different kind of grid. + + :type:`Index` values will be included in the feature properties + of exported geometry from :mod:`emsarray.operations.geometry`. + If the index type includes the grid kind, + the grid kind needs to be JSON serializable. + The easiest way to achieve this is to make your GridKind type subclass :class:`str`: + + .. code-block:: python + + class MyGridKind(str, enum.Enum): + face = 'face' + edge = 'edge' + node = 'node' + + For cases where the convention only supports a single grid, + a singleton enum can be used. + + More esoteric cases involving datasets with a potentially unbounded numbers of grids + can use a type that supports this instead. + +.. type:: Index + + An :ref:`index ` to a specific point on a grid in this convention. + For conventions with :ref:`multiple grids ` (e.g. cells, edges, and nodes), + this should be a tuple whos first element is :type:`.GridKind`. + For conventions with a single grid, :type:`.GridKind` is not required. .. autoclass:: emsarray.conventions.Specificity :members: diff --git a/docs/concepts/grids.rst b/docs/concepts/grids.rst index 8bab53d4..d33ae6d7 100644 --- a/docs/concepts/grids.rst +++ b/docs/concepts/grids.rst @@ -19,7 +19,7 @@ from one face to another. These edges represent another grid. Some conventions also define variables on face vertices, called *nodes*. Nodes represent a third grid. -This is represented by the :data:`~.conventions._base.GridKind` type variable. +This is represented by the :type:`~.conventions.GridKind` type variable. Each of the faces, edges, and nodes define an area, line, or point. These areas, lines, or points exist at some geographic location. diff --git a/docs/concepts/indexing.rst b/docs/concepts/indexing.rst index 151d5ca8..7cfcf27d 100644 --- a/docs/concepts/indexing.rst +++ b/docs/concepts/indexing.rst @@ -16,7 +16,7 @@ As each geometry convention may define a different number of grids, each convention has a different method of indexing data in these grids. These are the *convention native indexes*. Each :class:`~.conventions.Convention` implementation -has its own :data:`~.conventions._base.Index` type. +has its own :type:`~.conventions.Index` type. :mod:`CF grid datasets <.conventions.grid>` have only one grid - faces. Each face can be indexed using two numbers *x* and *y*. diff --git a/docs/releases/development.rst b/docs/releases/development.rst index 8f826de1..fe3651ff 100644 --- a/docs/releases/development.rst +++ b/docs/releases/development.rst @@ -18,3 +18,9 @@ Next release (in development) ``spatial_index()``, ``get_grid_kind_and_size()``, and ``NonIntersectingPoints.indices`` (:pr:`202`). +* Use `PEP 695 `_ style type parameters. + This drops the `Index` and `GridKind` type variables + which were exported in `emsarray.conventions`, + which is a backwards incompatible change + but is difficult to add meaningful backwards compatible support + (:issue:`109`, :pr:`203`) diff --git a/src/emsarray/cli/commands/plot.py b/src/emsarray/cli/commands/plot.py index ec824a8a..0aae53c8 100644 --- a/src/emsarray/cli/commands/plot.py +++ b/src/emsarray/cli/commands/plot.py @@ -3,23 +3,30 @@ import logging from collections.abc import Callable from pathlib import Path -from typing import Any, TypeVar +from typing import Any, overload import emsarray from emsarray.cli import BaseCommand, CommandException -T = TypeVar('T') - logger = logging.getLogger(__name__) -def key_value(arg: str, value_type: Callable = str) -> dict[str, T]: +@overload +def key_value(arg: str) -> dict[str, str]: ... # noqa: E704 +@overload +def key_value[T](arg: str, value_type: Callable[[str], T]) -> dict[str, T]: ... # noqa: E704 + + +def key_value[T](arg: str, value_type: Callable[[str], T] | None = None) -> dict[str, T] | dict[str, str]: try: name, value = arg.split("=", 2) except ValueError: raise argparse.ArgumentTypeError( "Coordinate / dimension indexes must be given as `name=value` pairs") - return {name: value_type(value)} + if value_type is None: + return {name: value} + else: + return {name: value_type(value)} class UpdateDict(argparse.Action): diff --git a/src/emsarray/conventions/__init__.py b/src/emsarray/conventions/__init__.py index 2e7ffd11..97014398 100644 --- a/src/emsarray/conventions/__init__.py +++ b/src/emsarray/conventions/__init__.py @@ -14,8 +14,7 @@ Refer to each Convention implementation for details. """ from ._base import ( - Convention, DimensionConvention, GridKind, Index, SpatialIndexItem, - Specificity + Convention, DimensionConvention, SpatialIndexItem, Specificity ) from ._registry import get_dataset_convention, register_convention from ._utils import open_dataset @@ -25,7 +24,7 @@ from .ugrid import UGrid __all__ = [ - "Convention", "DimensionConvention", "GridKind", "Index", + "Convention", "DimensionConvention", "SpatialIndexItem", "Specificity", "get_dataset_convention", "register_convention", "open_dataset", diff --git a/src/emsarray/conventions/_base.py b/src/emsarray/conventions/_base.py index ac85ef67..2f4d5da5 100644 --- a/src/emsarray/conventions/_base.py +++ b/src/emsarray/conventions/_base.py @@ -6,7 +6,7 @@ import warnings from collections.abc import Callable, Hashable, Iterable, Sequence from functools import cached_property -from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast +from typing import TYPE_CHECKING, Any, Literal, cast import numpy import shapely @@ -38,39 +38,8 @@ logger = logging.getLogger(__name__) -#: Some type that can enumerate the different :ref:`grid types ` -#: present in a dataset. -#: This can be an :class:`enum.Enum` listing each different kind of grid. -#: -#: :data:`Index` values will be included in the feature properties -#: of exported geometry from :mod:`emsarray.operations.geometry`. -#: If the index type includes the grid kind, -#: the grid kind needs to be JSON serializable. -#: The easiest way to achieve this is to make your GridKind type subclass :class:`str`: -#: -#: .. code-block:: python -#: -#: class MyGridKind(str, enum.Enum): -#: face = 'face' -#: edge = 'edge' -#: node = 'node' -#: -#: For cases where the convention only supports a single grid, -#: a singleton enum can be used. -#: -#: More esoteric cases involving datasets with a potentially unbounded numbers of grids -#: can use a type that supports this instead. -GridKind = TypeVar("GridKind") - -#: An :ref:`index ` to a specific point on a grid in this convention. -#: For conventions with :ref:`multiple grids ` (e.g. cells, edges, and nodes), -#: this should be a tuple whos first element is :data:`.GridKind`. -#: For conventions with a single grid, :data:`.GridKind` is not required. -Index = TypeVar("Index") - - @dataclasses.dataclass -class SpatialIndexItem(Generic[Index]): +class SpatialIndexItem[Index]: """Information about an item in the :class:`~shapely.strtree.STRtree` spatial index for a dataset. @@ -124,7 +93,7 @@ class Specificity(enum.IntEnum): HIGH = 30 -class Convention(abc.ABC, Generic[GridKind, Index]): +class Convention[GridKind, Index](abc.ABC): """ Each supported geometry convention represents data differently. The :class:`Convention` class abstracts these differences away, @@ -1749,7 +1718,7 @@ def hash_geometry(self, hash: "hashlib._Hash") -> None: hash_attributes(hash, data_array.attrs) -class DimensionConvention(Convention[GridKind, Index]): +class DimensionConvention[GridKind, Index](Convention[GridKind, Index]): """ A Convention subclass where different grid kinds are always defined on unique sets of dimension. diff --git a/src/emsarray/conventions/grid.py b/src/emsarray/conventions/grid.py index 2997a2d8..6bc39df7 100644 --- a/src/emsarray/conventions/grid.py +++ b/src/emsarray/conventions/grid.py @@ -9,7 +9,7 @@ from collections.abc import Hashable, Sequence from contextlib import suppress from functools import cached_property -from typing import Generic, TypeVar, cast +from typing import cast import numpy import xarray @@ -183,10 +183,7 @@ def size(self) -> int: return int(numpy.prod(self.shape)) -Topology = TypeVar('Topology', bound=CFGridTopology) - - -class CFGrid(Generic[Topology], DimensionConvention[CFGridKind, CFGridIndex]): +class CFGrid[Topology: CFGridTopology](DimensionConvention[CFGridKind, CFGridIndex]): """ A base class for CF grid datasets. There are two concrete subclasses: :class:`CFGrid1D` and :class:`CFGrid2D`. diff --git a/src/emsarray/operations/geometry.py b/src/emsarray/operations/geometry.py index 62229e8b..3b06cb4a 100644 --- a/src/emsarray/operations/geometry.py +++ b/src/emsarray/operations/geometry.py @@ -7,7 +7,7 @@ import pathlib from collections.abc import Generator, Iterable, Iterator from contextlib import contextmanager -from typing import IO, Any, Generic, TypeVar +from typing import IO, Any import geojson import shapefile @@ -16,10 +16,8 @@ from emsarray.types import Pathish -T = TypeVar('T') - -class _dumpable_iterator(Generic[T], list): +class _dumpable_iterator[T](list): """ Wrap an iterator / generator so it can be used in `json.dumps()`. No guarantees that it works for anything else! diff --git a/src/emsarray/transect.py b/src/emsarray/transect.py index ebc47557..45b31d17 100644 --- a/src/emsarray/transect.py +++ b/src/emsarray/transect.py @@ -1,7 +1,7 @@ import dataclasses from collections.abc import Callable, Iterable from functools import cached_property -from typing import Any, Generic, cast +from typing import Any, cast import cfunits import numpy @@ -16,7 +16,7 @@ from matplotlib.figure import Figure from matplotlib.ticker import EngFormatter, Formatter -from emsarray.conventions import Convention, Index +from emsarray.conventions import Convention from emsarray.plot import _requires_plot, make_plot_title from emsarray.types import DataArrayOrName, Landmark from emsarray.utils import move_dimensions_to_end, name_to_data_array @@ -86,7 +86,7 @@ class TransectPoint: @dataclasses.dataclass -class TransectSegment(Generic[Index]): +class TransectSegment: """ A TransectSegment holds information about each intersecting segment of the transect path and the dataset cells. @@ -96,7 +96,6 @@ class TransectSegment(Generic[Index]): intersection: shapely.LineString start_distance: float end_distance: float - index: Index linear_index: int polygon: shapely.Polygon @@ -281,7 +280,7 @@ def points( return points @cached_property - def segments(self) -> list[TransectSegment[Index]]: + def segments(self) -> list[TransectSegment]: """ A list of :class:`.TransectSegmens` for each intersecting segment of the transect line and the dataset geometry. Segments are listed in order from the start of the line to the end of the line. @@ -293,7 +292,6 @@ def segments(self) -> list[TransectSegment[Index]]: for linear_index in intersecting_indexes: polygon = self.convention.polygons[linear_index] - index = self.convention.wind_index(linear_index) for intersection in self._intersect_polygon(polygon): # The line will have two ends. # The intersection starts and ends at these points. @@ -314,7 +312,6 @@ def segments(self) -> list[TransectSegment[Index]]: intersection=intersection, start_distance=start[1], end_distance=end[1], - index=index, linear_index=linear_index, polygon=polygon, )) diff --git a/src/emsarray/utils.py b/src/emsarray/utils.py index c0f757e9..78f0b980 100644 --- a/src/emsarray/utils.py +++ b/src/emsarray/utils.py @@ -18,7 +18,7 @@ Callable, Hashable, Iterable, Mapping, MutableMapping, Sequence ) from types import TracebackType -from typing import Any, Literal, TypeVar, cast +from typing import Any, Literal, cast import cftime import netCDF4 @@ -35,10 +35,6 @@ DEFAULT_CALENDAR = 'proleptic_gregorian' -_T = TypeVar("_T") -_Exception = TypeVar("_Exception", bound=BaseException) - - class PerfTimer: __slots__ = ('_start', '_stop', 'running') @@ -56,10 +52,10 @@ def __enter__(self) -> 'PerfTimer': self._start = time.perf_counter() return self - def __exit__( + def __exit__[E: BaseException]( self, - exc_type: type[_Exception] | None, - exc_value: _Exception | None, + exc_type: type[E] | None, + exc_value: E | None, traceback: TracebackType ) -> bool | None: self._stop = time.perf_counter() @@ -75,7 +71,7 @@ def elapsed(self) -> float: return self._stop - self._start -def timed_func(fn: Callable[..., _T]) -> Callable[..., _T]: +def timed_func[F: Callable](fn: F) -> F: """ Log the execution time of the decorated function. Logs "Calling ````" before the wrapped function is called, @@ -101,13 +97,13 @@ def polygons(self): fn_logger = logging.getLogger(fn.__module__) @functools.wraps(fn) - def wrapper(*args: Any, **kwargs: Any) -> _T: + def wrapper(*args, **kwargs): # type: ignore fn_logger.debug("Calling %s", fn.__qualname__) with PerfTimer() as timer: value = fn(*args, **kwargs) fn_logger.debug("Completed %s in %fs", fn.__qualname__, timer.elapsed) return value - return wrapper + return cast(F, wrapper) def to_netcdf_with_fixes( @@ -376,7 +372,7 @@ def extract_vars( return dataset.drop_vars(drop_vars) -def pairwise(iterable: Iterable[_T]) -> Iterable[tuple[_T, _T]]: +def pairwise[T](iterable: Iterable[T]) -> Iterable[tuple[T, T]]: """ Iterate over values in an iterator in pairs. @@ -734,15 +730,15 @@ def __init__(self, extra: str) -> None: self.extra = extra -def requires_extra( +def requires_extra[T]( extra: str, import_error: ImportError | None, exception_class: type[RequiresExtraException] = RequiresExtraException, -) -> Callable[[_T], _T]: +) -> Callable[[T], T]: if import_error is None: return lambda fn: fn - def error_decorator(fn: _T) -> _T: + def error_decorator(fn: T) -> T: @functools.wraps(fn) # type: ignore def error(*args: Any, **kwargs: Any) -> Any: raise exception_class(extra) from import_error