Skip to content

Commit f52a95c

Browse files
authored
Fix interval labels with units (#4794)
* make sure we actually get the correct axis label * append the interval position suffix before the units * update whats-new.rst * also update the call in facetgrid * test both x and y using parametrize
1 parent 9bb0302 commit f52a95c

File tree

5 files changed

+31
-12
lines changed

5 files changed

+31
-12
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ Bug fixes
6868
By `Alessandro Amici <https://github.com/alexamici>`_
6969
- Limit number of data rows when printing large datasets. (:issue:`4736`, :pull:`4750`). By `Jimmy Westling <https://github.com/illviljan>`_.
7070
- Add ``missing_dims`` parameter to transpose (:issue:`4647`, :pull:`4767`). By `Daniel Mesejo <https://github.com/mesejo>`_.
71+
- Resolve intervals before appending other metadata to labels when plotting (:issue:`4322`, :pull:`4794`).
72+
By `Justus Magin <https://github.com/keewis>`_.
7173

7274
Documentation
7375
~~~~~~~~~~~~~

xarray/plot/facetgrid.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,11 @@ def map_dataarray_line(
306306
)
307307
self._mappables.append(mappable)
308308

309-
_, _, hueplt, xlabel, ylabel, huelabel = _infer_line_data(
309+
xplt, yplt, hueplt, huelabel = _infer_line_data(
310310
darray=self.data.loc[self.name_dicts.flat[0]], x=x, y=y, hue=hue
311311
)
312+
xlabel = label_from_attrs(xplt)
313+
ylabel = label_from_attrs(yplt)
312314

313315
self._hue_var = hueplt
314316
self._hue_label = huelabel

xarray/plot/plot.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,7 @@ def _infer_line_data(darray, x, y, hue):
107107
huelabel = label_from_attrs(darray[huename])
108108
hueplt = darray[huename]
109109

110-
xlabel = label_from_attrs(xplt)
111-
ylabel = label_from_attrs(yplt)
112-
113-
return xplt, yplt, hueplt, xlabel, ylabel, huelabel
110+
return xplt, yplt, hueplt, huelabel
114111

115112

116113
def plot(
@@ -292,12 +289,14 @@ def line(
292289
assert "args" not in kwargs
293290

294291
ax = get_axis(figsize, size, aspect, ax)
295-
xplt, yplt, hueplt, xlabel, ylabel, hue_label = _infer_line_data(darray, x, y, hue)
292+
xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue)
296293

297294
# Remove pd.Intervals if contained in xplt.values and/or yplt.values.
298-
xplt_val, yplt_val, xlabel, ylabel, kwargs = _resolve_intervals_1dplot(
299-
xplt.values, yplt.values, xlabel, ylabel, kwargs
295+
xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot(
296+
xplt.values, yplt.values, kwargs
300297
)
298+
xlabel = label_from_attrs(xplt, extra=x_suffix)
299+
ylabel = label_from_attrs(yplt, extra=y_suffix)
301300

302301
_ensure_plottable(xplt_val, yplt_val)
303302

xarray/plot/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -503,12 +503,14 @@ def _interval_to_double_bound_points(xarray, yarray):
503503
return xarray, yarray
504504

505505

506-
def _resolve_intervals_1dplot(xval, yval, xlabel, ylabel, kwargs):
506+
def _resolve_intervals_1dplot(xval, yval, kwargs):
507507
"""
508508
Helper function to replace the values of x and/or y coordinate arrays
509509
containing pd.Interval with their mid-points or - for step plots - double
510510
points which double the length.
511511
"""
512+
x_suffix = ""
513+
y_suffix = ""
512514

513515
# Is it a step plot? (see matplotlib.Axes.step)
514516
if kwargs.get("drawstyle", "").startswith("steps-"):
@@ -534,13 +536,13 @@ def _resolve_intervals_1dplot(xval, yval, xlabel, ylabel, kwargs):
534536
# Convert intervals to mid points and adjust labels
535537
if _valid_other_type(xval, [pd.Interval]):
536538
xval = _interval_to_mid_points(xval)
537-
xlabel += "_center"
539+
x_suffix = "_center"
538540
if _valid_other_type(yval, [pd.Interval]):
539541
yval = _interval_to_mid_points(yval)
540-
ylabel += "_center"
542+
y_suffix = "_center"
541543

542544
# return converted arguments
543-
return xval, yval, xlabel, ylabel, kwargs
545+
return xval, yval, x_suffix, y_suffix, kwargs
544546

545547

546548
def _resolve_intervals_2dplot(val, func_name):

xarray/tests/test_plot.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,20 @@ def test_coord_with_interval_xy(self):
592592
bins = [-1, 0, 1, 2]
593593
self.darray.groupby_bins("dim_0", bins).mean(...).dim_0_bins.plot()
594594

595+
@pytest.mark.parametrize("dim", ("x", "y"))
596+
def test_labels_with_units_with_interval(self, dim):
597+
"""Test line plot with intervals and a units attribute."""
598+
bins = [-1, 0, 1, 2]
599+
arr = self.darray.groupby_bins("dim_0", bins).mean(...)
600+
arr.dim_0_bins.attrs["units"] = "m"
601+
602+
(mappable,) = arr.plot(**{dim: "dim_0_bins"})
603+
ax = mappable.figure.gca()
604+
actual = getattr(ax, f"get_{dim}label")()
605+
606+
expected = "dim_0_bins_center [m]"
607+
assert actual == expected
608+
595609

596610
class TestPlot1D(PlotTestCase):
597611
@pytest.fixture(autouse=True)

0 commit comments

Comments
 (0)