Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions docs/api/conventions/interface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ Convention interface
.. contents::
:local:

.. currentmodule:: emsarray.conventions

Each supported convention implements
the :class:`~emsarray.conventions.Convention` interface.

Expand All @@ -25,8 +27,37 @@ the :class:`~emsarray.conventions.Convention` interface.
.. autoclass:: emsarray.conventions.SpatialIndexItem
:members:

.. autodata:: emsarray.conventions._base.GridKind
.. autodata:: emsarray.conventions._base.Index
.. type:: GridKind

Some type that can enumerate the different :ref:`grid types <grids>`
present in a dataset.
This can be an :class:`enum.Enum` listing each different kind of grid.

:type:`Index` values will be included in the feature properties
of exported geometry from :mod:`emsarray.operations.geometry`.
If the index type includes the grid kind,
the grid kind needs to be JSON serializable.
The easiest way to achieve this is to make your GridKind type subclass :class:`str`:

.. code-block:: python

class MyGridKind(str, enum.Enum):
face = 'face'
edge = 'edge'
node = 'node'

For cases where the convention only supports a single grid,
a singleton enum can be used.

More esoteric cases involving datasets with a potentially unbounded numbers of grids
can use a type that supports this instead.

.. type:: Index

An :ref:`index <indexing>` to a specific point on a grid in this convention.
For conventions with :ref:`multiple grids <grids>` (e.g. cells, edges, and nodes),
this should be a tuple whos first element is :type:`.GridKind`.
For conventions with a single grid, :type:`.GridKind` is not required.

.. autoclass:: emsarray.conventions.Specificity
:members:
Expand Down
2 changes: 1 addition & 1 deletion docs/concepts/grids.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ from one face to another.
These edges represent another grid.
Some conventions also define variables on face vertices, called *nodes*.
Nodes represent a third grid.
This is represented by the :data:`~.conventions._base.GridKind` type variable.
This is represented by the :type:`~.conventions.GridKind` type variable.

Each of the faces, edges, and nodes define an area, line, or point.
These areas, lines, or points exist at some geographic location.
Expand Down
2 changes: 1 addition & 1 deletion docs/concepts/indexing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ As each geometry convention may define a different number of grids,
each convention has a different method of indexing data in these grids.
These are the *convention native indexes*.
Each :class:`~.conventions.Convention` implementation
has its own :data:`~.conventions._base.Index` type.
has its own :type:`~.conventions.Index` type.

:mod:`CF grid datasets <.conventions.grid>` have only one grid - faces.
Each face can be indexed using two numbers *x* and *y*.
Expand Down
6 changes: 6 additions & 0 deletions docs/releases/development.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,9 @@ Next release (in development)
``spatial_index()``, ``get_grid_kind_and_size()``,
and ``NonIntersectingPoints.indices``
(:pr:`202`).
* Use `PEP 695 <https://peps.python.org/pep-0695/>`_ style type parameters.
This drops the `Index` and `GridKind` type variables
which were exported in `emsarray.conventions`,
which is a backwards incompatible change
but is difficult to add meaningful backwards compatible support
(:issue:`109`, :pr:`203`)
17 changes: 12 additions & 5 deletions src/emsarray/cli/commands/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,30 @@
import logging
from collections.abc import Callable
from pathlib import Path
from typing import Any, TypeVar
from typing import Any, overload

import emsarray
from emsarray.cli import BaseCommand, CommandException

T = TypeVar('T')

logger = logging.getLogger(__name__)


def key_value(arg: str, value_type: Callable = str) -> dict[str, T]:
@overload
def key_value(arg: str) -> dict[str, str]: ... # noqa: E704
@overload
def key_value[T](arg: str, value_type: Callable[[str], T]) -> dict[str, T]: ... # noqa: E704


def key_value[T](arg: str, value_type: Callable[[str], T] | None = None) -> dict[str, T] | dict[str, str]:
try:
name, value = arg.split("=", 2)
except ValueError:
raise argparse.ArgumentTypeError(
"Coordinate / dimension indexes must be given as `name=value` pairs")
return {name: value_type(value)}
if value_type is None:
return {name: value}
else:
return {name: value_type(value)}


class UpdateDict(argparse.Action):
Expand Down
5 changes: 2 additions & 3 deletions src/emsarray/conventions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
Refer to each Convention implementation for details.
"""
from ._base import (
Convention, DimensionConvention, GridKind, Index, SpatialIndexItem,
Specificity
Convention, DimensionConvention, SpatialIndexItem, Specificity
)
from ._registry import get_dataset_convention, register_convention
from ._utils import open_dataset
Expand All @@ -25,7 +24,7 @@
from .ugrid import UGrid

__all__ = [
"Convention", "DimensionConvention", "GridKind", "Index",
"Convention", "DimensionConvention",
"SpatialIndexItem", "Specificity",
"get_dataset_convention", "register_convention",
"open_dataset",
Expand Down
39 changes: 4 additions & 35 deletions src/emsarray/conventions/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import warnings
from collections.abc import Callable, Hashable, Iterable, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
from typing import TYPE_CHECKING, Any, Literal, cast

import numpy
import shapely
Expand Down Expand Up @@ -38,39 +38,8 @@
logger = logging.getLogger(__name__)


#: Some type that can enumerate the different :ref:`grid types <grids>`
#: present in a dataset.
#: This can be an :class:`enum.Enum` listing each different kind of grid.
#:
#: :data:`Index` values will be included in the feature properties
#: of exported geometry from :mod:`emsarray.operations.geometry`.
#: If the index type includes the grid kind,
#: the grid kind needs to be JSON serializable.
#: The easiest way to achieve this is to make your GridKind type subclass :class:`str`:
#:
#: .. code-block:: python
#:
#: class MyGridKind(str, enum.Enum):
#: face = 'face'
#: edge = 'edge'
#: node = 'node'
#:
#: For cases where the convention only supports a single grid,
#: a singleton enum can be used.
#:
#: More esoteric cases involving datasets with a potentially unbounded numbers of grids
#: can use a type that supports this instead.
GridKind = TypeVar("GridKind")

#: An :ref:`index <indexing>` to a specific point on a grid in this convention.
#: For conventions with :ref:`multiple grids <grids>` (e.g. cells, edges, and nodes),
#: this should be a tuple whos first element is :data:`.GridKind`.
#: For conventions with a single grid, :data:`.GridKind` is not required.
Index = TypeVar("Index")


@dataclasses.dataclass
class SpatialIndexItem(Generic[Index]):
class SpatialIndexItem[Index]:
"""Information about an item in the :class:`~shapely.strtree.STRtree`
spatial index for a dataset.

Expand Down Expand Up @@ -124,7 +93,7 @@ class Specificity(enum.IntEnum):
HIGH = 30


class Convention(abc.ABC, Generic[GridKind, Index]):
class Convention[GridKind, Index](abc.ABC):
"""
Each supported geometry convention represents data differently.
The :class:`Convention` class abstracts these differences away,
Expand Down Expand Up @@ -1749,7 +1718,7 @@ def hash_geometry(self, hash: "hashlib._Hash") -> None:
hash_attributes(hash, data_array.attrs)


class DimensionConvention(Convention[GridKind, Index]):
class DimensionConvention[GridKind, Index](Convention[GridKind, Index]):
"""
A Convention subclass where different grid kinds
are always defined on unique sets of dimension.
Expand Down
7 changes: 2 additions & 5 deletions src/emsarray/conventions/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections.abc import Hashable, Sequence
from contextlib import suppress
from functools import cached_property
from typing import Generic, TypeVar, cast
from typing import cast

import numpy
import xarray
Expand Down Expand Up @@ -183,10 +183,7 @@ def size(self) -> int:
return int(numpy.prod(self.shape))


Topology = TypeVar('Topology', bound=CFGridTopology)


class CFGrid(Generic[Topology], DimensionConvention[CFGridKind, CFGridIndex]):
class CFGrid[Topology: CFGridTopology](DimensionConvention[CFGridKind, CFGridIndex]):
"""
A base class for CF grid datasets.
There are two concrete subclasses: :class:`CFGrid1D` and :class:`CFGrid2D`.
Expand Down
6 changes: 2 additions & 4 deletions src/emsarray/operations/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pathlib
from collections.abc import Generator, Iterable, Iterator
from contextlib import contextmanager
from typing import IO, Any, Generic, TypeVar
from typing import IO, Any

import geojson
import shapefile
Expand All @@ -16,10 +16,8 @@

from emsarray.types import Pathish

T = TypeVar('T')


class _dumpable_iterator(Generic[T], list):
class _dumpable_iterator[T](list):
"""
Wrap an iterator / generator so it can be used in `json.dumps()`.
No guarantees that it works for anything else!
Expand Down
11 changes: 4 additions & 7 deletions src/emsarray/transect.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
from collections.abc import Callable, Iterable
from functools import cached_property
from typing import Any, Generic, cast
from typing import Any, cast

import cfunits
import numpy
Expand All @@ -16,7 +16,7 @@
from matplotlib.figure import Figure
from matplotlib.ticker import EngFormatter, Formatter

from emsarray.conventions import Convention, Index
from emsarray.conventions import Convention
from emsarray.plot import _requires_plot, make_plot_title
from emsarray.types import DataArrayOrName, Landmark
from emsarray.utils import move_dimensions_to_end, name_to_data_array
Expand Down Expand Up @@ -86,7 +86,7 @@ class TransectPoint:


@dataclasses.dataclass
class TransectSegment(Generic[Index]):
class TransectSegment:
"""
A TransectSegment holds information about each intersecting segment of the
transect path and the dataset cells.
Expand All @@ -96,7 +96,6 @@ class TransectSegment(Generic[Index]):
intersection: shapely.LineString
start_distance: float
end_distance: float
index: Index
linear_index: int
polygon: shapely.Polygon

Expand Down Expand Up @@ -281,7 +280,7 @@ def points(
return points

@cached_property
def segments(self) -> list[TransectSegment[Index]]:
def segments(self) -> list[TransectSegment]:
"""
A list of :class:`.TransectSegmens` for each intersecting segment of the transect line and the dataset geometry.
Segments are listed in order from the start of the line to the end of the line.
Expand All @@ -293,7 +292,6 @@ def segments(self) -> list[TransectSegment[Index]]:

for linear_index in intersecting_indexes:
polygon = self.convention.polygons[linear_index]
index = self.convention.wind_index(linear_index)
for intersection in self._intersect_polygon(polygon):
# The line will have two ends.
# The intersection starts and ends at these points.
Expand All @@ -314,7 +312,6 @@ def segments(self) -> list[TransectSegment[Index]]:
intersection=intersection,
start_distance=start[1],
end_distance=end[1],
index=index,
linear_index=linear_index,
polygon=polygon,
))
Expand Down
26 changes: 11 additions & 15 deletions src/emsarray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Callable, Hashable, Iterable, Mapping, MutableMapping, Sequence
)
from types import TracebackType
from typing import Any, Literal, TypeVar, cast
from typing import Any, Literal, cast

import cftime
import netCDF4
Expand All @@ -35,10 +35,6 @@
DEFAULT_CALENDAR = 'proleptic_gregorian'


_T = TypeVar("_T")
_Exception = TypeVar("_Exception", bound=BaseException)


class PerfTimer:
__slots__ = ('_start', '_stop', 'running')

Expand All @@ -56,10 +52,10 @@ def __enter__(self) -> 'PerfTimer':
self._start = time.perf_counter()
return self

def __exit__(
def __exit__[E: BaseException](
self,
exc_type: type[_Exception] | None,
exc_value: _Exception | None,
exc_type: type[E] | None,
exc_value: E | None,
traceback: TracebackType
) -> bool | None:
self._stop = time.perf_counter()
Expand All @@ -75,7 +71,7 @@ def elapsed(self) -> float:
return self._stop - self._start


def timed_func(fn: Callable[..., _T]) -> Callable[..., _T]:
def timed_func[F: Callable](fn: F) -> F:
"""
Log the execution time of the decorated function.
Logs "Calling ``<func.__qualname__>``" before the wrapped function is called,
Expand All @@ -101,13 +97,13 @@ def polygons(self):
fn_logger = logging.getLogger(fn.__module__)

@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> _T:
def wrapper(*args, **kwargs): # type: ignore
fn_logger.debug("Calling %s", fn.__qualname__)
with PerfTimer() as timer:
value = fn(*args, **kwargs)
fn_logger.debug("Completed %s in %fs", fn.__qualname__, timer.elapsed)
return value
return wrapper
return cast(F, wrapper)


def to_netcdf_with_fixes(
Expand Down Expand Up @@ -376,7 +372,7 @@ def extract_vars(
return dataset.drop_vars(drop_vars)


def pairwise(iterable: Iterable[_T]) -> Iterable[tuple[_T, _T]]:
def pairwise[T](iterable: Iterable[T]) -> Iterable[tuple[T, T]]:
"""
Iterate over values in an iterator in pairs.

Expand Down Expand Up @@ -734,15 +730,15 @@ def __init__(self, extra: str) -> None:
self.extra = extra


def requires_extra(
def requires_extra[T](
extra: str,
import_error: ImportError | None,
exception_class: type[RequiresExtraException] = RequiresExtraException,
) -> Callable[[_T], _T]:
) -> Callable[[T], T]:
if import_error is None:
return lambda fn: fn

def error_decorator(fn: _T) -> _T:
def error_decorator(fn: T) -> T:
@functools.wraps(fn) # type: ignore
def error(*args: Any, **kwargs: Any) -> Any:
raise exception_class(extra) from import_error
Expand Down
Loading