Skip to content

Commit b811ea9

Browse files
committed
Fix error when plotting empty transects
Fixes #119
1 parent 395514d commit b811ea9

File tree

6 files changed

+120
-21
lines changed

6 files changed

+120
-21
lines changed

docs/releases/development.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@ Next release (in development)
55
* Fix a ``FutureWarning`` on accessing :attr:`xarray.Dataset.dims`
66
with xarray >= 2023.12.0
77
(:pr:`124`, :pr:`pydata/xarray#8500`).
8+
* Fix an error when creating a transect plot that does not intersect the model geometry.
9+
Previously this would raise a cryptic error, now it returns an empty transect dataset
10+
(:issue:`119`, :pr:`120`).

src/emsarray/conventions/_base.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -268,21 +268,19 @@ def bind(self) -> None:
268268
"cannot assign a new convention.")
269269
state.bind_convention(self)
270270

271+
@abc.abstractmethod
271272
def _get_data_array(self, data_array: DataArrayOrName) -> xarray.DataArray:
272273
"""
273274
Utility to help get a data array for this dataset.
274275
If a string is passed in, the matching data array is fetched from the dataset.
275-
If a data array is passed in, it is inspected to ensure the dimensions match
276+
If a data array is passed in,
277+
it is inspected to ensure the surface dimensions align
276278
before being returned as-is.
277279
278280
This is useful for methods that support being passed either
279281
the name of a data array or a data array instance.
280282
"""
281-
if isinstance(data_array, xarray.DataArray):
282-
utils.check_data_array_dimensions_match(self.dataset, data_array)
283-
return data_array
284-
else:
285-
return self.dataset[data_array]
283+
pass
286284

287285
@cached_property
288286
def time_coordinate(self) -> xarray.DataArray:
@@ -1011,6 +1009,9 @@ def animate_on_figure(
10111009
if coordinate is None:
10121010
# Assume the user wants to plot along the time axis by default.
10131011
coordinate = self.get_time_name()
1012+
if isinstance(coordinate, xarray.DataArray):
1013+
utils.check_data_array_dimensions_match(
1014+
self.dataset, coordinate, dimensions=coordinate.dims)
10141015

10151016
coordinate = self._get_data_array(coordinate)
10161017

@@ -1787,6 +1788,23 @@ def get_grid_kind(self, data_array: xarray.DataArray) -> GridKind:
17871788
return kind
17881789
raise ValueError("Unknown grid kind")
17891790

1791+
def _get_data_array(self, data_array: DataArrayOrName) -> xarray.DataArray:
1792+
if isinstance(data_array, xarray.DataArray):
1793+
grid_kind = self.get_grid_kind(data_array)
1794+
grid_dimensions = self.grid_dimensions[grid_kind]
1795+
for dimension in grid_dimensions:
1796+
# The data array already has matching dimension names
1797+
# as we found the grid kind using `Convention.get_grid_kind()`.
1798+
if self.dataset.sizes[dimension] != data_array.sizes[dimension]:
1799+
raise ValueError(
1800+
f"Mismatched dimension {dimension!r}, "
1801+
"dataset has size {self.dataset.sizes[dimension]} but "
1802+
"data array has size {data_array.sizes[dimension]}!"
1803+
)
1804+
return data_array
1805+
else:
1806+
return self.dataset[data_array]
1807+
17901808
@abc.abstractmethod
17911809
def unpack_index(self, index: Index) -> Tuple[GridKind, Sequence[int]]:
17921810
"""Convert a native index in to a grid kind and dimension indices.

src/emsarray/transect.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,17 @@ def transect_dataset(self) -> xarray.Dataset:
188188
dims=(depth_dimension, 'bounds'),
189189
)
190190
distance_bounds = xarray.DataArray(
191-
data=[
192-
[segment.start_distance, segment.end_distance]
193-
for segment in self.segments
194-
],
191+
data=numpy.fromiter(
192+
(
193+
[segment.start_distance, segment.end_distance]
194+
for segment in self.segments
195+
),
196+
# Be explicit here, to handle the case when len(self.segments) == 0.
197+
# This happens when the transect line does not intersect the dataset.
198+
# This will result in an empty transect plot.
199+
count=len(self.segments),
200+
dtype=numpy.dtype((float, 2)),
201+
),
195202
dims=('index', 'bounds'),
196203
attrs={
197204
'long_name': 'Distance along transect',
@@ -762,8 +769,15 @@ def _plot_on_figure(
762769

763770
cmap = colormaps[cmap].copy()
764771
cmap.set_bad(ocean_floor_colour)
765-
collection = self.make_poly_collection(
766-
cmap=cmap, clim=(numpy.nanmin(data_array), numpy.nanmax(data_array)))
772+
773+
if data_array.size != 0:
774+
clim = (numpy.nanmin(data_array), numpy.nanmax(data_array))
775+
else:
776+
# An empty data array happens when the transect line does not
777+
# intersect the dataset geometry.
778+
clim = None
779+
780+
collection = self.make_poly_collection(cmap=cmap, clim=clim, edgecolor='face')
767781
axes.add_collection(collection)
768782

769783
if bathymetry is not None:

src/emsarray/utils.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -427,38 +427,54 @@ def dimensions_from_coords(
427427
return dimensions
428428

429429

430-
def check_data_array_dimensions_match(dataset: xarray.Dataset, data_array: xarray.DataArray) -> None:
430+
def check_data_array_dimensions_match(
431+
dataset: xarray.Dataset,
432+
data_array: xarray.DataArray,
433+
*,
434+
dimensions: Optional[Sequence[Hashable]] = None,
435+
) -> None:
431436
"""
432437
Check that the dimensions of a :class:`xarray.DataArray`
433438
match the dimensions of a :class:`xarray.Dataset`.
434439
This is useful when using the metadata of a particular dataset to display a data array,
435440
without requiring the data array to be taken directly from the dataset.
436441
437-
If the dimensions do not match, a ValueError is raised, indicating the mismatched dimension.
442+
If the dimensions do not match a ValueError is raised indicating the mismatched dimension.
438443
439444
Parameters
440445
----------
441-
dataset
446+
dataset : xarray.Dataset
442447
The dataset used as a reference
443-
data_array
448+
data_array : xarray.DataArray
444449
The data array to check the dimensions of
450+
dimensions : list of Hashable, optional
451+
The dimension names to check for equal sizes.
452+
Optional, defaults to checking all dimensions on the data array.
445453
446454
Raises
447455
------
448456
ValueError
449457
Raised if the dimensions do not match
450458
"""
451-
for dimension, data_array_size in zip(data_array.dims, data_array.shape):
452-
if dimension not in dataset.dims:
459+
if dimensions is None:
460+
dimensions = data_array.dims
461+
462+
for dimension in dimensions:
463+
if dimension not in dataset.dims and dimension not in data_array.dims:
453464
raise ValueError(
454-
f"Data array has unknown dimension {dimension} of size {data_array_size}"
465+
f"Dimension {dimension!r} not present on either dataset or data array"
455466
)
456-
467+
elif dimension not in dataset.dims:
468+
raise ValueError(f"Dataset does not have dimension {dimension!r}")
469+
elif dimension not in data_array.dims:
470+
raise ValueError(f"Data array does not have dimension {dimension!r}")
457471
dataset_size = dataset.sizes[dimension]
472+
data_array_size = data_array.sizes[dimension]
473+
458474
if data_array_size != dataset_size:
459475
raise ValueError(
460476
"Dimension mismatch between dataset and data array: "
461-
f"Dataset dimension {dimension} has size {dataset_size}, "
477+
f"Dataset dimension {dimension!r} has size {dataset_size}, "
462478
f"data array has size {data_array_size}"
463479
)
464480

tests/conventions/test_base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ class SimpleConvention(Convention[SimpleGridKind, SimpleGridIndex]):
4444
def check_dataset(cls, dataset: xarray.Dataset) -> Optional[int]:
4545
return None
4646

47+
def _get_data_array(self, data_array_or_name) -> xarray.DataArray:
48+
if isinstance(data_array_or_name, str):
49+
return self.dataset[data_array_or_name]
50+
else:
51+
return data_array_or_name
52+
4753
@cached_property
4854
def shape(self) -> Tuple[int, int]:
4955
y, x = map(int, self.dataset['botz'].shape)

tests/test_transect.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,45 @@ def test_plot(
5555

5656
matplotlib.pyplot.savefig(tmp_path / 'plot.png')
5757
logger.info("Saved plot to %r", tmp_path / 'plot.png')
58+
59+
60+
@pytest.mark.matplotlib(mock_coast=True)
61+
@pytest.mark.tutorial
62+
def test_plot_no_intersection(
63+
datasets: pathlib.Path,
64+
tmp_path: pathlib.Path,
65+
):
66+
"""
67+
Transects that do not intersect the dataset geometry need special handling.
68+
This should produce an empty transect plot, which is better than raising an error.
69+
"""
70+
dataset = emsarray.tutorial.open_dataset('gbr4')
71+
temp = dataset['temp'].copy()
72+
temp = temp.isel(time=-1)
73+
74+
# This line goes through the Bass Strait, no where near the GBR.
75+
# Someone picked the wrong dataset...
76+
line = shapely.LineString([
77+
[142.097168, -39.206719],
78+
[145.393066, -39.3088],
79+
[149.798584, -39.172659],
80+
])
81+
emsarray.transect.plot(
82+
dataset, line, temp,
83+
bathymetry=dataset['botz'])
84+
85+
figure = matplotlib.pyplot.gcf()
86+
axes = figure.axes[0]
87+
# This is assembled from the variable long_name and the time coordinate
88+
assert axes.get_title() == 'Temperature\n2022-05-11T14:00'
89+
# This is the long_name of the depth coordinate
90+
assert axes.get_ylabel() == 'Z coordinate'
91+
# This is made up
92+
assert axes.get_xlabel() == 'Distance along transect'
93+
94+
colorbar = figure.axes[-1]
95+
# This is the variable units
96+
assert colorbar.get_ylabel() == 'degrees C'
97+
98+
matplotlib.pyplot.savefig(tmp_path / 'plot.png')
99+
logger.info("Saved plot to %r", tmp_path / 'plot.png')

0 commit comments

Comments
 (0)