55import warnings
66from collections .abc import Callable , Hashable , Iterable , Sequence
77from functools import cached_property
8- from typing import TYPE_CHECKING , Any , Generic , TypeVar , cast
8+ from typing import TYPE_CHECKING , Any , Generic , Literal , TypeVar , cast
99
1010import numpy
1111import xarray
1717from emsarray import utils
1818from emsarray .compat .shapely import SpatialIndex
1919from emsarray .exceptions import NoSuchCoordinateError
20- from emsarray .operations import depth
20+ from emsarray .operations import depth , point_extraction
2121from emsarray .plot import (
2222 _requires_plot , animate_on_figure , make_plot_title , plot_on_figure ,
2323 polygons_to_collection
@@ -516,7 +516,7 @@ def get_depth_coordinate_for_data_array(
516516 candidates = [
517517 coordinate
518518 for coordinate in self .depth_coordinates
519- if coordinate .dims [ 0 ] in data_array .dims
519+ if set ( coordinate .dims ) <= set ( data_array .dims )
520520 ]
521521 if len (candidates ) == 0 :
522522 raise NoSuchCoordinateError (f"No depth coordinate found for { name } " )
@@ -1471,8 +1471,7 @@ def get_index_for_point(
14711471 polygon = self .polygons [linear_index ])
14721472 return None
14731473
1474- @abc .abstractmethod
1475- def selector_for_index (self , index : Index ) -> dict [Hashable , int ]:
1474+ def selector_for_index (self , index : Index ) -> xarray .Dataset :
14761475 """
14771476 Convert a convention native index into a selector
14781477 that can be passed to :meth:`Dataset.isel <xarray.Dataset.isel>`.
@@ -1494,11 +1493,24 @@ def selector_for_index(self, index: Index) -> dict[Hashable, int]:
14941493 :meth:`.select_point`
14951494 :ref:`indexing`
14961495 """
1496+ index_dimension = utils .find_unused_dimension (self .dataset , 'index' )
1497+ dataset = self .selector_for_indexes ([index ], index_dimension = index_dimension )
1498+ dataset = dataset .squeeze (dim = index_dimension , drop = False )
1499+ return dataset
1500+
1501+ @abc .abstractmethod
1502+ def selector_for_indexes (
1503+ self ,
1504+ indexes : list [Index ],
1505+ * ,
1506+ index_dimension : Hashable | None = None ,
1507+ ) -> xarray .Dataset :
14971508 pass
14981509
14991510 def select_index (
15001511 self ,
15011512 index : Index ,
1513+ drop_geometry : bool = True ,
15021514 ) -> xarray .Dataset :
15031515 """
15041516 Return a new dataset that contains values only from a single index.
@@ -1516,6 +1528,62 @@ def select_index(
15161528 index : :data:`Index`
15171529 The index to select.
15181530 The index must be for the default grid kind for this dataset.
1531+ drop_geometry : bool, default True
1532+ Whether to drop geometry variables from the returned point dataset.
1533+ If the geometry data is kept
1534+ the associated geometry data will no longer conform to the dataset convention
1535+ and may not conform to any sensible convention at all.
1536+ The format of the geometry data left after selecting points is convention-dependent.
1537+
1538+ Returns
1539+ -------
1540+ :class:`xarray.Dataset`
1541+ A new dataset that is subset to the one index.
1542+
1543+ Notes
1544+ -----
1545+
1546+ The returned dataset will most likely not have sufficient coordinate data
1547+ to be used with a particular :class:`Convention` any more.
1548+ The ``dataset.ems`` accessor will raise an error if accessed on the new dataset.
1549+ """
1550+ index_dimension = utils .find_unused_dimension (self .dataset , 'index' )
1551+ dataset = self .select_indexes ([index ], index_dimension = index_dimension , drop_geometry = drop_geometry )
1552+ dataset = dataset .squeeze (dim = index_dimension , drop = False )
1553+ return dataset
1554+
1555+ def select_indexes (
1556+ self ,
1557+ indexes : list [Index ],
1558+ * ,
1559+ index_dimension : Hashable | None = None ,
1560+ drop_geometry : bool = True ,
1561+ ) -> xarray .Dataset :
1562+ """
1563+ Return a new dataset that contains values only at the selected indexes.
1564+ This is much like doing a :func:`xarray.Dataset.isel()` on some indexes,
1565+ but works with convention native index types.
1566+
1567+ An index is associated with a grid kind.
1568+ The returned dataset will only contain variables that were defined on this grid,
1569+ with the single indexed point selected.
1570+ For example, if the index of a face is passed in,
1571+ the returned dataset will not contain any variables defined on an edge.
1572+
1573+ Parameters
1574+ ----------
1575+ index : :data:`Index`
1576+ The index to select.
1577+ The index must be for the default grid kind for this dataset.
1578+ index_dimension : str, optional
1579+ The name of the new dimension added for each index to select.
1580+ Defaults to the :func:`first unused dimension <.utils.find_unused_dimension>` with prefix `index`.
1581+ drop_geometry : bool, default True
1582+ Whether to drop geometry variables from the returned point dataset.
1583+ If the geometry data is kept
1584+ the associated geometry data will no longer conform to the dataset convention
1585+ and may not conform to any sensible convention at all.
1586+ The format of the geometry data left after selecting points is convention-dependent.
15191587
15201588 Returns
15211589 -------
@@ -1529,16 +1597,21 @@ def select_index(
15291597 to be used with a particular :class:`Convention` any more.
15301598 The ``dataset.ems`` accessor will raise an error if accessed on the new dataset.
15311599 """
1532- selector = self .selector_for_index ( index )
1600+ selector = self .selector_for_indexes ( indexes , index_dimension = index_dimension )
15331601
15341602 # Make a new dataset consisting of only data arrays that use at least
15351603 # one of these dimensions.
1536- dims = set (selector .keys ())
1604+ if drop_geometry :
1605+ dataset = self .drop_geometry ()
1606+ else :
1607+ dataset = self .dataset
1608+
1609+ dims = set (selector .variables .keys ())
15371610 names = [
1538- name for name , data_array in self . dataset .items ()
1611+ name for name , data_array in dataset .items ()
15391612 if dims .intersection (data_array .dims )
15401613 ]
1541- dataset = utils .extract_vars (self . dataset , names )
1614+ dataset = utils .extract_vars (dataset , names )
15421615
15431616 # Select just this point
15441617 return dataset .isel (selector )
@@ -1565,6 +1638,38 @@ def select_point(self, point: Point) -> xarray.Dataset:
15651638 raise ValueError ("Point did not intersect dataset" )
15661639 return self .select_index (index .index )
15671640
1641+ def select_points (
1642+ self ,
1643+ points : list [Point ],
1644+ * ,
1645+ point_dimension : Hashable | None = None ,
1646+ missing_points : Literal ['error' , 'drop' ] = 'error' ,
1647+ ) -> xarray .Dataset :
1648+ """
1649+ Extract values from all variables on the default grid at a sequence of points.
1650+
1651+ Parameters
1652+ ----------
1653+ points : list of shapely.Point
1654+ The points to extract
1655+ point_dimension : str, optional
1656+ The name of the new dimension used to index points.
1657+ Defaults to 'point', or 'point_0', 'point_1', etc if those dimensions already exist.
1658+ missing_points : {'error', 'drop'}, default 'error'
1659+ What to do if a point does not intersect the dataset.
1660+ 'raise' will raise an error, while 'drop' will drop those points.
1661+
1662+ Returns
1663+ -------
1664+ xarray.Dataset
1665+ A dataset with values extracted from the points.
1666+ No variables not defined on the default grid and no geometry variables will be present.
1667+ """
1668+ if point_dimension is None :
1669+ point_dimension = utils .find_unused_dimension (self .dataset , 'point' )
1670+ return point_extraction .extract_points (
1671+ self .dataset , points , point_dimension = point_dimension , missing_points = missing_points )
1672+
15681673 @abc .abstractmethod
15691674 def get_all_geometry_names (self ) -> list [Hashable ]:
15701675 """
@@ -1936,7 +2041,30 @@ def wind(
19362041 dimensions = dimensions , sizes = sizes ,
19372042 linear_dimension = linear_dimension )
19382043
1939- def selector_for_index (self , index : Index ) -> dict [Hashable , int ]:
1940- grid_kind , indexes = self .unpack_index (index )
2044+ def selector_for_indexes (
2045+ self ,
2046+ indexes : list [Index ],
2047+ * ,
2048+ index_dimension : Hashable | None = None ,
2049+ ) -> xarray .Dataset :
2050+ if index_dimension is None :
2051+ index_dimension = utils .find_unused_dimension (self .dataset , 'index' )
2052+ if len (indexes ) == 0 :
2053+ raise ValueError ("Need at least one index to select" )
2054+
2055+ grid_kinds , index_tuples = zip (* [self .unpack_index (index ) for index in indexes ])
2056+
2057+ unique_grid_kinds = set (grid_kinds )
2058+ if len (unique_grid_kinds ) > 1 :
2059+ raise ValueError (
2060+ "All indexes must be on the same grid kind, got "
2061+ + ", " .join (map (repr , unique_grid_kinds )))
2062+
2063+ grid_kind = grid_kinds [0 ]
19412064 dimensions = self .grid_dimensions [grid_kind ]
1942- return dict (zip (dimensions , indexes ))
2065+ # This array will have shape (len(indexes), len(dimensions))
2066+ index_array = numpy .array (index_tuples )
2067+ return xarray .Dataset ({
2068+ dimension : (index_dimension , index_array [:, i ])
2069+ for i , dimension in enumerate (dimensions )
2070+ })
0 commit comments