10
10
Iterator ,
11
11
Sequence ,
12
12
)
13
- from random import shuffle
14
13
from typing import (
15
14
TYPE_CHECKING ,
16
15
Any ,
22
21
23
22
import matplotlib as mpl
24
23
import numpy as np
24
+ from seaborn ._base import (
25
+ HueMapping ,
26
+ VectorPlotter ,
27
+ )
25
28
26
29
from pandas ._libs import lib
27
30
from pandas .errors import AbstractMethodError
@@ -1340,27 +1343,47 @@ def _make_plot(self, fig: Figure) -> None:
1340
1343
norm , cmap = self ._get_norm_and_cmap (c_values , color_by_categorical )
1341
1344
cb = self ._get_colorbar (c_values , c_is_column )
1342
1345
1343
- # if a list of non color strings is passed in as c, generate a list
1344
- # colored by uniqueness of the strings, such same strings get same color
1345
- create_colors = not self ._are_valid_colors (c_values )
1346
- if create_colors :
1347
- custom_color_mapping , c_values = self ._uniquely_color_strs (c_values )
1348
- cb = False # no colorbar; opt for legend
1349
-
1350
1346
if self .legend :
1351
1347
label = self .label
1352
1348
else :
1353
1349
label = None
1354
- scatter = ax .scatter (
1355
- data [x ].values ,
1356
- data [y ].values ,
1357
- c = c_values ,
1358
- label = label ,
1359
- cmap = cmap ,
1360
- norm = norm ,
1361
- s = self .s ,
1362
- ** self .kwds ,
1363
- )
1350
+
1351
+ # if a list of non color strings is passed in as c, color points
1352
+ # by uniqueness of the strings, such same strings get same color
1353
+ create_colors = not self ._are_valid_colors (c_values )
1354
+
1355
+ # Plot as normal
1356
+ if not create_colors :
1357
+ scatter = ax .scatter (
1358
+ data [x ].values ,
1359
+ data [y ].values ,
1360
+ c = c_values ,
1361
+ label = label ,
1362
+ cmap = cmap ,
1363
+ norm = norm ,
1364
+ s = self .s ,
1365
+ ** self .kwds ,
1366
+ )
1367
+ # Have to custom color
1368
+ else :
1369
+ scatter = ax .scatter (
1370
+ data [x ].values ,
1371
+ data [y ].values ,
1372
+ label = label ,
1373
+ cmap = cmap ,
1374
+ norm = norm ,
1375
+ s = self .s ,
1376
+ ** self .kwds ,
1377
+ )
1378
+
1379
+ # set colors via Seaborn as it contains all the logic for handling color
1380
+ # decision all nicely packaged
1381
+ scatter .set_facecolor (
1382
+ HueMapping (
1383
+ VectorPlotter (data = data , variables = {"x" : x , "y" : y , "hue" : c })
1384
+ )(c_values )
1385
+ )
1386
+
1364
1387
if cb :
1365
1388
cbar_label = c if c_is_column else ""
1366
1389
cbar = self ._plot_colorbar (ax , fig = fig , label = cbar_label )
@@ -1377,15 +1400,6 @@ def _make_plot(self, fig: Figure) -> None:
1377
1400
label , # type: ignore[arg-type]
1378
1401
)
1379
1402
1380
- # build legend for labeling custom colors
1381
- if create_colors :
1382
- ax .legend (
1383
- handles = [
1384
- mpl .patches .Circle ((0 , 0 ), facecolor = color , label = string )
1385
- for string , color in custom_color_mapping .items ()
1386
- ]
1387
- )
1388
-
1389
1403
errors_x = self ._get_errorbars (label = x , index = 0 , yerr = False )
1390
1404
errors_y = self ._get_errorbars (label = y , index = 0 , xerr = False )
1391
1405
if len (errors_x ) > 0 or len (errors_y ) > 0 :
@@ -1409,38 +1423,20 @@ def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool):
1409
1423
c_values = c
1410
1424
return c_values
1411
1425
1412
- def _are_valid_colors (self , c_values : np .ndarray | list ):
1426
+ def _are_valid_colors (self , c_values : np .ndarray ):
1413
1427
# check if c_values contains strings and if these strings are valid mpl colors.
1414
1428
# no need to check numerics as these (and mpl colors) will be validated for us
1415
1429
# in .Axes.scatter._parse_scatter_color_args(...)
1430
+ unique = np .unique (c_values )
1416
1431
try :
1417
- if len (c_values ) and all (isinstance (c , str ) for c in c_values ):
1418
- mpl .colors .to_rgba_array (c_values )
1432
+ if len (c_values ) and all (isinstance (c , str ) for c in unique ):
1433
+ mpl .colors .to_rgba_array (unique )
1419
1434
1420
1435
return True
1421
1436
1422
1437
except (TypeError , ValueError ) as _ :
1423
1438
return False
1424
1439
1425
- def _uniquely_color_strs (
1426
- self , c_values : np .ndarray | list
1427
- ) -> tuple [dict , np .ndarray ]:
1428
- # well, almost uniquely color them (up to 949)
1429
- unique = np .unique (c_values )
1430
-
1431
- # for up to 7, lets keep colors consistent
1432
- if len (unique ) <= 7 :
1433
- possible_colors = list (mpl .colors .BASE_COLORS .values ()) # Hex
1434
- # explore better ways to handle this case
1435
- else :
1436
- possible_colors = list (mpl .colors .XKCD_COLORS .values ()) # Hex
1437
- shuffle (possible_colors )
1438
-
1439
- colors = [possible_colors [i % len (possible_colors )] for i in range (len (unique ))]
1440
- color_mapping = dict (zip (unique , colors ))
1441
-
1442
- return color_mapping , np .array (list (map (color_mapping .get , c_values )))
1443
-
1444
1440
def _get_norm_and_cmap (self , c_values , color_by_categorical : bool ):
1445
1441
c = self .c
1446
1442
if self .colormap is not None :
0 commit comments