Skip to content

Commit 5c5f161

Browse files
authored
Merge pull request #92 from csiro-coasts/convention-coordinates
Convention coordinates
2 parents 5de0975 + f4ea06b commit 5c5f161

File tree

8 files changed

+238
-23
lines changed

8 files changed

+238
-23
lines changed

docs/releases/development.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,7 @@ Next release (in development)
99
are dropped from the returned dataset,
1010
or filled with a sensible fill value
1111
(:pr:`90`).
12+
* Align automatic coordinate detection of time and depth with CF Conventions.
13+
Add :attr:`.Convention.time_coordinate` and :attr:`.Convention.depth_coordinate`,
14+
deprecate :meth:`.Convention.get_times()` and :meth:`.Convention.get_depths()`
15+
(:pr:`92`).

src/emsarray/conventions/_base.py

Lines changed: 133 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from emsarray import utils
2121
from emsarray.compat.shapely import SpatialIndex
22+
from emsarray.exceptions import NoSuchCoordinateError
2223
from emsarray.operations import depth
2324
from emsarray.plot import (
2425
_requires_plot, animate_on_figure, plot_on_figure, polygons_to_collection
@@ -282,22 +283,122 @@ def _get_data_array(self, data_array: DataArrayOrName) -> xr.DataArray:
282283
else:
283284
return self.dataset[data_array]
284285

286+
@cached_property
287+
def time_coordinate(self) -> xr.DataArray:
288+
"""The time coordinate for this dataset.
289+
290+
Returns
291+
-------
292+
xarray.DataArray
293+
The variable for the time coordinate for this dataset.
294+
295+
Raises
296+
------
297+
exceptions.NoSuchCoordinateError
298+
If no time coordinate was found
299+
300+
See Also
301+
--------
302+
get_time_name
303+
"""
304+
return self.dataset[self.get_time_name()]
305+
306+
@cached_property
307+
def depth_coordinate(self) -> xr.DataArray:
308+
"""The depth coordinate for this dataset.
309+
310+
Returns
311+
-------
312+
xarray.DataArray
313+
The variable for the depth coordinate for this dataset.
314+
315+
Raises
316+
------
317+
exceptions.NoSuchCoordinateError
318+
If no depth coordinate was found
319+
320+
See Also
321+
--------
322+
get_depth_name
323+
"""
324+
return self.dataset[self.get_depth_name()]
325+
285326
def get_time_name(self) -> Hashable:
286-
"""Get the name of the time variable in this dataset."""
327+
"""Get the name of the time variable in this dataset.
328+
329+
Returns
330+
-------
331+
Hashable
332+
The name of the time coordinate.
333+
334+
Raises
335+
------
336+
exceptions.NoSuchCoordinateError
337+
If no time coordinate was found
338+
339+
Notes
340+
-----
341+
The CF Conventions state that
342+
a time variable is defined by having a `units` attribute
343+
formatted according to the UDUNITS package [1]_.
344+
345+
xarray will find all time variables and convert them to numpy datetimes.
346+
347+
References
348+
----------
349+
.. [1] `CF Conventions v1.10, 4.4 Time Coordinate <https://cfconventions.org/Data/cf-conventions/cf-conventions-1.10/cf-conventions.html#time-coordinate>`_
350+
"""
287351
for name, variable in self.dataset.variables.items():
288-
if variable.attrs.get('standard_name') == 'time':
289-
return name
290-
raise KeyError("Dataset does not have a time dimension")
352+
# xarray will automatically decode all time variables
353+
# and move the 'units' attribute over to encoding to store this change.
354+
if 'units' in variable.encoding:
355+
units = variable.encoding['units']
356+
# A time variable must have units of the form '<units> since <epoc>'
357+
if 'since' in units:
358+
# The variable must now be a numpy datetime
359+
if variable.dtype.type == np.datetime64:
360+
return name
361+
raise NoSuchCoordinateError("Could not find time coordinate in dataset")
291362

292363
def get_depth_name(self) -> Hashable:
293364
"""Get the name of the layer depth coordinate variable.
365+
294366
For datasets with multiple depth variables, this should be the one that
295367
represents the centre of the layer, not the bounds.
296368
297369
Note that this is the name of the coordinate variable,
298370
not the name of the dimension, for datasets where these differ.
371+
372+
Returns
373+
-------
374+
Hashable
375+
The name of the depth coordinate.
376+
377+
Raises
378+
------
379+
exceptions.NoSuchCoordinateError
380+
If no time coordinate was found
381+
382+
Notes
383+
-----
384+
The CF Conventions state that
385+
a depth variable is identifiable by units of pressure; or
386+
the presence of the ``positive`` attribute with value of ``up`` or ``down``
387+
[2]_.
388+
389+
In practice, many datasets do not follow this convention.
390+
In addition to checking for the ``positive`` attribute,
391+
all coordinates are checked for a ``standard_name: "depth"``,
392+
``coordinate_type: "Z"``, or ``axiz: "Z"``.
393+
394+
References
395+
----------
396+
.. [2] `CF Conventions v1.10, 4.3 Vertical (Height or Depth) Coordinate <https://cfconventions.org/Data/cf-conventions/cf-conventions-1.10/cf-conventions.html#vertical-coordinate>`_
299397
"""
300-
return self.get_all_depth_names()[0]
398+
try:
399+
return self.get_all_depth_names()[0]
400+
except IndexError:
401+
raise NoSuchCoordinateError("Could not find depth coordinate in dataset")
301402

302403
def get_all_depth_names(self) -> List[Hashable]:
303404
"""Get the names of all depth layers.
@@ -312,7 +413,8 @@ def get_all_depth_names(self) -> List[Hashable]:
312413
data_array = self.dataset[name]
313414

314415
if not (
315-
data_array.attrs.get('axis') == 'Z'
416+
data_array.attrs.get('positive', '').lower() in {'up', 'down'}
417+
or data_array.attrs.get('axis') == 'Z'
316418
or data_array.attrs.get('cartesian_axis') == 'Z'
317419
or data_array.attrs.get('coordinate_type') == 'Z'
318420
or data_array.attrs.get('standard_name') == 'depth'
@@ -333,27 +435,49 @@ def get_all_depth_names(self) -> List[Hashable]:
333435

334436
return depth_names
335437

438+
@utils.deprecated(
439+
(
440+
"Convention.get_depths() is deprecated. "
441+
"Use Convention.depth_coordinate.values instead."
442+
),
443+
DeprecationWarning,
444+
)
336445
def get_depths(self) -> np.ndarray:
337446
"""Get the depth of each vertical layer in this dataset.
338447
448+
.. deprecated:: 0.5.0
449+
This method is replaced by
450+
:attr:`Convention.depth_coordinate.values <Convention.depth_coordinate>`.
451+
339452
Returns
340453
-------
341454
:class:`numpy.ndarray`
342455
An array of depths, one per vertical layer in the dataset.
343456
"""
344-
return cast(np.ndarray, self.dataset.variables[self.get_depth_name()].values)
345-
457+
return cast(np.ndarray, self.depth_coordinate.values)
458+
459+
@utils.deprecated(
460+
(
461+
"Convention.get_times() is deprecated. "
462+
"Use Convention.time_coordinate.values instead."
463+
),
464+
DeprecationWarning,
465+
)
346466
def get_times(self) -> np.ndarray:
347467
"""Get all timesteps in this dataset.
348468
469+
.. deprecated:: 0.5.0
470+
This method is replaced by
471+
:attr:`Convention.time_coordinate.values <Convention.time_coordinate>`.
472+
349473
Returns
350474
-------
351475
:class:`numpy.ndarray`
352476
An array of datetimes.
353477
The datetimes will be whatever native format the dataset uses,
354478
likely :class:`numpy.datetime64`.
355479
"""
356-
return cast(np.ndarray, self.dataset.variables[self.get_time_name()].values)
480+
return cast(np.ndarray, self.time_coordinate.values)
357481

358482
@abc.abstractmethod
359483
def ravel_index(self, index: Index) -> int:

src/emsarray/conventions/shoc.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
import xarray as xr
2424

25+
from emsarray.exceptions import NoSuchCoordinateError
26+
2527
from ._base import Specificity
2628
from .arakawa_c import ArakawaC, ArakawaCGridKind
2729
from .grid import CFGrid2D, CFGrid2DTopology
@@ -48,13 +50,21 @@ class ShocStandard(ArakawaC):
4850
}
4951

5052
def get_depth_name(self) -> Hashable:
51-
return 'z_centre'
53+
name = 'z_centre'
54+
if name not in self.dataset.variables:
55+
raise NoSuchCoordinateError(
56+
f"SHOC dataset did not have expected depth coordinate {name!r}")
57+
return name
5258

5359
def get_all_depth_names(self) -> List[Hashable]:
5460
return ['z_centre', 'z_grid']
5561

5662
def get_time_name(self) -> Hashable:
57-
return 't'
63+
name = 't'
64+
if name not in self.dataset.variables:
65+
raise NoSuchCoordinateError(
66+
f"SHOC dataset did not have expected time coordinate {name!r}")
67+
return name
5868

5969
def drop_geometry(self) -> xr.Dataset:
6070
dataset = super().drop_geometry()
@@ -99,10 +109,18 @@ def check_dataset(cls, dataset: xr.Dataset) -> Optional[int]:
99109
return Specificity.HIGH
100110

101111
def get_time_name(self) -> Hashable:
102-
return 'time'
112+
name = 'time'
113+
if name not in self.dataset.variables:
114+
raise NoSuchCoordinateError(
115+
f"SHOC dataset did not have expected time coordinate {name!r}")
116+
return name
103117

104118
def get_depth_name(self) -> Hashable:
105-
return 'zc'
119+
name = 'zc'
120+
if name not in self.dataset.variables:
121+
raise NoSuchCoordinateError(
122+
f"SHOC dataset did not have expected depth coordinate {name!r}")
123+
return name
106124

107125
def get_all_depth_names(self) -> List[Hashable]:
108126
return [self.get_depth_name()]

src/emsarray/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,11 @@ class ConventionViolationWarning(UserWarning):
2222
For example, an attribute has an invalid type,
2323
but is still interpretable.
2424
"""
25+
26+
27+
class NoSuchCoordinateError(KeyError, EmsarrayError):
28+
"""
29+
Raised when a dataset does not have a particular coordinate,
30+
such as in :attr:`.Convention.time_coordinate` and
31+
:attr:`.Convention.depth_coordinate`.
32+
"""

src/emsarray/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import logging
1616
import textwrap
1717
import time
18+
import warnings
1819
from types import TracebackType
1920
from typing import (
2021
Any, Callable, Hashable, Iterable, List, Literal, Mapping, MutableMapping,
@@ -721,3 +722,13 @@ def make_polygons_with_holes(
721722
indices=complete_row_indices,
722723
out=out)
723724
return out
725+
726+
727+
def deprecated(message: str, category: Type[Warning] = DeprecationWarning) -> Callable:
728+
def decorator(fn: Callable) -> Callable:
729+
@functools.wraps(fn)
730+
def wrapped(*args: Any, **kwargs: Any) -> Any:
731+
warnings.warn(message, category=category, stacklevel=2)
732+
return fn(*args, **kwargs)
733+
return wrapped
734+
return decorator

tests/conventions/test_base.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
import dataclasses
44
import enum
5+
import pathlib
56
from functools import cached_property
6-
from typing import Dict, Hashable, List, Optional, Tuple
7+
from typing import Dict, Hashable, Optional, Tuple
78

89
import matplotlib.pyplot as plt
910
import numpy as np
@@ -14,6 +15,7 @@
1415

1516
from emsarray import masking, utils
1617
from emsarray.conventions import Convention, SpatialIndexItem
18+
from emsarray.exceptions import NoSuchCoordinateError
1719
from emsarray.types import Pathish
1820

1921

@@ -41,15 +43,6 @@ class SimpleConvention(Convention[SimpleGridKind, SimpleGridIndex]):
4143
def check_dataset(cls, dataset: xr.Dataset) -> Optional[int]:
4244
return None
4345

44-
def get_time_name(self) -> Hashable:
45-
return 't'
46-
47-
def get_depth_name(self) -> Hashable:
48-
return 'z'
49-
50-
def get_all_depth_names(self) -> List[Hashable]:
51-
return [self.get_depth_name()]
52-
5346
@cached_property
5447
def shape(self) -> Tuple[int, int]:
5548
y, x = map(int, self.dataset['botz'].shape)
@@ -114,6 +107,42 @@ def apply_clip_mask(self, clip_mask: xr.Dataset, work_dir: Pathish) -> xr.Datase
114107
return masking.mask_grid_dataset(self.dataset, clip_mask, work_dir)
115108

116109

110+
def test_get_time_name(datasets: pathlib.Path) -> None:
111+
dataset = xr.open_dataset(datasets / 'times.nc')
112+
SimpleConvention(dataset).bind()
113+
assert dataset.ems.get_time_name() == 'time'
114+
xr.testing.assert_equal(dataset.ems.time_coordinate, dataset['time'])
115+
116+
117+
def test_get_time_name_missing() -> None:
118+
dataset = xr.Dataset()
119+
SimpleConvention(dataset).bind()
120+
with pytest.raises(NoSuchCoordinateError):
121+
dataset.ems.get_time_name()
122+
123+
124+
@pytest.mark.parametrize('attrs', [
125+
{'positive': 'up'},
126+
{'positive': 'DOWN'},
127+
{'standard_name': 'depth'},
128+
{'axis': 'Z'},
129+
], ids=lambda a: '{}:{}'.format(*next(iter(a.items()))))
130+
def test_get_depth_name(attrs: dict) -> None:
131+
dataset = xr.Dataset({
132+
'name': (['dim'], [0, 1, 2], attrs),
133+
})
134+
SimpleConvention(dataset).bind()
135+
assert dataset.ems.get_depth_name() == 'name'
136+
xr.testing.assert_equal(dataset.ems.depth_coordinate, dataset['name'])
137+
138+
139+
def test_get_depth_name_missing() -> None:
140+
dataset = xr.Dataset()
141+
SimpleConvention(dataset).bind()
142+
with pytest.raises(NoSuchCoordinateError):
143+
dataset.ems.get_depth_name()
144+
145+
117146
def test_mask():
118147
dataset = xr.Dataset({
119148
'values': (['z', 'y', 'x'], np.random.standard_normal((5, 10, 20))),

0 commit comments

Comments
 (0)