2525 polygons_to_collection
2626)
2727from emsarray .state import State
28- from emsarray .types import Bounds , Pathish
28+ from emsarray .types import Bounds , DataArrayOrName , Pathish
2929
3030if TYPE_CHECKING :
3131 # Import these optional dependencies only during type checking
3939logger = logging .getLogger (__name__ )
4040
4141
42- DataArrayOrName = Union [Hashable , xarray .DataArray ]
43-
4442#: Some type that can enumerate the different :ref:`grid types <grids>`
4543#: present in a dataset.
4644#: This can be an :class:`enum.Enum` listing each different kind of grid.
@@ -266,19 +264,15 @@ def bind(self) -> None:
266264 "cannot assign a new convention." )
267265 state .bind_convention (self )
268266
269- @abc .abstractmethod
267+ @utils .deprecated (
268+ (
269+ "Convention._get_data_array() has been deprecated. "
270+ "Use emsarray.utils.name_to_data_array() instead."
271+ ),
272+ DeprecationWarning ,
273+ )
270274 def _get_data_array (self , data_array : DataArrayOrName ) -> xarray .DataArray :
271- """
272- Utility to help get a data array for this dataset.
273- If a string is passed in, the matching data array is fetched from the dataset.
274- If a data array is passed in,
275- it is inspected to ensure the surface dimensions align
276- before being returned as-is.
277-
278- This is useful for methods that support being passed either
279- the name of a data array or a data array instance.
280- """
281- pass
275+ return utils .name_to_data_array (self .dataset , data_array )
282276
283277 @cached_property
284278 def time_coordinate (self ) -> xarray .DataArray :
@@ -902,10 +896,10 @@ def plot_on_figure(
902896 ----------
903897 figure : matplotlib.figure.Figure
904898 The :class:`~matplotlib.figure.Figure` instance to plot this on.
905- scalar : xarray.DataArray or str
899+ scalar : DataArrayOrName
906900 The :class:`~xarray.DataArray` to plot,
907901 or the name of an existing DataArray in this Dataset.
908- vector : tuple of xarray.DataArray or str
902+ vector : tuple of DataArrayOrName
909903 A tuple of the *u* and *v* components of a vector.
910904 The components should be a :class:`~xarray.DataArray`,
911905 or the name of an existing DataArray in this Dataset.
@@ -918,10 +912,10 @@ def plot_on_figure(
918912 :func:`.plot.plot_on_figure` : The underlying implementation
919913 """
920914 if scalar is not None :
921- kwargs ['scalar' ] = self . _get_data_array ( scalar )
915+ kwargs ['scalar' ] = utils . name_to_data_array ( self . dataset , scalar )
922916
923917 if vector is not None :
924- kwargs ['vector' ] = tuple (map (self ._get_data_array , vector ) )
918+ kwargs ['vector' ] = tuple (utils . name_to_data_array (self .dataset , v ) for v in vector )
925919
926920 if title is not None :
927921 kwargs ['title' ] = title
@@ -977,7 +971,7 @@ def animate_on_figure(
977971 ----------
978972 figure : matplotlib.figure.Figure
979973 The :class:`matplotlib.figure.Figure` to plot the animation on
980- data_array : Hashable or xarray.DataArray
974+ data_array : DataArrayOrName
981975 The :class:`xarray.DataArray` to plot.
982976 If a string is passed in,
983977 the variable with that name is taken from :attr:`dataset`.
@@ -1006,26 +1000,26 @@ def animate_on_figure(
10061000
10071001 if coordinate is None :
10081002 # Assume the user wants to plot along the time axis by default.
1009- coordinate = self .get_time_name ()
1010- if isinstance (coordinate , xarray .DataArray ):
1011- utils .check_data_array_dimensions_match (
1012- self .dataset , coordinate , dimensions = coordinate .dims )
1013-
1014- coordinate = self ._get_data_array (coordinate )
1003+ coordinate = self .time_coordinate
1004+ else :
1005+ coordinate = utils .name_to_data_array (self .dataset , coordinate )
10151006
10161007 if len (coordinate .dims ) != 1 :
10171008 raise ValueError ("Coordinate variable must be one dimensional" )
10181009
10191010 coordinate_dim = coordinate .dims [0 ]
10201011
10211012 if scalar is not None :
1022- scalar = self . _get_data_array ( scalar )
1013+ scalar = utils . name_to_data_array ( self . dataset , scalar )
10231014 if coordinate_dim not in scalar .dims :
10241015 raise ValueError ("Scalar dimensions do not match coordinate axis to animate along" )
10251016 kwargs ['scalar' ] = scalar
10261017
10271018 if vector is not None :
1028- vector = (self ._get_data_array (vector [0 ]), self ._get_data_array (vector [1 ]))
1019+ vector = (
1020+ utils .name_to_data_array (self .dataset , vector [0 ]),
1021+ utils .name_to_data_array (self .dataset , vector [1 ]),
1022+ )
10291023 if not all (coordinate_dim in component .dims for component in vector ):
10301024 raise ValueError ("Vector dimensions do not match coordinate axis to animate along" )
10311025 kwargs ['vector' ] = vector
@@ -1119,7 +1113,7 @@ def make_poly_collection(
11191113 "Can not pass both `data_array` and `array` to make_poly_collection"
11201114 )
11211115
1122- data_array = self . _get_data_array ( data_array )
1116+ data_array = utils . name_to_data_array ( self . dataset , data_array )
11231117
11241118 data_array = self .ravel (data_array )
11251119 if len (data_array .dims ) > 1 :
@@ -1189,7 +1183,8 @@ def make_quiver(
11891183 values = numpy .nan , numpy .nan
11901184
11911185 if u is not None and v is not None :
1192- u , v = self ._get_data_array (u ), self ._get_data_array (v )
1186+ u = utils .name_to_data_array (self .dataset , u )
1187+ v = utils .name_to_data_array (self .dataset , v )
11931188
11941189 if u .dims != v .dims :
11951190 raise ValueError (
@@ -1539,15 +1534,15 @@ def drop_geometry(self) -> xarray.Dataset:
15391534 """
15401535 return self .dataset .drop_vars (self .get_all_geometry_names ())
15411536
1542- def select_variables (self , variables : Iterable [Hashable ]) -> xarray .Dataset :
1537+ def select_variables (self , variables : Iterable [DataArrayOrName ]) -> xarray .Dataset :
15431538 """Select only a subset of the variables in this dataset, dropping all others.
15441539
15451540 This will keep all coordinate variables and all geometry variables.
15461541
15471542 Parameters
15481543 ----------
1549- variables : iterable of Hashable
1550- The names of all data variables to select.
1544+ variables : iterable of DataArrayOrName
1545+ The data variables to select.
15511546
15521547 Returns
15531548 -------
@@ -1570,6 +1565,7 @@ def select_variables(self, variables: Iterable[Hashable]) -> xarray.Dataset:
15701565 keep_vars .add (self .get_time_name ())
15711566 except NoSuchCoordinateError :
15721567 pass
1568+ keep_vars = {utils .data_array_to_name (self .dataset , v ) for v in keep_vars }
15731569 return self .dataset .drop_vars (all_vars - keep_vars )
15741570
15751571 @abc .abstractmethod
@@ -1701,7 +1697,7 @@ def ocean_floor(self) -> xarray.Dataset:
17011697 """An alias for :func:`emsarray.operations.depth.ocean_floor`"""
17021698 return depth .ocean_floor (
17031699 self .dataset , self .get_all_depth_names (),
1704- non_spatial_variables = [self .get_time_name () ])
1700+ non_spatial_variables = [self .time_coordinate ])
17051701
17061702 def normalize_depth_variables (
17071703 self ,
@@ -1786,23 +1782,6 @@ def get_grid_kind(self, data_array: xarray.DataArray) -> GridKind:
17861782 return kind
17871783 raise ValueError ("Unknown grid kind" )
17881784
1789- def _get_data_array (self , data_array : DataArrayOrName ) -> xarray .DataArray :
1790- if isinstance (data_array , xarray .DataArray ):
1791- grid_kind = self .get_grid_kind (data_array )
1792- grid_dimensions = self .grid_dimensions [grid_kind ]
1793- for dimension in grid_dimensions :
1794- # The data array already has matching dimension names
1795- # as we found the grid kind using `Convention.get_grid_kind()`.
1796- if self .dataset .sizes [dimension ] != data_array .sizes [dimension ]:
1797- raise ValueError (
1798- f"Mismatched dimension { dimension !r} , "
1799- "dataset has size {self.dataset.sizes[dimension]} but "
1800- "data array has size {data_array.sizes[dimension]}!"
1801- )
1802- return data_array
1803- else :
1804- return self .dataset [data_array ]
1805-
18061785 @abc .abstractmethod
18071786 def unpack_index (self , index : Index ) -> tuple [GridKind , Sequence [int ]]:
18081787 """Convert a native index in to a grid kind and dimension indices.
0 commit comments