Skip to content

Commit 9951491

Browse files
Illviljanmathause
andauthored
Align naming convention with plt.subplots (#7194)
* Align axs naming with plt.subplots * Update plotting.rst * Update facetgrid.py * Update facetgrid.py * Update doc/user-guide/plotting.rst Co-authored-by: Mathias Hauser <[email protected]> * Update facetgrid.py * Update whats-new.rst * Update facetgrid.py Co-authored-by: Mathias Hauser <[email protected]>
1 parent c4677ce commit 9951491

File tree

3 files changed

+69
-40
lines changed

3 files changed

+69
-40
lines changed

doc/user-guide/plotting.rst

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -156,18 +156,18 @@ Keyword arguments work the same way, and are more explicit.
156156

157157
To add the plot to an existing axis pass in the axis as a keyword argument
158158
``ax``. This works for all xarray plotting methods.
159-
In this example ``axes`` is an array consisting of the left and right
159+
In this example ``axs`` is an array consisting of the left and right
160160
axes created by ``plt.subplots``.
161161

162162
.. ipython:: python
163163
:okwarning:
164164
165-
fig, axes = plt.subplots(ncols=2)
165+
fig, axs = plt.subplots(ncols=2)
166166
167-
axes
167+
axs
168168
169-
air1d.plot(ax=axes[0])
170-
air1d.plot.hist(ax=axes[1])
169+
air1d.plot(ax=axs[0])
170+
air1d.plot.hist(ax=axs[1])
171171
172172
plt.tight_layout()
173173
@@ -348,8 +348,8 @@ The keyword arguments ``xincrease`` and ``yincrease`` let you control the axes d
348348
349349
In addition, one can use ``xscale, yscale`` to set axes scaling;
350350
``xticks, yticks`` to set axes ticks and ``xlim, ylim`` to set axes limits.
351-
These accept the same values as the matplotlib methods ``Axes.set_(x,y)scale()``,
352-
``Axes.set_(x,y)ticks()``, ``Axes.set_(x,y)lim()`` respectively.
351+
These accept the same values as the matplotlib methods ``ax.set_(x,y)scale()``,
352+
``ax.set_(x,y)ticks()``, ``ax.set_(x,y)lim()``, respectively.
353353

354354

355355
Two Dimensions
@@ -701,12 +701,12 @@ that links a :py:class:`DataArray` to a matplotlib figure with a particular stru
701701
This object can be used to control the behavior of the multiple plots.
702702
It borrows an API and code from `Seaborn's FacetGrid
703703
<https://seaborn.pydata.org/tutorial/axis_grids.html>`_.
704-
The structure is contained within the ``axes`` and ``name_dicts``
704+
The structure is contained within the ``axs`` and ``name_dicts``
705705
attributes, both 2d NumPy object arrays.
706706

707707
.. ipython:: python
708708
709-
g.axes
709+
g.axs
710710
711711
g.name_dicts
712712
@@ -726,10 +726,10 @@ they have been plotted.
726726
727727
g = t.plot.imshow(x="lon", y="lat", col="time", col_wrap=3, robust=True)
728728
729-
for i, ax in enumerate(g.axes.flat):
729+
for i, ax in enumerate(g.axs.flat):
730730
ax.set_title("Air Temperature %d" % i)
731731
732-
bottomright = g.axes[-1, -1]
732+
bottomright = g.axs[-1, -1]
733733
bottomright.annotate("bottom right", (240, 40))
734734
735735
@savefig plot_facet_iterator.png
@@ -928,7 +928,7 @@ by faceting are accessible in the object returned by ``plot``:
928928
col="time",
929929
subplot_kws={"projection": ccrs.Orthographic(-80, 35)},
930930
)
931-
for ax in p.axes.flat:
931+
for ax in p.axs.flat:
932932
ax.coastlines()
933933
ax.gridlines()
934934
@savefig plotting_maps_cartopy_facetting.png width=100%
@@ -958,11 +958,11 @@ These are provided for user convenience; they all call the same code.
958958
import xarray.plot as xplt
959959
960960
da = xr.DataArray(range(5))
961-
fig, axes = plt.subplots(ncols=2, nrows=2)
962-
da.plot(ax=axes[0, 0])
963-
da.plot.line(ax=axes[0, 1])
964-
xplt.plot(da, ax=axes[1, 0])
965-
xplt.line(da, ax=axes[1, 1])
961+
fig, axs = plt.subplots(ncols=2, nrows=2)
962+
da.plot(ax=axs[0, 0])
963+
da.plot.line(ax=axs[0, 1])
964+
xplt.plot(da, ax=axs[1, 0])
965+
xplt.line(da, ax=axs[1, 1])
966966
plt.tight_layout()
967967
@savefig plotting_ways_to_use.png width=6in
968968
plt.draw()

doc/whats-new.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ Deprecations
4242

4343
- Positional arguments for all plot methods have been deprecated (:issue:`6949`, :pull:`7052`).
4444
By `Michael Niklas <https://github.com/headtr1ck>`_.
45+
- ``xarray.plot.FacetGrid.axes`` has been renamed to ``xarray.plot.FacetGrid.axs``
46+
because it's not clear if ``axes`` refers to single or multiple ``Axes`` instances.
47+
This aligns with ``matplotlib.pyplot.subplots``. (:pull:`7194`)
48+
By `Jimmy Westling <https://github.com/illviljan>`_.
4549

4650
Bug fixes
4751
~~~~~~~~~
@@ -60,7 +64,8 @@ Documentation
6064
By `Jessica Scheick <https://github.com/jessicas11>`_.
6165
- Add example of using :py:meth:`DataArray.coarsen.construct` to User Guide. (:pull:`7192`)
6266
By `Tom Nicholas <https://github.com/TomNicholas>`_.
63-
67+
- Rename ``axes`` to ``axs`` in plotting to align with ``matplotlib.pyplot.subplots``. (:pull:`7194`)
68+
By `Jimmy Westling <https://github.com/illviljan>`_.
6469

6570
Internal Changes
6671
~~~~~~~~~~~~~~~~

xarray/plot/facetgrid.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class FacetGrid(Generic[T_Xarray]):
9494
9595
Attributes
9696
----------
97-
axes : ndarray of matplotlib.axes.Axes
97+
axs : ndarray of matplotlib.axes.Axes
9898
Array containing axes in corresponding position, as returned from
9999
:py:func:`matplotlib.pyplot.subplots`.
100100
col_labels : list of matplotlib.text.Annotation
@@ -112,7 +112,7 @@ class FacetGrid(Generic[T_Xarray]):
112112
data: T_Xarray
113113
name_dicts: np.ndarray
114114
fig: Figure
115-
axes: np.ndarray
115+
axs: np.ndarray
116116
row_names: list[np.ndarray]
117117
col_names: list[np.ndarray]
118118
figlegend: Legend | None
@@ -223,7 +223,7 @@ def __init__(
223223
cbar_space = 1
224224
figsize = (ncol * size * aspect + cbar_space, nrow * size)
225225

226-
fig, axes = plt.subplots(
226+
fig, axs = plt.subplots(
227227
nrow,
228228
ncol,
229229
sharex=sharex,
@@ -258,7 +258,7 @@ def __init__(
258258
self.data = data
259259
self.name_dicts = name_dicts
260260
self.fig = fig
261-
self.axes = axes
261+
self.axs = axs
262262
self.row_names = row_names
263263
self.col_names = col_names
264264

@@ -282,13 +282,37 @@ def __init__(
282282
self._mappables = []
283283
self._finalized = False
284284

285+
@property
286+
def axes(self) -> np.ndarray:
287+
warnings.warn(
288+
(
289+
"self.axes is deprecated since 2022.11 in order to align with "
290+
"matplotlibs plt.subplots, use self.axs instead."
291+
),
292+
DeprecationWarning,
293+
stacklevel=2,
294+
)
295+
return self.axs
296+
297+
@axes.setter
298+
def axes(self, axs: np.ndarray) -> None:
299+
warnings.warn(
300+
(
301+
"self.axes is deprecated since 2022.11 in order to align with "
302+
"matplotlibs plt.subplots, use self.axs instead."
303+
),
304+
DeprecationWarning,
305+
stacklevel=2,
306+
)
307+
self.axs = axs
308+
285309
@property
286310
def _left_axes(self) -> np.ndarray:
287-
return self.axes[:, 0]
311+
return self.axs[:, 0]
288312

289313
@property
290314
def _bottom_axes(self) -> np.ndarray:
291-
return self.axes[-1, :]
315+
return self.axs[-1, :]
292316

293317
def map_dataarray(
294318
self: T_FacetGrid,
@@ -347,7 +371,7 @@ def map_dataarray(
347371
rgb=kwargs.get("rgb", None),
348372
)
349373

350-
for d, ax in zip(self.name_dicts.flat, self.axes.flat):
374+
for d, ax in zip(self.name_dicts.flat, self.axs.flat):
351375
# None is the sentinel value
352376
if d is not None:
353377
subset = self.data.loc[d]
@@ -449,7 +473,7 @@ def map_plot1d(
449473
func_kwargs["add_legend"] = False
450474
func_kwargs["add_title"] = False
451475

452-
add_labels_ = np.zeros(self.axes.shape + (3,), dtype=bool)
476+
add_labels_ = np.zeros(self.axs.shape + (3,), dtype=bool)
453477
if kwargs.get("z") is not None:
454478
# 3d plots looks better with all labels. 3d plots can't sharex either so it
455479
# is easy to get lost while rotating the plots:
@@ -478,7 +502,7 @@ def map_plot1d(
478502

479503
# Plot the data for each subplot:
480504
for add_lbls, d, ax in zip(
481-
add_labels_.reshape((self.axes.size, -1)), name_dicts.flat, self.axes.flat
505+
add_labels_.reshape((self.axs.size, -1)), name_dicts.flat, self.axs.flat
482506
):
483507
func_kwargs["add_labels"] = add_lbls
484508
# None is the sentinel value
@@ -542,7 +566,7 @@ def map_dataarray_line(
542566
) -> T_FacetGrid:
543567
from .dataarray_plot import _infer_line_data
544568

545-
for d, ax in zip(self.name_dicts.flat, self.axes.flat):
569+
for d, ax in zip(self.name_dicts.flat, self.axs.flat):
546570
# None is the sentinel value
547571
if d is not None:
548572
subset = self.data.loc[d]
@@ -609,7 +633,7 @@ def map_dataset(
609633
raise ValueError("Please provide scale.")
610634
# TODO: come up with an algorithm for reasonable scale choice
611635

612-
for d, ax in zip(self.name_dicts.flat, self.axes.flat):
636+
for d, ax in zip(self.name_dicts.flat, self.axs.flat):
613637
# None is the sentinel value
614638
if d is not None:
615639
subset = self.data.loc[d]
@@ -643,7 +667,7 @@ def _finalize_grid(self, *axlabels: Hashable) -> None:
643667
self.set_titles()
644668
self.fig.tight_layout()
645669

646-
for ax, namedict in zip(self.axes.flat, self.name_dicts.flat):
670+
for ax, namedict in zip(self.axs.flat, self.name_dicts.flat):
647671
if namedict is None:
648672
ax.set_visible(False)
649673

@@ -703,15 +727,15 @@ def add_colorbar(self, **kwargs: Any) -> None:
703727
if "label" not in kwargs:
704728
kwargs.setdefault("label", label_from_attrs(self.data))
705729
self.cbar = self.fig.colorbar(
706-
self._mappables[-1], ax=list(self.axes.flat), **kwargs
730+
self._mappables[-1], ax=list(self.axs.flat), **kwargs
707731
)
708732

709733
def add_quiverkey(self, u: Hashable, v: Hashable, **kwargs: Any) -> None:
710734
kwargs = kwargs.copy()
711735

712736
magnitude = _get_nice_quiver_magnitude(self.data[u], self.data[v])
713737
units = self.data[u].attrs.get("units", "")
714-
self.quiverkey = self.axes.flat[-1].quiverkey(
738+
self.quiverkey = self.axs.flat[-1].quiverkey(
715739
self._mappables[-1],
716740
X=0.8,
717741
Y=0.9,
@@ -747,7 +771,7 @@ def _get_largest_lims(self) -> dict[str, tuple[float, float]]:
747771
for axis in ("x", "y", "z"):
748772
# Find the plot with the largest xlim values:
749773
lower, upper = lims_largest[axis]
750-
for ax in self.axes.flat:
774+
for ax in self.axs.flat:
751775
get_lim: None | Callable[[], tuple[float, float]] = getattr(
752776
ax, f"get_{axis}lim", None
753777
)
@@ -781,13 +805,13 @@ def _set_lims(
781805
>>> ds = xr.tutorial.scatter_example_dataset(seed=42)
782806
>>> fg = ds.plot.scatter(x="A", y="B", hue="y", row="x", col="w")
783807
>>> fg._set_lims(x=(-0.3, 0.3), y=(0, 2), z=(0, 4))
784-
>>> fg.axes[0, 0].get_xlim(), fg.axes[0, 0].get_ylim()
808+
>>> fg.axs[0, 0].get_xlim(), fg.axs[0, 0].get_ylim()
785809
((-0.3, 0.3), (0.0, 2.0))
786810
"""
787811
lims_largest = self._get_largest_lims()
788812

789813
# Set limits:
790-
for ax in self.axes.flat:
814+
for ax in self.axs.flat:
791815
for (axis, data_limit), parameter_limit in zip(
792816
lims_largest.items(), (x, y, z)
793817
):
@@ -858,7 +882,7 @@ def set_titles(
858882
nicetitle = functools.partial(_nicetitle, maxchar=maxchar, template=template)
859883

860884
if self._single_group:
861-
for d, ax in zip(self.name_dicts.flat, self.axes.flat):
885+
for d, ax in zip(self.name_dicts.flat, self.axs.flat):
862886
# Only label the ones with data
863887
if d is not None:
864888
coord, value = list(d.items()).pop()
@@ -867,7 +891,7 @@ def set_titles(
867891
else:
868892
# The row titles on the right edge of the grid
869893
for index, (ax, row_name, handle) in enumerate(
870-
zip(self.axes[:, -1], self.row_names, self.row_labels)
894+
zip(self.axs[:, -1], self.row_names, self.row_labels)
871895
):
872896
title = nicetitle(coord=self._row_var, value=row_name, maxchar=maxchar)
873897
if not handle:
@@ -886,7 +910,7 @@ def set_titles(
886910

887911
# The column titles on the top row
888912
for index, (ax, col_name, handle) in enumerate(
889-
zip(self.axes[0, :], self.col_names, self.col_labels)
913+
zip(self.axs[0, :], self.col_names, self.col_labels)
890914
):
891915
title = nicetitle(coord=self._col_var, value=col_name, maxchar=maxchar)
892916
if not handle:
@@ -922,7 +946,7 @@ def set_ticks(
922946
x_major_locator = MaxNLocator(nbins=max_xticks)
923947
y_major_locator = MaxNLocator(nbins=max_yticks)
924948

925-
for ax in self.axes.flat:
949+
for ax in self.axs.flat:
926950
ax.xaxis.set_major_locator(x_major_locator)
927951
ax.yaxis.set_major_locator(y_major_locator)
928952
for tick in itertools.chain(
@@ -957,7 +981,7 @@ def map(
957981
"""
958982
plt = import_matplotlib_pyplot()
959983

960-
for ax, namedict in zip(self.axes.flat, self.name_dicts.flat):
984+
for ax, namedict in zip(self.axs.flat, self.name_dicts.flat):
961985
if namedict is not None:
962986
data = self.data.loc[namedict]
963987
plt.sca(ax)

0 commit comments

Comments
 (0)