21
21
22
22
import matplotlib as mpl
23
23
import numpy as np
24
- from seaborn ._base import (
25
- HueMapping ,
26
- VectorPlotter ,
27
- )
28
24
29
25
from pandas ._libs import lib
30
26
from pandas .errors import AbstractMethodError
@@ -1351,38 +1347,28 @@ def _make_plot(self, fig: Figure) -> None:
1351
1347
# if a list of non color strings is passed in as c, color points
1352
1348
# by uniqueness of the strings, such same strings get same color
1353
1349
create_colors = not self ._are_valid_colors (c_values )
1354
-
1355
- # Plot as normal
1356
- if not create_colors :
1357
- scatter = ax .scatter (
1358
- data [x ].values ,
1359
- data [y ].values ,
1360
- c = c_values ,
1361
- label = label ,
1362
- cmap = cmap ,
1363
- norm = norm ,
1364
- s = self .s ,
1365
- ** self .kwds ,
1366
- )
1367
- # Have to custom color
1368
- else :
1369
- scatter = ax .scatter (
1370
- data [x ].values ,
1371
- data [y ].values ,
1372
- label = label ,
1373
- cmap = cmap ,
1374
- norm = norm ,
1375
- s = self .s ,
1376
- ** self .kwds ,
1350
+ if create_colors :
1351
+ color_mapping = self ._get_color_mapping (c_values )
1352
+ c_values = [color_mapping [s ] for s in c_values ]
1353
+
1354
+ # build legend for labeling custom colors
1355
+ ax .legend (
1356
+ handles = [
1357
+ mpl .patches .Circle ((0 , 0 ), facecolor = c , label = s )
1358
+ for s , c in color_mapping .items ()
1359
+ ]
1377
1360
)
1378
1361
1379
- # set colors via Seaborn as it contains all the logic for handling color
1380
- # decision all nicely packaged
1381
- scatter .set_facecolor (
1382
- HueMapping (
1383
- VectorPlotter (data = data , variables = {"x" : x , "y" : y , "hue" : c })
1384
- )(c_values )
1385
- )
1362
+ scatter = ax .scatter (
1363
+ data [x ].values ,
1364
+ data [y ].values ,
1365
+ c = c_values ,
1366
+ label = label ,
1367
+ cmap = cmap ,
1368
+ norm = norm ,
1369
+ s = self .s ,
1370
+ ** self .kwds ,
1371
+ )
1386
1372
1387
1373
if cb :
1388
1374
cbar_label = c if c_is_column else ""
@@ -1423,7 +1409,7 @@ def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool):
1423
1409
c_values = c
1424
1410
return c_values
1425
1411
1426
- def _are_valid_colors (self , c_values : np . ndarray ):
1412
+ def _are_valid_colors (self , c_values : Series ):
1427
1413
# check if c_values contains strings and if these strings are valid mpl colors.
1428
1414
# no need to check numerics as these (and mpl colors) will be validated for us
1429
1415
# in .Axes.scatter._parse_scatter_color_args(...)
@@ -1437,6 +1423,16 @@ def _are_valid_colors(self, c_values: np.ndarray):
1437
1423
except (TypeError , ValueError ) as _ :
1438
1424
return False
1439
1425
1426
+ def _get_color_mapping (self , c_values : Series ) -> dict [str , str ]:
1427
+ unique = np .unique (c_values )
1428
+ n_colors = len (unique )
1429
+
1430
+ # passing `None` here will default to :rc:`image.cmap`
1431
+ cmap = mpl .colormaps .get_cmap (self .colormap )
1432
+ colors = cmap (np .linspace (0 , 1 , n_colors )) # RGB tuples
1433
+
1434
+ return dict (zip (unique , colors ))
1435
+
1440
1436
def _get_norm_and_cmap (self , c_values , color_by_categorical : bool ):
1441
1437
c = self .c
1442
1438
if self .colormap is not None :
0 commit comments