Skip to content

Commit 0c4471d

Browse files
authored
Merge pull request #51 from daizutabi/50-color-and-marker-accept-callable
Enhance Plot and Palette modules with flexible marker and color generation
2 parents 03e6f90 + d517c28 commit 0c4471d

File tree

7 files changed

+193
-47
lines changed

7 files changed

+193
-47
lines changed

src/xlviews/figure/palette.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from collections.abc import Hashable
4+
from collections.abc import Callable, Hashable
55
from itertools import cycle, islice
66
from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar
77

@@ -102,12 +102,11 @@ def get(self, value: Hashable) -> int:
102102

103103
return self.index[value]
104104

105-
def __getitem__(self, value: Hashable | dict) -> T:
106-
if value == {None: 0}: # from series
105+
def __getitem__(self, key: dict) -> T:
106+
if key == {None: 0}: # from series
107107
return self.items[0]
108108

109-
if isinstance(value, dict):
110-
value = tuple(value[k] for k in self.columns)
109+
value = tuple(key[k] for k in self.columns)
111110

112111
return self.items[self.get(value)]
113112

@@ -145,35 +144,63 @@ def cycle_colors(skips: Iterable[str] | None = None) -> Iterator[str]:
145144
yield color
146145

147146

147+
class FunctionPalette(Generic[T]):
148+
columns: str | list[str]
149+
func: Callable[[Hashable], T]
150+
151+
def __init__(self, columns: str | list[str], func: Callable[[Hashable], T]) -> None:
152+
self.columns = columns
153+
self.func = func
154+
155+
def __getitem__(self, key: dict) -> T:
156+
if isinstance(self.columns, str):
157+
return self.func(key[self.columns])
158+
159+
value = tuple(key[k] for k in self.columns)
160+
return self.func(value)
161+
162+
148163
PaletteStyle: TypeAlias = (
149164
str
150165
| list[str]
151166
| dict[Hashable, str]
152-
| tuple[str | list[str], dict[Hashable, str] | list[str]]
167+
| Callable[[Hashable], str]
168+
| tuple[str | list[str], list[str] | dict[Hashable, str]]
169+
| tuple[str | list[str], Callable[[Hashable], str]]
153170
| Palette
171+
| FunctionPalette
154172
)
155173

156174

157175
def get_palette(
158176
cls: type[Palette],
159177
data: DataFrame,
160178
style: PaletteStyle | None,
161-
) -> Palette | None:
179+
) -> Palette | FunctionPalette | None:
162180
"""Get a palette from a style."""
163-
if isinstance(style, Palette):
181+
if isinstance(style, Palette | FunctionPalette):
164182
return style
165183

166184
if style is None:
167185
return None
168186

187+
if isinstance(style, Callable):
188+
if isinstance(data.index, MultiIndex):
189+
return FunctionPalette(data.index.names, style)
190+
return FunctionPalette(data.index.name, style)
191+
169192
if data.index.name is not None or isinstance(data.index, MultiIndex):
170193
data = data.index.to_frame(index=False)
171194

172195
if isinstance(style, dict):
173196
return cls(data, data.columns.to_list(), style)
174197

175198
if isinstance(style, tuple):
176-
return cls(data, *style)
199+
columns, default = style
200+
if callable(default):
201+
return FunctionPalette(columns, default)
202+
203+
return cls(data, columns, default)
177204

178205
columns = style
179206

@@ -192,12 +219,12 @@ def get_palette(
192219
def get_marker_palette(
193220
data: DataFrame,
194221
marker: PaletteStyle | None,
195-
) -> MarkerPalette | None:
222+
) -> MarkerPalette | FunctionPalette | None:
196223
return get_palette(MarkerPalette, data, marker) # type: ignore
197224

198225

199226
def get_color_palette(
200227
data: DataFrame,
201228
color: PaletteStyle | None,
202-
) -> ColorPalette | None:
229+
) -> ColorPalette | FunctionPalette | None:
203230
return get_palette(ColorPalette, data, color) # type: ignore

src/xlviews/figure/plot.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,8 @@ def set(
6969
weight: float | None = None,
7070
size: int | None = None,
7171
) -> Self:
72-
index = self.data.index.to_frame(index=False)
73-
marker_palette = get_marker_palette(index, marker)
74-
color_palette = get_color_palette(index, color)
72+
marker_palette = get_marker_palette(self.data, marker)
73+
color_palette = get_color_palette(self.data, color)
7574

7675
for key, s in zip(self.keys(), self.series_collection, strict=True):
7776
s.set(

src/xlviews/testing/figure/plot.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
35
from xlwings.constants import ChartType
46

57
from xlviews.chart.axes import Axes
68
from xlviews.figure.plot import Plot
79
from xlviews.testing.chart import Base
810
from xlviews.testing.common import create_sheet
911

12+
if TYPE_CHECKING:
13+
from collections.abc import Hashable
14+
1015
if __name__ == "__main__":
1116
sheet = create_sheet()
1217
fc = Base(sheet, style=True)
@@ -21,23 +26,49 @@
2126
.set(label="abc", marker="o", color="blue", alpha=0.6)
2227
)
2328

24-
ax = Axes(chart_type=ChartType.xlXYScatterLines)
29+
ax = Axes()
2530
data = sf.groupby("b").agg(include_sheetname=True)
2631
p = (
2732
Plot(ax, data)
28-
.add("x", "y")
33+
.add("x", "y", ChartType.xlXYScatterLines)
2934
.set(label="b={b}", marker=["o", "s"], color={"s": "red", "t": "blue"})
3035
)
3136

32-
ax = Axes(chart_type=ChartType.xlXYScatterLines)
37+
ax = Axes()
3338
data = sf.groupby(["b", "c"]).agg(include_sheetname=True)
3439
p = (
3540
Plot(ax, data)
36-
.add("x", "y")
41+
.add("x", "y", ChartType.xlXYScatterLines)
3742
.set(
3843
label=lambda x: f"{x['b']},{x['c']}",
3944
marker="b",
4045
color=("c", ["red", "green"]),
4146
size=10,
4247
)
4348
)
49+
50+
def m(x: Hashable) -> str:
51+
if x == "s":
52+
return "o"
53+
return "^"
54+
55+
ax = Axes(left=0)
56+
data = sf.groupby("b").agg(include_sheetname=True)
57+
p = (
58+
Plot(ax, data)
59+
.add("x", "y", ChartType.xlXYScatter)
60+
.set(label="{b}", marker=m, size=10)
61+
)
62+
63+
def c(x: Hashable) -> str:
64+
if x == ("s", 100):
65+
return "red"
66+
return "blue"
67+
68+
ax = Axes()
69+
data = sf.groupby(["b", "c"]).agg(include_sheetname=True)
70+
p = (
71+
Plot(ax, data)
72+
.add("x", "y", ChartType.xlXYScatter)
73+
.set(label="{b}_{c}", color=c, marker=("b", m), size=10)
74+
)

tests/dataframes/dist_frame/test_sigma.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from xlwings import Sheet
55

66
from xlviews.core.range import Range
7+
from xlviews.testing import is_app_available
8+
9+
pytestmark = pytest.mark.skipif(not is_app_available(), reason="Excel not installed")
710

811

912
@pytest.fixture(params=["norm", "weibull"])

tests/figure/test_grid.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
from xlviews.chart.axes import Axes
55
from xlviews.figure.grid import Grid, Series
6+
from xlviews.testing import is_app_available
7+
8+
pytestmark = pytest.mark.skipif(not is_app_available(), reason="Excel not installed")
69

710

811
@pytest.fixture(scope="module")

0 commit comments

Comments
 (0)