55from typing import TYPE_CHECKING , TypeAlias
66
77import pandas as pd
8- from pandas import DataFrame
8+ from pandas import DataFrame , Index
99
10- from .palette import ColorPalette , MarkerPalette , PaletteStyle , get_palette
10+ from .palette import PaletteStyle , get_color_palette , get_marker_palette
1111
1212if TYPE_CHECKING :
13- from collections .abc import Iterator
14- from typing import Self
13+ from collections .abc import Iterator , Sequence
14+ from typing import Any , Self
1515
1616 from xlviews .chart .axes import Axes
1717 from xlviews .chart .series import Series
@@ -70,8 +70,8 @@ def set(
7070 size : int | None = None ,
7171 ) -> Self :
7272 index = self .data .index .to_frame (index = False )
73- marker_palette = get_palette ( MarkerPalette , index , marker )
74- color_palette = get_palette ( ColorPalette , index , color )
73+ marker_palette = get_marker_palette ( index , marker )
74+ color_palette = get_color_palette ( index , color )
7575
7676 for key , s in zip (self .keys (), self .series_collection , strict = True ):
7777 s .set (
@@ -85,6 +85,34 @@ def set(
8585
8686 return self
8787
88+ @classmethod
89+ def facet (
90+ cls ,
91+ axes : Axes ,
92+ data : DataFrame ,
93+ index : str | list [str ] | None = None ,
94+ columns : str | list [str ] | None = None ,
95+ ) -> Iterator [tuple [dict [str , Any ], Self ]]:
96+ left = axes .chart .left
97+ top = axes .chart .top
98+ width = axes .chart .width
99+ height = axes .chart .height
100+
101+ for r , rkey in enumerate (iterrows (data .index , index )):
102+ for c , ckey in enumerate (iterrows (data .index , columns )):
103+ key = rkey | ckey
104+ sub = xs (data , key )
105+
106+ if len (sub ) == 0 :
107+ continue
108+
109+ if r == 0 and c == 0 :
110+ axes_ = axes
111+ else :
112+ axes_ = axes .copy (left = left + c * width , top = top + r * height )
113+
114+ yield key , cls (axes_ , sub )
115+
88116
89117def get_label (label : Label , key : dict [str , Hashable ]) -> str :
90118 if isinstance (label , str ):
@@ -95,3 +123,29 @@ def get_label(label: Label, key: dict[str, Hashable]) -> str:
95123
96124 msg = f"Invalid label: { label } "
97125 raise ValueError (msg )
126+
127+
128+ def iterrows (
129+ index : Index ,
130+ levels : int | str | Sequence [int | str ] | None ,
131+ ) -> Iterator [dict [str , Any ]]:
132+ if levels is None :
133+ yield {}
134+ return
135+
136+ if isinstance (levels , int | str ):
137+ levels = [levels ]
138+
139+ if levels :
140+ values = {level : index .get_level_values (level ) for level in levels }
141+ it = DataFrame (values ).drop_duplicates ().iterrows ()
142+
143+ for _ , s in it :
144+ yield s .to_dict ()
145+
146+
147+ def xs (df : DataFrame , index : dict [str , Any ] | None ) -> DataFrame :
148+ if index :
149+ df = df .xs (tuple (index .values ()), 0 , tuple (index .keys ()), drop_level = False ) # type: ignore
150+
151+ return df
0 commit comments