-
-
Notifications
You must be signed in to change notification settings - Fork 19.1k
ENH: DataFrame.plot.scatter argument c
now accepts a column of strings, where rows with the same string are colored identically
#59239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
b91e635
8609ea5
b4440c1
571c0c8
e9511d0
1ca57ed
fb0d6e4
4bcdbfc
7972138
45886d9
1713727
62427ad
609fe40
6e86858
5223f2a
d97606c
7e5a02a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
Iterator, | ||
Sequence, | ||
) | ||
from random import shuffle | ||
from typing import ( | ||
TYPE_CHECKING, | ||
Any, | ||
|
@@ -1337,6 +1338,13 @@ def _make_plot(self, fig: Figure) -> None: | |
norm, cmap = self._get_norm_and_cmap(c_values, color_by_categorical) | ||
cb = self._get_colorbar(c_values, c_is_column) | ||
|
||
# if a list of non color strings is passed in as c, generate a list | ||
# colored by uniqueness of the strings, such same strings get same color | ||
create_colors = not self._are_valid_colors(c_values) | ||
if create_colors: | ||
custom_color_mapping, c_values = self._uniquely_color_strs(c_values) | ||
cb = False # no colorbar; opt for legend | ||
|
||
if self.legend: | ||
label = self.label | ||
else: | ||
|
@@ -1367,6 +1375,15 @@ def _make_plot(self, fig: Figure) -> None: | |
label, # type: ignore[arg-type] | ||
) | ||
|
||
# build legend for labeling custom colors | ||
if create_colors: | ||
ax.legend( | ||
handles=[ | ||
mpl.patches.Circle((0, 0), facecolor=color, label=string) | ||
for string, color in custom_color_mapping.items() | ||
] | ||
) | ||
|
||
errors_x = self._get_errorbars(label=x, index=0, yerr=False) | ||
errors_y = self._get_errorbars(label=y, index=0, xerr=False) | ||
if len(errors_x) > 0 or len(errors_y) > 0: | ||
|
@@ -1390,6 +1407,38 @@ def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool): | |
c_values = c | ||
return c_values | ||
|
||
def _are_valid_colors(self, c_values: np.ndarray | list): | ||
# check if c_values contains strings and if these strings are valid mpl colors. | ||
# no need to check numerics as these (and mpl colors) will be validated for us | ||
# in .Axes.scatter._parse_scatter_color_args(...) | ||
try: | ||
if len(c_values) and all(isinstance(c, str) for c in c_values): | ||
mpl.colors.to_rgba_array(c_values) | ||
|
||
return True | ||
|
||
except (TypeError, ValueError) as _: | ||
return False | ||
|
||
def _uniquely_color_strs( | ||
|
||
self, c_values: np.ndarray | list | ||
) -> tuple[dict, np.ndarray]: | ||
# well, almost uniquely color them (up to 949) | ||
unique = np.unique(c_values) | ||
|
||
# for up to 7, lets keep colors consistent | ||
if len(unique) <= 7: | ||
possible_colors = list(mpl.colors.BASE_COLORS.values()) # Hex | ||
# explore better ways to handle this case | ||
else: | ||
possible_colors = list(mpl.colors.XKCD_COLORS.values()) # Hex | ||
shuffle(possible_colors) | ||
|
||
colors = [possible_colors[i % len(possible_colors)] for i in range(len(unique))] | ||
color_mapping = dict(zip(unique, colors)) | ||
|
||
return color_mapping, np.array(list(map(color_mapping.get, c_values))) | ||
|
||
def _get_norm_and_cmap(self, c_values, color_by_categorical: bool): | ||
c = self.c | ||
if self.colormap is not None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -207,6 +207,21 @@ def test_scatter_with_c_column_name_with_colors(self, cmap): | |
ax = df.plot.scatter(x=0, y=1, c="species", cmap=cmap) | ||
assert ax.collections[0].colorbar is None | ||
|
||
def test_scatter_with_c_column_name_without_colors(self): | ||
df = DataFrame( | ||
{ | ||
"dataX": range(100), | ||
"dataY": range(100), | ||
"state": ["NY", "MD", "MA", "CA"] * 25, | ||
} | ||
) | ||
df.plot.scatter("dataX", "dataY", c="state") | ||
|
||
with tm.assert_produces_warning(None): | ||
ax = df.plot.scatter(x=0, y=1, c="state") | ||
|
||
|
||
assert len(np.unique(ax.collections[0].get_facecolor())) == 4 # 4 states | ||
|
||
def test_scatter_colors(self): | ||
df = DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]}) | ||
with pytest.raises(TypeError, match="Specify exactly one of `c` and `color`"): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In what instances is
c_values
a list? Might be misreading but would be better if we only worked with a pd.Series and could call .unique on that, instead of checking every single value in a loopThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to take a
pd.Series
, notnp.ndarray | list