Skip to content

Commit 8e16424

Browse files
committed
Add Convention.select_indexes() and related functions
This is a breaking change for plugins as it changes the abstract methods on the Convention class. `Convention.selector_for_indexes()` is a new required method, and `Convention.selector_for_index()` now has a default implementation which calls the new method.
1 parent 43fd43d commit 8e16424

File tree

5 files changed

+232
-54
lines changed

5 files changed

+232
-54
lines changed

src/emsarray/conventions/_base.py

Lines changed: 140 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import warnings
66
from collections.abc import Callable, Hashable, Iterable, Sequence
77
from functools import cached_property
8-
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
8+
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
99

1010
import numpy
1111
import xarray
@@ -17,7 +17,7 @@
1717
from emsarray import utils
1818
from emsarray.compat.shapely import SpatialIndex
1919
from emsarray.exceptions import NoSuchCoordinateError
20-
from emsarray.operations import depth
20+
from emsarray.operations import depth, point_extraction
2121
from emsarray.plot import (
2222
_requires_plot, animate_on_figure, make_plot_title, plot_on_figure,
2323
polygons_to_collection
@@ -516,7 +516,7 @@ def get_depth_coordinate_for_data_array(
516516
candidates = [
517517
coordinate
518518
for coordinate in self.depth_coordinates
519-
if coordinate.dims[0] in data_array.dims
519+
if set(coordinate.dims) <= set(data_array.dims)
520520
]
521521
if len(candidates) == 0:
522522
raise NoSuchCoordinateError(f"No depth coordinate found for {name}")
@@ -1471,8 +1471,7 @@ def get_index_for_point(
14711471
polygon=self.polygons[linear_index])
14721472
return None
14731473

1474-
@abc.abstractmethod
1475-
def selector_for_index(self, index: Index) -> dict[Hashable, int]:
1474+
def selector_for_index(self, index: Index) -> xarray.Dataset:
14761475
"""
14771476
Convert a convention native index into a selector
14781477
that can be passed to :meth:`Dataset.isel <xarray.Dataset.isel>`.
@@ -1494,11 +1493,24 @@ def selector_for_index(self, index: Index) -> dict[Hashable, int]:
14941493
:meth:`.select_point`
14951494
:ref:`indexing`
14961495
"""
1496+
index_dimension = utils.find_unused_dimension(self.dataset, 'index')
1497+
dataset = self.selector_for_indexes([index], index_dimension=index_dimension)
1498+
dataset = dataset.squeeze(dim=index_dimension, drop=False)
1499+
return dataset
1500+
1501+
@abc.abstractmethod
1502+
def selector_for_indexes(
1503+
self,
1504+
indexes: list[Index],
1505+
*,
1506+
index_dimension: Hashable | None = None,
1507+
) -> xarray.Dataset:
14971508
pass
14981509

14991510
def select_index(
15001511
self,
15011512
index: Index,
1513+
drop_geometry: bool = True,
15021514
) -> xarray.Dataset:
15031515
"""
15041516
Return a new dataset that contains values only from a single index.
@@ -1516,6 +1528,62 @@ def select_index(
15161528
index : :data:`Index`
15171529
The index to select.
15181530
The index must be for the default grid kind for this dataset.
1531+
drop_geometry : bool, default True
1532+
Whether to drop geometry variables from the returned point dataset.
1533+
If the geometry data is kept
1534+
the associated geometry data will no longer conform to the dataset convention
1535+
and may not conform to any sensible convention at all.
1536+
The format of the geometry data left after selecting points is convention-dependent.
1537+
1538+
Returns
1539+
-------
1540+
:class:`xarray.Dataset`
1541+
A new dataset that is subset to the one index.
1542+
1543+
Notes
1544+
-----
1545+
1546+
The returned dataset will most likely not have sufficient coordinate data
1547+
to be used with a particular :class:`Convention` any more.
1548+
The ``dataset.ems`` accessor will raise an error if accessed on the new dataset.
1549+
"""
1550+
index_dimension = utils.find_unused_dimension(self.dataset, 'index')
1551+
dataset = self.select_indexes([index], index_dimension=index_dimension, drop_geometry=drop_geometry)
1552+
dataset = dataset.squeeze(dim=index_dimension, drop=False)
1553+
return dataset
1554+
1555+
def select_indexes(
1556+
self,
1557+
indexes: list[Index],
1558+
*,
1559+
index_dimension: Hashable | None = None,
1560+
drop_geometry: bool = True,
1561+
) -> xarray.Dataset:
1562+
"""
1563+
Return a new dataset that contains values only at the selected indexes.
1564+
This is much like doing a :func:`xarray.Dataset.isel()` on some indexes,
1565+
but works with convention native index types.
1566+
1567+
An index is associated with a grid kind.
1568+
The returned dataset will only contain variables that were defined on this grid,
1569+
with the single indexed point selected.
1570+
For example, if the index of a face is passed in,
1571+
the returned dataset will not contain any variables defined on an edge.
1572+
1573+
Parameters
1574+
----------
1575+
index : :data:`Index`
1576+
The index to select.
1577+
The index must be for the default grid kind for this dataset.
1578+
index_dimension : str, optional
1579+
The name of the new dimension added for each index to select.
1580+
Defaults to the :func:`first unused dimension <.utils.find_unused_dimension>` with prefix `index`.
1581+
drop_geometry : bool, default True
1582+
Whether to drop geometry variables from the returned point dataset.
1583+
If the geometry data is kept
1584+
the associated geometry data will no longer conform to the dataset convention
1585+
and may not conform to any sensible convention at all.
1586+
The format of the geometry data left after selecting points is convention-dependent.
15191587
15201588
Returns
15211589
-------
@@ -1529,16 +1597,21 @@ def select_index(
15291597
to be used with a particular :class:`Convention` any more.
15301598
The ``dataset.ems`` accessor will raise an error if accessed on the new dataset.
15311599
"""
1532-
selector = self.selector_for_index(index)
1600+
selector = self.selector_for_indexes(indexes, index_dimension=index_dimension)
15331601

15341602
# Make a new dataset consisting of only data arrays that use at least
15351603
# one of these dimensions.
1536-
dims = set(selector.keys())
1604+
if drop_geometry:
1605+
dataset = self.drop_geometry()
1606+
else:
1607+
dataset = self.dataset
1608+
1609+
dims = set(selector.variables.keys())
15371610
names = [
1538-
name for name, data_array in self.dataset.items()
1611+
name for name, data_array in dataset.items()
15391612
if dims.intersection(data_array.dims)
15401613
]
1541-
dataset = utils.extract_vars(self.dataset, names)
1614+
dataset = utils.extract_vars(dataset, names)
15421615

15431616
# Select just this point
15441617
return dataset.isel(selector)
@@ -1565,6 +1638,38 @@ def select_point(self, point: Point) -> xarray.Dataset:
15651638
raise ValueError("Point did not intersect dataset")
15661639
return self.select_index(index.index)
15671640

1641+
def select_points(
1642+
self,
1643+
points: list[Point],
1644+
*,
1645+
point_dimension: Hashable | None = None,
1646+
missing_points: Literal['error', 'drop'] = 'error',
1647+
) -> xarray.Dataset:
1648+
"""
1649+
Extract values from all variables on the default grid at a sequence of points.
1650+
1651+
Parameters
1652+
----------
1653+
points : list of shapely.Point
1654+
The points to extract
1655+
point_dimension : str, optional
1656+
The name of the new dimension used to index points.
1657+
Defaults to 'point', or 'point_0', 'point_1', etc if those dimensions already exist.
1658+
missing_points : {'error', 'drop'}, default 'error'
1659+
What to do if a point does not intersect the dataset.
1660+
'raise' will raise an error, while 'drop' will drop those points.
1661+
1662+
Returns
1663+
-------
1664+
xarray.Dataset
1665+
A dataset with values extracted from the points.
1666+
No variables not defined on the default grid and no geometry variables will be present.
1667+
"""
1668+
if point_dimension is None:
1669+
point_dimension = utils.find_unused_dimension(self.dataset, 'point')
1670+
return point_extraction.extract_points(
1671+
self.dataset, points, point_dimension=point_dimension, missing_points=missing_points)
1672+
15681673
@abc.abstractmethod
15691674
def get_all_geometry_names(self) -> list[Hashable]:
15701675
"""
@@ -1936,7 +2041,30 @@ def wind(
19362041
dimensions=dimensions, sizes=sizes,
19372042
linear_dimension=linear_dimension)
19382043

1939-
def selector_for_index(self, index: Index) -> dict[Hashable, int]:
1940-
grid_kind, indexes = self.unpack_index(index)
2044+
def selector_for_indexes(
2045+
self,
2046+
indexes: list[Index],
2047+
*,
2048+
index_dimension: Hashable | None = None,
2049+
) -> xarray.Dataset:
2050+
if index_dimension is None:
2051+
index_dimension = utils.find_unused_dimension(self.dataset, 'index')
2052+
if len(indexes) == 0:
2053+
raise ValueError("Need at least one index to select")
2054+
2055+
grid_kinds, index_tuples = zip(*[self.unpack_index(index) for index in indexes])
2056+
2057+
unique_grid_kinds = set(grid_kinds)
2058+
if len(unique_grid_kinds) > 1:
2059+
raise ValueError(
2060+
"All indexes must be on the same grid kind, got "
2061+
+ ", ".join(map(repr, unique_grid_kinds)))
2062+
2063+
grid_kind = grid_kinds[0]
19412064
dimensions = self.grid_dimensions[grid_kind]
1942-
return dict(zip(dimensions, indexes))
2065+
# This array will have shape (len(indexes), len(dimensions))
2066+
index_array = numpy.array(index_tuples)
2067+
return xarray.Dataset({
2068+
dimension: (index_dimension, index_array[:, i])
2069+
for i, dimension in enumerate(dimensions)
2070+
})

src/emsarray/operations/point_extraction.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
import xarray
2424
import xarray.core.dtypes as xrdtypes
2525

26+
import emsarray.conventions
2627
from emsarray import utils
27-
from emsarray.conventions import Convention
2828

2929

3030
@dataclasses.dataclass
@@ -67,7 +67,7 @@ def extract_points(
6767
dataset: xarray.Dataset,
6868
points: list[shapely.Point],
6969
*,
70-
point_dimension: Hashable = 'point',
70+
point_dimension: Hashable | None = None,
7171
missing_points: Literal['error', 'drop'] = 'error',
7272
) -> xarray.Dataset:
7373
"""
@@ -106,7 +106,10 @@ def extract_points(
106106
--------
107107
:func:`extract_dataframe`
108108
"""
109-
convention: Convention = dataset.ems
109+
convention: emsarray.conventions.Convention = dataset.ems
110+
111+
if point_dimension is None:
112+
point_dimension = utils.find_unused_dimension(dataset, 'point')
110113

111114
# Find the indexer for each given point
112115
indexes = numpy.array([convention.get_index_for_point(point) for point in points])
@@ -118,21 +121,17 @@ def extract_points(
118121
indexes=out_of_bounds,
119122
points=[points[i] for i in out_of_bounds])
120123

121-
# Make a DataFrame out of all point indexers
122-
selector = convention.selector_for_indexes([i.index for i in indexes])
123-
selector_df = pandas.DataFrame([
124-
convention.selector_for_index(index.index)
125-
for index in indexes
126-
if index is not None])
127-
point_indexes = [i for i, index in enumerate(indexes) if index is not None]
124+
point_ds = convention.select_indexes(
125+
[index.index for index in indexes if index is not None],
126+
index_dimension=point_dimension,
127+
drop_geometry=True)
128128

129-
# Subset the dataset to the points
130-
point_ds = convention.drop_geometry()
131-
selector_ds = _dataframe_to_dataset(selector_df, dimension_name=point_dimension)
132-
point_ds = point_ds.isel(selector_ds)
129+
# Number the points
130+
point_indexes = [i for i, index in enumerate(indexes) if index is not None]
133131
point_ds = point_ds.assign_coords({
134132
point_dimension: ([point_dimension], point_indexes),
135133
})
134+
136135
return point_ds
137136

138137

tests/conventions/test_base.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,20 @@ def wind_index(
7171
def ravel_index(self, indexes: SimpleGridIndex) -> int:
7272
return int(numpy.ravel_multi_index((indexes.y, indexes.x), self.shape))
7373

74-
def selector_for_index(self, index: SimpleGridIndex) -> dict[Hashable, int]:
75-
return {'x': index.x, 'y': index.y}
74+
def selector_for_indexes(
75+
self,
76+
indexes: list[SimpleGridIndex],
77+
*,
78+
index_dimension: Hashable | None = None,
79+
) -> xarray.Dataset:
80+
if index_dimension is None:
81+
index_dimension = utils.find_unused_dimension(self.dataset, 'index')
82+
index_array = numpy.array([[index.x, index.y] for index in indexes])
83+
84+
return xarray.Dataset({
85+
'x': (index_dimension, index_array[:, 0]),
86+
'y': (index_dimension, index_array[:, 1]),
87+
})
7688

7789
def ravel(
7890
self,

0 commit comments

Comments
 (0)