Skip to content

Commit 62427ad

Browse files
Michael Vincent ManninoMichael Vincent Mannino
authored andcommitted
format
1 parent 1713727 commit 62427ad

File tree

2 files changed

+76
-55
lines changed

2 files changed

+76
-55
lines changed

pandas/plotting/_matplotlib/core.py

Lines changed: 45 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
Iterator,
1111
Sequence,
1212
)
13-
from random import shuffle
1413
from typing import (
1514
TYPE_CHECKING,
1615
Any,
@@ -22,6 +21,10 @@
2221

2322
import matplotlib as mpl
2423
import numpy as np
24+
from seaborn._base import (
25+
HueMapping,
26+
VectorPlotter,
27+
)
2528

2629
from pandas._libs import lib
2730
from pandas.errors import AbstractMethodError
@@ -1340,27 +1343,47 @@ def _make_plot(self, fig: Figure) -> None:
13401343
norm, cmap = self._get_norm_and_cmap(c_values, color_by_categorical)
13411344
cb = self._get_colorbar(c_values, c_is_column)
13421345

1343-
# if a list of non color strings is passed in as c, generate a list
1344-
# colored by uniqueness of the strings, such same strings get same color
1345-
create_colors = not self._are_valid_colors(c_values)
1346-
if create_colors:
1347-
custom_color_mapping, c_values = self._uniquely_color_strs(c_values)
1348-
cb = False # no colorbar; opt for legend
1349-
13501346
if self.legend:
13511347
label = self.label
13521348
else:
13531349
label = None
1354-
scatter = ax.scatter(
1355-
data[x].values,
1356-
data[y].values,
1357-
c=c_values,
1358-
label=label,
1359-
cmap=cmap,
1360-
norm=norm,
1361-
s=self.s,
1362-
**self.kwds,
1363-
)
1350+
1351+
# if a list of non color strings is passed in as c, color points
1352+
# by uniqueness of the strings, such same strings get same color
1353+
create_colors = not self._are_valid_colors(c_values)
1354+
1355+
# Plot as normal
1356+
if not create_colors:
1357+
scatter = ax.scatter(
1358+
data[x].values,
1359+
data[y].values,
1360+
c=c_values,
1361+
label=label,
1362+
cmap=cmap,
1363+
norm=norm,
1364+
s=self.s,
1365+
**self.kwds,
1366+
)
1367+
# Have to custom color
1368+
else:
1369+
scatter = ax.scatter(
1370+
data[x].values,
1371+
data[y].values,
1372+
label=label,
1373+
cmap=cmap,
1374+
norm=norm,
1375+
s=self.s,
1376+
**self.kwds,
1377+
)
1378+
1379+
# set colors via Seaborn as it contains all the logic for handling color
1380+
# decision all nicely packaged
1381+
scatter.set_facecolor(
1382+
HueMapping(
1383+
VectorPlotter(data=data, variables={"x": x, "y": y, "hue": c})
1384+
)(c_values)
1385+
)
1386+
13641387
if cb:
13651388
cbar_label = c if c_is_column else ""
13661389
cbar = self._plot_colorbar(ax, fig=fig, label=cbar_label)
@@ -1377,15 +1400,6 @@ def _make_plot(self, fig: Figure) -> None:
13771400
label, # type: ignore[arg-type]
13781401
)
13791402

1380-
# build legend for labeling custom colors
1381-
if create_colors:
1382-
ax.legend(
1383-
handles=[
1384-
mpl.patches.Circle((0, 0), facecolor=color, label=string)
1385-
for string, color in custom_color_mapping.items()
1386-
]
1387-
)
1388-
13891403
errors_x = self._get_errorbars(label=x, index=0, yerr=False)
13901404
errors_y = self._get_errorbars(label=y, index=0, xerr=False)
13911405
if len(errors_x) > 0 or len(errors_y) > 0:
@@ -1409,38 +1423,20 @@ def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool):
14091423
c_values = c
14101424
return c_values
14111425

1412-
def _are_valid_colors(self, c_values: np.ndarray | list):
1426+
def _are_valid_colors(self, c_values: np.ndarray):
14131427
# check if c_values contains strings and if these strings are valid mpl colors.
14141428
# no need to check numerics as these (and mpl colors) will be validated for us
14151429
# in .Axes.scatter._parse_scatter_color_args(...)
1430+
unique = np.unique(c_values)
14161431
try:
1417-
if len(c_values) and all(isinstance(c, str) for c in c_values):
1418-
mpl.colors.to_rgba_array(c_values)
1432+
if len(c_values) and all(isinstance(c, str) for c in unique):
1433+
mpl.colors.to_rgba_array(unique)
14191434

14201435
return True
14211436

14221437
except (TypeError, ValueError) as _:
14231438
return False
14241439

1425-
def _uniquely_color_strs(
1426-
self, c_values: np.ndarray | list
1427-
) -> tuple[dict, np.ndarray]:
1428-
# well, almost uniquely color them (up to 949)
1429-
unique = np.unique(c_values)
1430-
1431-
# for up to 7, lets keep colors consistent
1432-
if len(unique) <= 7:
1433-
possible_colors = list(mpl.colors.BASE_COLORS.values()) # Hex
1434-
# explore better ways to handle this case
1435-
else:
1436-
possible_colors = list(mpl.colors.XKCD_COLORS.values()) # Hex
1437-
shuffle(possible_colors)
1438-
1439-
colors = [possible_colors[i % len(possible_colors)] for i in range(len(unique))]
1440-
color_mapping = dict(zip(unique, colors))
1441-
1442-
return color_mapping, np.array(list(map(color_mapping.get, c_values)))
1443-
14441440
def _get_norm_and_cmap(self, c_values, color_by_categorical: bool):
14451441
c = self.c
14461442
if self.colormap is not None:

pandas/tests/plotting/frame/test_frame_color.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,22 +217,46 @@ def test_scatter_with_c_column_name_with_colors(self, cmap):
217217
ax = df.plot.scatter(x=0, y=1, cmap=cmap, c="species")
218218
else:
219219
ax = df.plot.scatter(x=0, y=1, c="species", cmap=cmap)
220+
221+
assert len(np.unique(ax.collections[0].get_facecolor(), axis=0)) == 3 # r/g/b
220222
assert ax.collections[0].colorbar is None
221223

222224
def test_scatter_with_c_column_name_without_colors(self):
225+
# Given
226+
colors = ["NY", "MD", "MA", "CA"]
227+
color_count = 4 # 4 unique colors
228+
229+
# When
223230
df = DataFrame(
224231
{
225232
"dataX": range(100),
226233
"dataY": range(100),
227-
"state": ["NY", "MD", "MA", "CA"] * 25,
234+
"color": (colors[i % len(colors)] for i in range(100)),
228235
}
229236
)
230-
df.plot.scatter("dataX", "dataY", c="state")
231237

232-
with tm.assert_produces_warning(None):
233-
ax = df.plot.scatter(x=0, y=1, c="state")
238+
# Then
239+
ax = df.plot.scatter("dataX", "dataY", c="color")
240+
assert len(np.unique(ax.collections[0].get_facecolor(), axis=0)) == color_count
241+
242+
# Given
243+
colors = ["r", "g", "not-a-color"]
244+
color_count = 3
245+
# Also, since not all are mpl-colors, points matching 'r' or 'g'
246+
# are not necessarily red or green
247+
248+
# When
249+
df = DataFrame(
250+
{
251+
"dataX": range(100),
252+
"dataY": range(100),
253+
"color": (colors[i % len(colors)] for i in range(100)),
254+
}
255+
)
234256

235-
assert len(np.unique(ax.collections[0].get_facecolor())) == 4 # 4 states
257+
# Then
258+
ax = df.plot.scatter("dataX", "dataY", c="color")
259+
assert len(np.unique(ax.collections[0].get_facecolor(), axis=0)) == color_count
236260

237261
def test_scatter_colors(self):
238262
df = DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
@@ -244,7 +268,8 @@ def test_scatter_colors_not_raising_warnings(self):
244268
# provided via 'c'. Parameters 'cmap' will be ignored
245269
df = DataFrame({"x": [1, 2, 3], "y": [1, 2, 3]})
246270
with tm.assert_produces_warning(None):
247-
df.plot.scatter(x="x", y="y", c="b")
271+
ax = df.plot.scatter(x="x", y="y", c="b")
272+
assert len(np.unique(ax.collections[0].get_facecolor(), axis=0)) == 1 # b
248273

249274
def test_scatter_colors_default(self):
250275
df = DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})

0 commit comments

Comments
 (0)