Skip to content

Commit b4440c1

Browse files
Michael Vincent ManninoMichael Vincent Mannino
authored andcommitted
extract logic into different functions; add plot (WIP)
1 parent 8609ea5 commit b4440c1

File tree

1 file changed

+27
-19
lines changed

1 file changed

+27
-19
lines changed

pandas/plotting/_matplotlib/core.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,6 +1337,11 @@ def _make_plot(self, fig: Figure) -> None:
13371337
norm, cmap = self._get_norm_and_cmap(c_values, color_by_categorical)
13381338
cb = self._get_colorbar(c_values, c_is_column)
13391339

1340+
orig_invalid_colors = not self._are_valid_colors(c_values)
1341+
if orig_invalid_colors:
1342+
unique_color_labels, c_values = self._convert_str_to_colors(c_values)
1343+
cb = False
1344+
13401345
if self.legend:
13411346
label = self.label
13421347
else:
@@ -1367,6 +1372,15 @@ def _make_plot(self, fig: Figure) -> None:
13671372
label, # type: ignore[arg-type]
13681373
)
13691374

1375+
if orig_invalid_colors:
1376+
for s in unique_color_labels:
1377+
self._append_legend_handles_labels(
1378+
# error: Argument 2 to "_append_legend_handles_labels" of
1379+
# "MPLPlot" has incompatible type "Hashable"; expected "str"
1380+
scatter,
1381+
s, # type: ignore[arg-type]
1382+
)
1383+
13701384
errors_x = self._get_errorbars(label=x, index=0, yerr=False)
13711385
errors_y = self._get_errorbars(label=y, index=0, xerr=False)
13721386
if len(errors_x) > 0 or len(errors_y) > 0:
@@ -1388,37 +1402,31 @@ def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool):
13881402
c_values = self.data[c].values
13891403
else:
13901404
c_values = c
1405+
return c_values
13911406

1392-
return self._prevalidate_c_values(c_values)
1393-
1394-
def _prevalidate_c_values(self, c_values):
1395-
# if c_values contains strings, pre-check whether these are valid mpl colors
1396-
# should we determine c_values are valid to this point, no changes are made
1397-
# to the object
1398-
1407+
def _are_valid_colors(self, c_values):
13991408
# check if c_values contains strings. no need to check numerics as these
14001409
# will be validated for us in .Axes.scatter._parse_scatter_color_args(...)
14011410
if not (
14021411
np.iterable(c_values) and len(c_values) > 0 and isinstance(c_values[0], str)
14031412
):
1404-
return c_values
1413+
return True
14051414

14061415
try:
1407-
_ = mpl.colors.to_rgba_array(c_values)
1408-
14091416
# similar to above, if this conversion is successful, remaining validation
14101417
# will be done in .Axes.scatter._parse_scatter_color_args(...)
1411-
return c_values
1418+
_ = mpl.colors.to_rgba_array(c_values)
1419+
return True
14121420

14131421
except (TypeError, ValueError) as _:
1414-
# invalid color strings, build numerics based off this
1415-
# map N unique str to N evenly spaced values [0, 1], colors
1416-
# will be automattically assigned based off this mapping
1417-
unique = np.unique(c_values)
1418-
colors = np.linspace(0, 1, len(unique))
1419-
color_mapping = dict(zip(unique, colors))
1420-
1421-
return np.array(list(map(color_mapping.get, c_values)))
1422+
return False
1423+
1424+
def _convert_str_to_colors(self, c_values):
1425+
unique = np.unique(c_values)
1426+
colors = np.linspace(0, 1, len(unique))
1427+
color_mapping = dict(zip(unique, colors))
1428+
1429+
return unique, np.array(list(map(color_mapping.get, c_values)))
14221430

14231431
def _get_norm_and_cmap(self, c_values, color_by_categorical: bool):
14241432
c = self.c

0 commit comments

Comments
 (0)