11from __future__ import annotations
22
33from abc import ABC , abstractmethod
4- from collections .abc import Hashable
4+ from collections .abc import Callable , Hashable
55from itertools import cycle , islice
66from typing import TYPE_CHECKING , Generic , TypeAlias , TypeVar
77
@@ -102,12 +102,11 @@ def get(self, value: Hashable) -> int:
102102
103103 return self .index [value ]
104104
105- def __getitem__ (self , value : Hashable | dict ) -> T :
106- if value == {None : 0 }: # from series
105+ def __getitem__ (self , key : dict ) -> T :
106+ if key == {None : 0 }: # from series
107107 return self .items [0 ]
108108
109- if isinstance (value , dict ):
110- value = tuple (value [k ] for k in self .columns )
109+ value = tuple (key [k ] for k in self .columns )
111110
112111 return self .items [self .get (value )]
113112
@@ -145,35 +144,63 @@ def cycle_colors(skips: Iterable[str] | None = None) -> Iterator[str]:
145144 yield color
146145
147146
147+ class FunctionPalette (Generic [T ]):
148+ columns : str | list [str ]
149+ func : Callable [[Hashable ], T ]
150+
151+ def __init__ (self , columns : str | list [str ], func : Callable [[Hashable ], T ]) -> None :
152+ self .columns = columns
153+ self .func = func
154+
155+ def __getitem__ (self , key : dict ) -> T :
156+ if isinstance (self .columns , str ):
157+ return self .func (key [self .columns ])
158+
159+ value = tuple (key [k ] for k in self .columns )
160+ return self .func (value )
161+
162+
148163PaletteStyle : TypeAlias = (
149164 str
150165 | list [str ]
151166 | dict [Hashable , str ]
152- | tuple [str | list [str ], dict [Hashable , str ] | list [str ]]
167+ | Callable [[Hashable ], str ]
168+ | tuple [str | list [str ], list [str ] | dict [Hashable , str ]]
169+ | tuple [str | list [str ], Callable [[Hashable ], str ]]
153170 | Palette
171+ | FunctionPalette
154172)
155173
156174
157175def get_palette (
158176 cls : type [Palette ],
159177 data : DataFrame ,
160178 style : PaletteStyle | None ,
161- ) -> Palette | None :
179+ ) -> Palette | FunctionPalette | None :
162180 """Get a palette from a style."""
163- if isinstance (style , Palette ):
181+ if isinstance (style , Palette | FunctionPalette ):
164182 return style
165183
166184 if style is None :
167185 return None
168186
187+ if isinstance (style , Callable ):
188+ if isinstance (data .index , MultiIndex ):
189+ return FunctionPalette (data .index .names , style )
190+ return FunctionPalette (data .index .name , style )
191+
169192 if data .index .name is not None or isinstance (data .index , MultiIndex ):
170193 data = data .index .to_frame (index = False )
171194
172195 if isinstance (style , dict ):
173196 return cls (data , data .columns .to_list (), style )
174197
175198 if isinstance (style , tuple ):
176- return cls (data , * style )
199+ columns , default = style
200+ if callable (default ):
201+ return FunctionPalette (columns , default )
202+
203+ return cls (data , columns , default )
177204
178205 columns = style
179206
@@ -192,12 +219,12 @@ def get_palette(
192219def get_marker_palette (
193220 data : DataFrame ,
194221 marker : PaletteStyle | None ,
195- ) -> MarkerPalette | None :
222+ ) -> MarkerPalette | FunctionPalette | None :
196223 return get_palette (MarkerPalette , data , marker ) # type: ignore
197224
198225
199226def get_color_palette (
200227 data : DataFrame ,
201228 color : PaletteStyle | None ,
202- ) -> ColorPalette | None :
229+ ) -> ColorPalette | FunctionPalette | None :
203230 return get_palette (ColorPalette , data , color ) # type: ignore
0 commit comments