Skip to content

Commit 747fe26

Browse files
authored
scatter plot by order of the first appearance of hue (#4723)
* plot by order of first appearance * use ravel to avoid copying the data * update whats-new.rst * add a test to make sure the legend labels and the mappable labels match * test with upstream-dev [test-upstream] * add a comment about the reason for using pd.unique [skip-ci] * empty commit [skip-ci]
1 parent 1ce8938 commit 747fe26

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ Bug fixes
6464
By `Richard Kleijn <https://github.com/rhkleijn>`_ .
6565
- Remove dictionary unpacking when using ``.loc`` to avoid collision with ``.sel`` parameters (:pull:`4695`).
6666
By `Anderson Banihirwe <https://github.com/andersy005>`_
67+
- Fix the legend created by :py:meth:`Dataset.plot.scatter` (:issue:`4641`, :pull:`4723`).
68+
By `Justus Magin <https://github.com/keewis>`_.
6769
- Fix a crash in orthogonal indexing on geographic coordinates with ``engine='cfgrib'`` (:issue:`4733` :pull:`4737`).
6870
By `Alessandro Amici <https://github.com/alexamici>`_
6971
- Coordinates with dtype ``str`` or ``bytes`` now retain their dtype on many operations,

xarray/plot/dataset_plot.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def newplotfunc(
291291
allargs = locals().copy()
292292
allargs["plotfunc"] = globals()[plotfunc.__name__]
293293
allargs["data"] = ds
294-
# TODO dcherian: why do I need to remove kwargs?
294+
# remove kwargs to avoid passing the information twice
295295
for arg in ["meta_data", "kwargs", "ds"]:
296296
del allargs[arg]
297297

@@ -422,7 +422,10 @@ def scatter(ds, x, y, ax, **kwargs):
422422

423423
if hue_style == "discrete":
424424
primitive = []
425-
for label in np.unique(data["hue"].values):
425+
# use pd.unique instead of np.unique because that keeps the order of the labels,
426+
# which is important to keep them in sync with the ones used in
427+
# FacetGrid.add_legend
428+
for label in pd.unique(data["hue"].values.ravel()):
426429
mask = data["hue"] == label
427430
if data["sizes"] is not None:
428431
kwargs.update(s=data["sizes"].where(mask, drop=True).values.flatten())

xarray/tests/test_plot.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2290,6 +2290,17 @@ def test_legend_labels(self):
22902290
lines = ds2.plot.scatter(x="A", y="B", hue="hue")
22912291
assert [t.get_text() for t in lines[0].axes.get_legend().texts] == ["a", "b"]
22922292

2293+
def test_legend_labels_facetgrid(self):
2294+
ds2 = self.ds.copy()
2295+
ds2["hue"] = ["d", "a", "c", "b"]
2296+
g = ds2.plot.scatter(x="A", y="B", hue="hue", col="col")
2297+
legend_labels = tuple(t.get_text() for t in g.figlegend.texts)
2298+
attached_labels = [
2299+
tuple(m.get_label() for m in mappables_per_ax)
2300+
for mappables_per_ax in g._mappables
2301+
]
2302+
assert list(set(attached_labels)) == [legend_labels]
2303+
22932304
def test_add_legend_by_default(self):
22942305
sc = self.ds.plot.scatter(x="A", y="B", hue="hue")
22952306
assert len(sc.figure.axes) == 2

0 commit comments

Comments
 (0)