Skip to content

Commit d6df99e

Browse files
committed
Fix transect title and units
The data array attributes were being dropped in `Transect.prepare_data_array_for_transect`. A new `emsarray.plot.make_plot_title()` function has been added. It is used for surface plots and transects.
1 parent f7fffc2 commit d6df99e

File tree

6 files changed

+117
-30
lines changed

6 files changed

+117
-30
lines changed

src/emsarray/conventions/_base.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from emsarray.exceptions import NoSuchCoordinateError
2424
from emsarray.operations import depth
2525
from emsarray.plot import (
26-
_requires_plot, animate_on_figure, plot_on_figure, polygons_to_collection
26+
_requires_plot, animate_on_figure, make_plot_title, plot_on_figure,
27+
polygons_to_collection
2728
)
2829
from emsarray.state import State
2930
from emsarray.types import Bounds, Pathish
@@ -938,25 +939,7 @@ def plot_on_figure(
938939
#
939940
# Users can supply their own titles
940941
# if this automatic behaviour is insufficient
941-
title_bits: list[str] = []
942-
long_name = kwargs['scalar'].attrs.get('long_name')
943-
if long_name is not None:
944-
title_bits.append(str(long_name))
945-
try:
946-
time_coordinate = self.dataset.variables[self.get_time_name()]
947-
except KeyError:
948-
pass
949-
else:
950-
# Add a time stamp when the time coordinate has a single value.
951-
# This happens when you `.sel()` a single time slice to plot -
952-
# as long as the time coordinate is a proper coordinate with
953-
# matching dimension name, not an auxiliary coordinate.
954-
if time_coordinate.size == 1:
955-
time = time_coordinate.values[0]
956-
title_bits.append(str(time))
957-
958-
if title_bits:
959-
kwargs['title'] = '\n'.join(title_bits)
942+
kwargs['title'] = make_plot_title(self.dataset, kwargs['scalar'])
960943

961944
plot_on_figure(figure, self, **kwargs)
962945

src/emsarray/plot.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy
99
import xarray
1010

11+
from emsarray.exceptions import NoSuchCoordinateError
1112
from emsarray.types import Landmark
1213
from emsarray.utils import requires_extra
1314

@@ -217,6 +218,48 @@ def polygons_to_collection(
217218
)
218219

219220

221+
def make_plot_title(
222+
dataset: xarray.Dataset,
223+
data_array: xarray.DataArray,
224+
) -> Optional[str]:
225+
"""
226+
Make a suitable plot title for a variable.
227+
This will attempt to find a name for the variable by looking through the attributes.
228+
If the variable has a time coordinate,
229+
and the time coordinate has a single value,
230+
the time step is appended after the title.
231+
"""
232+
if 'long_name' in data_array.attrs:
233+
title = str(data_array.attrs['long_name'])
234+
elif 'standard_name' in data_array.attrs:
235+
title = str(data_array.attrs['standard_name'])
236+
elif data_array.name is not None:
237+
title = str(data_array.name)
238+
else:
239+
return None
240+
241+
# Check if this variable has a time coordinate
242+
try:
243+
time_coordinate = dataset.ems.time_coordinate
244+
except NoSuchCoordinateError:
245+
return title
246+
if time_coordinate.name not in data_array.coords:
247+
return title
248+
# Fetch the coordinate from the data array itself,
249+
# in case someone did `data_array = dataset['temp'].isel(time=0)`
250+
time_coordinate = data_array.coords[time_coordinate.name]
251+
252+
if len(time_coordinate.dims) == 0:
253+
time_value = time_coordinate.values
254+
elif time_coordinate.size == 1:
255+
time_value = time_coordinate.values[0]
256+
else:
257+
return title
258+
259+
time_string = numpy.datetime_as_string(time_value, unit='auto')
260+
return f'{title}\n{time_string}'
261+
262+
220263
@_requires_plot
221264
def plot_on_figure(
222265
figure: Figure,

src/emsarray/transect.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from matplotlib.ticker import EngFormatter, Formatter
2121

2222
from emsarray.conventions import Convention, Index
23-
from emsarray.plot import _requires_plot
23+
from emsarray.plot import _requires_plot, make_plot_title
2424
from emsarray.types import Landmark
2525
from emsarray.utils import move_dimensions_to_end
2626

@@ -493,6 +493,10 @@ def prepare_data_array_for_transect(self, data_array: xarray.DataArray) -> xarra
493493
The input data array transformed to have the correct shape
494494
for plotting on the transect.
495495
"""
496+
# Some of the following operations drop attrs,
497+
# so keep a reference to the original ones
498+
attrs = data_array.attrs
499+
496500
data_array = self.convention.ravel(data_array)
497501

498502
depth_dimension = self.transect_dataset.coords['depth'].dims[0]
@@ -502,6 +506,9 @@ def prepare_data_array_for_transect(self, data_array: xarray.DataArray) -> xarra
502506
linear_indices = self.transect_dataset['linear_index'].values
503507
data_array = data_array.isel({index_dimension: linear_indices})
504508

509+
# Restore attrs after reformatting
510+
data_array.attrs.update(attrs)
511+
505512
return data_array
506513

507514
def _find_depth_bounds(self, data_array: xarray.DataArray) -> Tuple[int, int]:
@@ -749,7 +756,7 @@ def _plot_on_figure(
749756
axes.set_ylim(depth_limit_deep, depth_limit_shallow)
750757

751758
if title is None:
752-
title = data_array.attrs.get('long_name')
759+
title = make_plot_title(self.dataset, data_array)
753760
if title is not None:
754761
axes.set_title(title)
755762

tests/conventions/test_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ def test_plot():
443443
'botz': (['y', 'x'], numpy.random.standard_normal((10, 20)) - 10),
444444
})
445445
convention = SimpleConvention(dataset)
446+
convention.bind()
446447

447448
# Naming a simple variable should work fine
448449
convention.plot('botz')

tests/test_plot.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,15 @@ def test_plot(
5858
datasets: pathlib.Path,
5959
tmp_path: pathlib.Path,
6060
):
61-
"""
62-
Test plotting a variable with no long_name attribute works.
63-
Regression test for https://github.com/csiro-coasts/emsarray/issues/105
64-
"""
65-
dataset = emsarray.tutorial.open_dataset('gbr4')
61+
dataset = emsarray.tutorial.open_dataset('fraser')
6662
temp = dataset['temp'].copy()
6763
temp = temp.isel(time=0, k=-1)
6864

6965
dataset.ems.plot(temp)
7066

7167
figure = matplotlib.pyplot.gcf()
7268
axes = figure.axes[0]
73-
assert axes.get_title() == 'Temperature\n2022-05-11T14:00:00.000000000'
69+
assert axes.get_title() == 'Temperature\n2022-05-11T14:00'
7470

7571
matplotlib.pyplot.savefig(tmp_path / 'plot.png')
7672
logger.info("Saved plot to %r", tmp_path / 'plot.png')
@@ -86,7 +82,7 @@ def test_plot_no_long_name(
8682
Test plotting a variable with no long_name attribute works.
8783
Regression test for https://github.com/csiro-coasts/emsarray/issues/105
8884
"""
89-
dataset = emsarray.tutorial.open_dataset('gbr4')
85+
dataset = emsarray.tutorial.open_dataset('fraser')
9086
temp = dataset['temp'].copy()
9187
temp = temp.isel(time=0, k=-1)
9288
del temp.attrs['long_name']
@@ -95,7 +91,7 @@ def test_plot_no_long_name(
9591

9692
figure = matplotlib.pyplot.gcf()
9793
axes = figure.axes[0]
98-
assert axes.get_title() == '2022-05-11T14:00:00.000000000'
94+
assert axes.get_title() == 'temp\n2022-05-11T14:00'
9995

10096
matplotlib.pyplot.savefig(tmp_path / 'plot.png')
10197
logger.info("Saved plot to %r", tmp_path / 'plot.png')

tests/test_transect.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import logging
2+
import pathlib
3+
4+
import matplotlib
5+
import pytest
6+
import shapely
7+
8+
import emsarray.transect
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
@pytest.mark.matplotlib(mock_coast=True)
14+
@pytest.mark.tutorial
15+
def test_plot(
16+
datasets: pathlib.Path,
17+
tmp_path: pathlib.Path,
18+
):
19+
dataset = emsarray.tutorial.open_dataset('gbr4')
20+
temp = dataset['temp'].copy()
21+
temp = temp.isel(time=-1)
22+
23+
line = shapely.LineString([
24+
[152.9768944, -25.4827962],
25+
[152.9701996, -25.4420345],
26+
[152.9727745, -25.3967620],
27+
[152.9623032, -25.3517828],
28+
[152.9401588, -25.3103560],
29+
[152.9173279, -25.2538563],
30+
[152.8962135, -25.1942238],
31+
[152.8692627, -25.0706729],
32+
[152.8623962, -24.9698750],
33+
[152.8472900, -24.8415806],
34+
[152.8308105, -24.6470172],
35+
[152.7607727, -24.3521012],
36+
[152.6392365, -24.1906056],
37+
[152.4792480, -24.0615124],
38+
])
39+
emsarray.transect.plot(
40+
dataset, line, temp,
41+
bathymetry=dataset['botz'])
42+
43+
figure = matplotlib.pyplot.gcf()
44+
axes = figure.axes[0]
45+
# This is assembled from the variable long_name and the time coordinate
46+
assert axes.get_title() == 'Temperature\n2022-05-11T14:00'
47+
# This is the long_name of the depth coordinate
48+
assert axes.get_ylabel() == 'Z coordinate'
49+
# This is made up
50+
assert axes.get_xlabel() == 'Distance along transect'
51+
52+
colorbar = figure.axes[-1]
53+
# This is the variable units
54+
assert colorbar.get_ylabel() == 'degrees C'
55+
56+
matplotlib.pyplot.savefig(tmp_path / 'plot.png')
57+
logger.info("Saved plot to %r", tmp_path / 'plot.png')

0 commit comments

Comments
 (0)