Skip to content

Commit ae95a93

Browse files
committed
Expand more functions that accept DataArrayOrName
Two new functions `emsarray.utils.name_to_data_array` and `emsarray.utils.data_array_to_name` have been added to accomplish this. `Convention._get_data_array()` has been deprecaterd in favour of `name_to_data_array`.
1 parent e95a620 commit ae95a93

File tree

10 files changed

+243
-93
lines changed

10 files changed

+243
-93
lines changed

docs/api/types.rst

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
.. module:: emsarray.types
2-
31
==============
42
emsarray.types
53
==============
64

75
.. automodule:: emsarray.types
8-
:noindex:
96
:members:

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
napoleon_type_aliases = {
7575
'xarray.core.dataset.Dataset': ':class:`~xarray.Dataset',
7676
'xarray.core.dataarray.DataArray': ':class:`~xarray.DataArray',
77+
'DataArrayOrName': ':data:`data array or name <emsarray.types.DataArrayOrName>`',
7778
}
7879

7980

docs/releases/development.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,6 @@ Next release (in development)
3030
(:pr:`137`).
3131
* Lint Python code in `docs/` and `scripts/`
3232
(:pr:`141`).
33+
* Add :func:`emsarray.utils.name_to_data_array()` and :func:`~emsarray.utils.data_array_to_name()` functions.
34+
Allow more functions to interchangeably take either a data array or the name of a data array
35+
(:pr:`142`).

src/emsarray/conventions/_base.py

Lines changed: 30 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
polygons_to_collection
2626
)
2727
from emsarray.state import State
28-
from emsarray.types import Bounds, Pathish
28+
from emsarray.types import Bounds, DataArrayOrName, Pathish
2929

3030
if TYPE_CHECKING:
3131
# Import these optional dependencies only during type checking
@@ -39,8 +39,6 @@
3939
logger = logging.getLogger(__name__)
4040

4141

42-
DataArrayOrName = Union[Hashable, xarray.DataArray]
43-
4442
#: Some type that can enumerate the different :ref:`grid types <grids>`
4543
#: present in a dataset.
4644
#: This can be an :class:`enum.Enum` listing each different kind of grid.
@@ -266,19 +264,15 @@ def bind(self) -> None:
266264
"cannot assign a new convention.")
267265
state.bind_convention(self)
268266

269-
@abc.abstractmethod
267+
@utils.deprecated(
268+
(
269+
"Convention._get_data_array() has been deprecated. "
270+
"Use emsarray.utils.name_to_data_array() instead."
271+
),
272+
DeprecationWarning,
273+
)
270274
def _get_data_array(self, data_array: DataArrayOrName) -> xarray.DataArray:
271-
"""
272-
Utility to help get a data array for this dataset.
273-
If a string is passed in, the matching data array is fetched from the dataset.
274-
If a data array is passed in,
275-
it is inspected to ensure the surface dimensions align
276-
before being returned as-is.
277-
278-
This is useful for methods that support being passed either
279-
the name of a data array or a data array instance.
280-
"""
281-
pass
275+
return utils.name_to_data_array(self.dataset, data_array)
282276

283277
@cached_property
284278
def time_coordinate(self) -> xarray.DataArray:
@@ -902,10 +896,10 @@ def plot_on_figure(
902896
----------
903897
figure : matplotlib.figure.Figure
904898
The :class:`~matplotlib.figure.Figure` instance to plot this on.
905-
scalar : xarray.DataArray or str
899+
scalar : DataArrayOrName
906900
The :class:`~xarray.DataArray` to plot,
907901
or the name of an existing DataArray in this Dataset.
908-
vector : tuple of xarray.DataArray or str
902+
vector : tuple of DataArrayOrName
909903
A tuple of the *u* and *v* components of a vector.
910904
The components should be a :class:`~xarray.DataArray`,
911905
or the name of an existing DataArray in this Dataset.
@@ -918,10 +912,10 @@ def plot_on_figure(
918912
:func:`.plot.plot_on_figure` : The underlying implementation
919913
"""
920914
if scalar is not None:
921-
kwargs['scalar'] = self._get_data_array(scalar)
915+
kwargs['scalar'] = utils.name_to_data_array(self.dataset, scalar)
922916

923917
if vector is not None:
924-
kwargs['vector'] = tuple(map(self._get_data_array, vector))
918+
kwargs['vector'] = tuple(utils.name_to_data_array(self.dataset, v) for v in vector)
925919

926920
if title is not None:
927921
kwargs['title'] = title
@@ -977,7 +971,7 @@ def animate_on_figure(
977971
----------
978972
figure : matplotlib.figure.Figure
979973
The :class:`matplotlib.figure.Figure` to plot the animation on
980-
data_array : Hashable or xarray.DataArray
974+
data_array : DataArrayOrName
981975
The :class:`xarray.DataArray` to plot.
982976
If a string is passed in,
983977
the variable with that name is taken from :attr:`dataset`.
@@ -1006,26 +1000,26 @@ def animate_on_figure(
10061000

10071001
if coordinate is None:
10081002
# Assume the user wants to plot along the time axis by default.
1009-
coordinate = self.get_time_name()
1010-
if isinstance(coordinate, xarray.DataArray):
1011-
utils.check_data_array_dimensions_match(
1012-
self.dataset, coordinate, dimensions=coordinate.dims)
1013-
1014-
coordinate = self._get_data_array(coordinate)
1003+
coordinate = self.time_coordinate
1004+
else:
1005+
coordinate = utils.name_to_data_array(self.dataset, coordinate)
10151006

10161007
if len(coordinate.dims) != 1:
10171008
raise ValueError("Coordinate variable must be one dimensional")
10181009

10191010
coordinate_dim = coordinate.dims[0]
10201011

10211012
if scalar is not None:
1022-
scalar = self._get_data_array(scalar)
1013+
scalar = utils.name_to_data_array(self.dataset, scalar)
10231014
if coordinate_dim not in scalar.dims:
10241015
raise ValueError("Scalar dimensions do not match coordinate axis to animate along")
10251016
kwargs['scalar'] = scalar
10261017

10271018
if vector is not None:
1028-
vector = (self._get_data_array(vector[0]), self._get_data_array(vector[1]))
1019+
vector = (
1020+
utils.name_to_data_array(self.dataset, vector[0]),
1021+
utils.name_to_data_array(self.dataset, vector[1]),
1022+
)
10291023
if not all(coordinate_dim in component.dims for component in vector):
10301024
raise ValueError("Vector dimensions do not match coordinate axis to animate along")
10311025
kwargs['vector'] = vector
@@ -1119,7 +1113,7 @@ def make_poly_collection(
11191113
"Can not pass both `data_array` and `array` to make_poly_collection"
11201114
)
11211115

1122-
data_array = self._get_data_array(data_array)
1116+
data_array = utils.name_to_data_array(self.dataset, data_array)
11231117

11241118
data_array = self.ravel(data_array)
11251119
if len(data_array.dims) > 1:
@@ -1189,7 +1183,8 @@ def make_quiver(
11891183
values = numpy.nan, numpy.nan
11901184

11911185
if u is not None and v is not None:
1192-
u, v = self._get_data_array(u), self._get_data_array(v)
1186+
u = utils.name_to_data_array(self.dataset, u)
1187+
v = utils.name_to_data_array(self.dataset, v)
11931188

11941189
if u.dims != v.dims:
11951190
raise ValueError(
@@ -1539,15 +1534,15 @@ def drop_geometry(self) -> xarray.Dataset:
15391534
"""
15401535
return self.dataset.drop_vars(self.get_all_geometry_names())
15411536

1542-
def select_variables(self, variables: Iterable[Hashable]) -> xarray.Dataset:
1537+
def select_variables(self, variables: Iterable[DataArrayOrName]) -> xarray.Dataset:
15431538
"""Select only a subset of the variables in this dataset, dropping all others.
15441539
15451540
This will keep all coordinate variables and all geometry variables.
15461541
15471542
Parameters
15481543
----------
1549-
variables : iterable of Hashable
1550-
The names of all data variables to select.
1544+
variables : iterable of DataArrayOrName
1545+
The data variables to select.
15511546
15521547
Returns
15531548
-------
@@ -1570,6 +1565,7 @@ def select_variables(self, variables: Iterable[Hashable]) -> xarray.Dataset:
15701565
keep_vars.add(self.get_time_name())
15711566
except NoSuchCoordinateError:
15721567
pass
1568+
keep_vars = {utils.data_array_to_name(self.dataset, v) for v in keep_vars}
15731569
return self.dataset.drop_vars(all_vars - keep_vars)
15741570

15751571
@abc.abstractmethod
@@ -1701,7 +1697,7 @@ def ocean_floor(self) -> xarray.Dataset:
17011697
"""An alias for :func:`emsarray.operations.depth.ocean_floor`"""
17021698
return depth.ocean_floor(
17031699
self.dataset, self.get_all_depth_names(),
1704-
non_spatial_variables=[self.get_time_name()])
1700+
non_spatial_variables=[self.time_coordinate])
17051701

17061702
def normalize_depth_variables(
17071703
self,
@@ -1786,23 +1782,6 @@ def get_grid_kind(self, data_array: xarray.DataArray) -> GridKind:
17861782
return kind
17871783
raise ValueError("Unknown grid kind")
17881784

1789-
def _get_data_array(self, data_array: DataArrayOrName) -> xarray.DataArray:
1790-
if isinstance(data_array, xarray.DataArray):
1791-
grid_kind = self.get_grid_kind(data_array)
1792-
grid_dimensions = self.grid_dimensions[grid_kind]
1793-
for dimension in grid_dimensions:
1794-
# The data array already has matching dimension names
1795-
# as we found the grid kind using `Convention.get_grid_kind()`.
1796-
if self.dataset.sizes[dimension] != data_array.sizes[dimension]:
1797-
raise ValueError(
1798-
f"Mismatched dimension {dimension!r}, "
1799-
"dataset has size {self.dataset.sizes[dimension]} but "
1800-
"data array has size {data_array.sizes[dimension]}!"
1801-
)
1802-
return data_array
1803-
else:
1804-
return self.dataset[data_array]
1805-
18061785
@abc.abstractmethod
18071786
def unpack_index(self, index: Index) -> tuple[GridKind, Sequence[int]]:
18081787
"""Convert a native index in to a grid kind and dimension indices.

src/emsarray/operations/depth.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,20 @@
55
import warnings
66
from collections import defaultdict
77
from collections.abc import Hashable
8-
from typing import Optional, cast
8+
from typing import Iterable, Optional, cast
99

1010
import numpy
1111
import xarray
1212

1313
from emsarray import utils
14+
from emsarray.types import DataArrayOrName
1415

1516

1617
def ocean_floor(
1718
dataset: xarray.Dataset,
18-
depth_variables: list[Hashable],
19+
depth_coordinates: Iterable[DataArrayOrName],
1920
*,
20-
non_spatial_variables: Optional[list[Hashable]] = None,
21+
non_spatial_variables: Optional[Iterable[DataArrayOrName]] = None,
2122
) -> xarray.Dataset:
2223
"""Make a new :class:`xarray.Dataset` reduced along the given depth
2324
coordinates to only contain values along the ocean floor.
@@ -26,12 +27,12 @@ def ocean_floor(
2627
----------
2728
dataset
2829
The dataset to reduce.
29-
depth_variables
30-
The names of depth coordinate variables.
30+
depth_coordinates : iterable of DataArrayOrName
31+
The depth coordinate variables.
3132
For supported conventions, use :meth:`.Convention.get_all_depth_names()`.
32-
non_spatial_variables
33+
non_spatial_variables : iterable of DataArrayOrName
3334
Optional.
34-
A list of the names of any non-spatial coordinate variables, such as time.
35+
A list of any non-spatial coordinate variables, such as time.
3536
The ocean floor is assumed to be static across non-spatial dimensions.
3637
For supported conventions, use :meth:`.Convention.get_time_name()`.
3738
@@ -74,7 +75,7 @@ def ocean_floor(
7475
7576
>>> operations.ocean_floor(
7677
... big_dataset['temp'].isel(record=0).to_dataset(),
77-
... depth_variables=big_dataset.ems.get_all_depth_names())
78+
... depth_coordinates=big_dataset.ems.get_all_depth_names())
7879
<xarray.Dataset>
7980
Dimensions: (y: 5, x: 5)
8081
Coordinates:
@@ -106,15 +107,17 @@ def ocean_floor(
106107
# and that for a combination of depth dimension and spatial dimensions,
107108
# the ocean floor is static.
108109

110+
depth_coordinates = list(depth_coordinates)
111+
109112
dataset = normalize_depth_variables(
110-
dataset, depth_variables,
113+
dataset, depth_coordinates,
111114
positive_down=True, deep_to_shallow=False)
112115

113116
if non_spatial_variables is None:
114117
non_spatial_variables = []
115118

116119
# The name of all the relevant _dimensions_, not _coordinates_
117-
depth_dimensions = utils.dimensions_from_coords(dataset, depth_variables)
120+
depth_dimensions = utils.dimensions_from_coords(dataset, depth_coordinates)
118121
non_spatial_dimensions = utils.dimensions_from_coords(dataset, non_spatial_variables)
119122

120123
for depth_dimension in sorted(depth_dimensions, key=hash):
@@ -196,7 +199,7 @@ def _find_ocean_floor_indices(
196199

197200
def normalize_depth_variables(
198201
dataset: xarray.Dataset,
199-
depth_variables: list[Hashable],
202+
depth_variables: Iterable[DataArrayOrName],
200203
*,
201204
positive_down: Optional[bool] = None,
202205
deep_to_shallow: Optional[bool] = None,
@@ -216,10 +219,8 @@ def normalize_depth_variables(
216219
----------
217220
dataset : xarray.Dataset
218221
The dataset to normalize
219-
depth_variables : list of Hashable
220-
The names of the depth coordinate variables.
221-
This should be the names of the variables, not the dimensions,
222-
for datasets where these differ.
222+
depth_variables : iterable of DataArrayOrName
223+
The depth coordinate variables.
223224
positive_down : bool, optional
224225
If True, positive values will indicate depth below the surface.
225226
If False, negative values indicate depth below the surface.
@@ -240,8 +241,9 @@ def normalize_depth_variables(
240241
:meth:`.Convention.get_all_depth_names`
241242
"""
242243
new_dataset = dataset.copy()
243-
for name in depth_variables:
244-
variable = dataset[name]
244+
for variable in depth_variables:
245+
variable = utils.name_to_data_array(dataset, variable)
246+
name = variable.name
245247
if len(variable.dims) != 1:
246248
raise ValueError(
247249
f"Can't normalize multidimensional depth variable {name!r} "

src/emsarray/transect.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import dataclasses
2-
from collections.abc import Hashable, Iterable
2+
from collections.abc import Iterable
33
from functools import cached_property
44
from typing import Any, Callable, Generic, Optional, Union, cast
55

@@ -18,8 +18,8 @@
1818

1919
from emsarray.conventions import Convention, Index
2020
from emsarray.plot import _requires_plot, make_plot_title
21-
from emsarray.types import Landmark
22-
from emsarray.utils import move_dimensions_to_end
21+
from emsarray.types import DataArrayOrName, Landmark
22+
from emsarray.utils import move_dimensions_to_end, name_to_data_array
2323

2424
# Useful for calculating distances in a AzimuthalEquidistant projection
2525
# centred on some point:
@@ -116,13 +116,13 @@ def __init__(
116116
self,
117117
dataset: xarray.Dataset,
118118
line: shapely.LineString,
119-
depth: Optional[Union[Hashable, xarray.DataArray]] = None,
119+
depth: Optional[DataArrayOrName] = None,
120120
):
121121
self.dataset = dataset
122122
self.convention = dataset.ems
123123
self.line = line
124124
if depth is not None:
125-
self.depth = self.convention._get_data_array(depth)
125+
self.depth = name_to_data_array(dataset, depth)
126126
else:
127127
self.depth = self.convention.depth_coordinate
128128

src/emsarray/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
"""
44

55
import os
6+
from collections.abc import Hashable
67
from typing import Union
78

89
import shapely
10+
import xarray
911

1012
#: Something that can be used as a path.
1113
Pathish = Union[os.PathLike, str]
@@ -17,3 +19,6 @@
1719
#: A landmark for a plot.
1820
#: This is a tuple of the landmark name and and its location.
1921
Landmark = tuple[str, shapely.Point]
22+
23+
#: Either an :class:`xarray.DataArray`, or the name of a data array in a dataset.
24+
DataArrayOrName = Union[Hashable, xarray.DataArray]

0 commit comments

Comments
 (0)