Skip to content

Commit 571c0c8

Browse files
Michael Vincent ManninoMichael Vincent Mannino
authored andcommitted
create labels for custom colors
1 parent b4440c1 commit 571c0c8

File tree

1 file changed

+31
-26
lines changed

1 file changed

+31
-26
lines changed

pandas/plotting/_matplotlib/core.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Iterator,
1111
Sequence,
1212
)
13+
from random import shuffle
1314
from typing import (
1415
TYPE_CHECKING,
1516
Any,
@@ -1337,10 +1338,12 @@ def _make_plot(self, fig: Figure) -> None:
13371338
norm, cmap = self._get_norm_and_cmap(c_values, color_by_categorical)
13381339
cb = self._get_colorbar(c_values, c_is_column)
13391340

1340-
orig_invalid_colors = not self._are_valid_colors(c_values)
1341-
if orig_invalid_colors:
1342-
unique_color_labels, c_values = self._convert_str_to_colors(c_values)
1343-
cb = False
1341+
# if a list of non color strings is passed in as c, generate a list
1342+
# colored by uniqueness of the strings, such same strings get same color
1343+
create_colors = not self._are_valid_colors(c_values)
1344+
if create_colors:
1345+
color_mapping, c_values = self._uniquely_color_strs(c_values)
1346+
cb = False # no colorbar; opt for legend
13441347

13451348
if self.legend:
13461349
label = self.label
@@ -1372,14 +1375,14 @@ def _make_plot(self, fig: Figure) -> None:
13721375
label, # type: ignore[arg-type]
13731376
)
13741377

1375-
if orig_invalid_colors:
1376-
for s in unique_color_labels:
1377-
self._append_legend_handles_labels(
1378-
# error: Argument 2 to "_append_legend_handles_labels" of
1379-
# "MPLPlot" has incompatible type "Hashable"; expected "str"
1380-
scatter,
1381-
s, # type: ignore[arg-type]
1382-
)
1378+
# build legend for labeling custom colors
1379+
if create_colors:
1380+
ax.legend(
1381+
handles=[
1382+
mpl.patches.Circle((0, 0), facecolor=color, label=string)
1383+
for string, color in color_mapping.items()
1384+
]
1385+
)
13831386

13841387
errors_x = self._get_errorbars(label=x, index=0, yerr=False)
13851388
errors_y = self._get_errorbars(label=y, index=0, xerr=False)
@@ -1404,29 +1407,31 @@ def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool):
14041407
c_values = c
14051408
return c_values
14061409

1407-
def _are_valid_colors(self, c_values):
1408-
# check if c_values contains strings. no need to check numerics as these
1409-
# will be validated for us in .Axes.scatter._parse_scatter_color_args(...)
1410-
if not (
1411-
np.iterable(c_values) and len(c_values) > 0 and isinstance(c_values[0], str)
1412-
):
1413-
return True
1414-
1410+
def _are_valid_colors(self, c_values: np.ndarray | list):
1411+
# check if c_values contains strings and if these strings are valid mpl colors
1412+
# no need to check numerics as these (and mpl colors) will be validated for us
1413+
# in .Axes.scatter._parse_scatter_color_args(...)
14151414
try:
1416-
# similar to above, if this conversion is successful, remaining validation
1417-
# will be done in .Axes.scatter._parse_scatter_color_args(...)
1418-
_ = mpl.colors.to_rgba_array(c_values)
1415+
if len(c_values) and all(isinstance(c, str) for c in c_values):
1416+
mpl.colors.to_rgba_array(c_values)
1417+
14191418
return True
14201419

14211420
except (TypeError, ValueError) as _:
14221421
return False
14231422

1424-
def _convert_str_to_colors(self, c_values):
1423+
def _uniquely_color_strs(
1424+
self, c_values: np.ndarray | list
1425+
) -> tuple[dict, np.ndarray]:
1426+
# well, almost uniquely color them (up to 949)
1427+
possible_colors = list(mpl.colors.XKCD_COLORS.values()) # Hex representations
1428+
shuffle(possible_colors) # TODO: find better way of getting colors
1429+
14251430
unique = np.unique(c_values)
1426-
colors = np.linspace(0, 1, len(unique))
1431+
colors = [possible_colors[i % len(possible_colors)] for i in range(len(unique))]
14271432
color_mapping = dict(zip(unique, colors))
14281433

1429-
return unique, np.array(list(map(color_mapping.get, c_values)))
1434+
return color_mapping, np.array(list(map(color_mapping.get, c_values)))
14301435

14311436
def _get_norm_and_cmap(self, c_values, color_by_categorical: bool):
14321437
c = self.c

0 commit comments

Comments
 (0)