Skip to content

Commit c093f81

Browse files
committed
Merge branch 'grid' into feature
2 parents 1328dd6 + 92d0f11 commit c093f81

File tree

27 files changed

+456
-179
lines changed

27 files changed

+456
-179
lines changed

src/xlviews/chart/axes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,14 @@ def copy(
135135
has_legend = self.chart.api[1].HasLegend
136136
include_in_layout = self.chart.api[1].Legend.IncludeInLayout
137137

138+
if left == 0:
139+
left = self.chart.left + self.chart.width
140+
top = self.chart.top
141+
142+
if top == 0:
143+
left = self.chart.left
144+
top = self.chart.top + self.chart.height
145+
138146
return self.__class__(
139147
row=row,
140148
column=column,

src/xlviews/core/formula.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,14 @@ def _aggregate(
7070
if func is None:
7171
return column
7272

73-
if isinstance(func, str) and func in AGG_FUNCS:
74-
return f"AGGREGATE({AGG_FUNCS[func]},{option},{column})"
73+
if isinstance(func, str):
74+
if func in AGG_FUNCS:
75+
return f"AGGREGATE({AGG_FUNCS[func]},{option},{column})"
7576

76-
if isinstance(func, Range | RangeImpl):
77-
ref = func.get_address(column_absolute=False, row_absolute=False)
78-
else:
79-
ref = func
77+
msg = f"Invalid aggregate function: {func}"
78+
raise ValueError(msg)
8079

80+
ref = func.get_address(column_absolute=False, row_absolute=False)
8181
func = f"LOOKUP({ref},{{{AGG_FUNC_NAMES}}},{{{AGG_FUNC_INTS}}})"
8282
soa = aggregate("soa", ranges, option=option, **kwargs)
8383
return f'IF({ref}="soa",{soa},AGGREGATE({func},{option},{column}))'

src/xlviews/dataframes/heat_frame.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,9 @@ def xs(
198198
columns: dict[Hashable, Any] | None,
199199
) -> DataFrame:
200200
if index:
201-
for key, value in index.items():
202-
df = df.xs(value, level=key, axis=0) # type: ignore
201+
df = df.xs(tuple(index.values()), 0, tuple(index.keys())) # type: ignore
203202

204203
if columns:
205-
for key, value in columns.items():
206-
df = df.xs(value, level=key, axis=1) # type: ignore
204+
df = df.xs(tuple(columns.values()), 1, tuple(columns.keys())) # type: ignore
207205

208206
return df

src/xlviews/dataframes/style.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,11 +264,8 @@ def _set_heat_border(sf: HeatFrame) -> None:
264264
ec = rcParams["heat.border.color"]
265265

266266
for row in iter_group_locs(sf.index, offset=r):
267-
if row[0] == row[1]:
268-
continue
269-
270267
for col in iter_group_locs(sf.columns, offset=c):
271-
if col[0] == col[1]:
268+
if row[0] == row[1] and col[0] == col[1]:
272269
continue
273270

274271
rng = sf.sheet.range((row[0], col[0]), (row[1], col[1]))

src/xlviews/figure/groupby.py

Lines changed: 0 additions & 34 deletions
This file was deleted.

src/xlviews/figure/palette.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from itertools import cycle, islice
66
from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar
77

8+
from pandas import MultiIndex
9+
810
from xlviews.chart.style import COLORS, MARKER_DICT
911

1012
if TYPE_CHECKING:
@@ -164,6 +166,9 @@ def get_palette(
164166
if style is None:
165167
return None
166168

169+
if data.index.name is not None or isinstance(data.index, MultiIndex):
170+
data = data.index.to_frame(index=False)
171+
167172
if isinstance(style, dict):
168173
return cls(data, data.columns.to_list(), style)
169174

@@ -182,3 +187,17 @@ def get_palette(
182187
return cls(data, data.columns.tolist(), default) # type: ignore
183188

184189
return cls(data, columns)
190+
191+
192+
def get_marker_palette(
193+
data: DataFrame,
194+
marker: PaletteStyle | None,
195+
) -> MarkerPalette | None:
196+
return get_palette(MarkerPalette, data, marker) # type: ignore
197+
198+
199+
def get_color_palette(
200+
data: DataFrame,
201+
color: PaletteStyle | None,
202+
) -> ColorPalette | None:
203+
return get_palette(ColorPalette, data, color) # type: ignore

src/xlviews/figure/plot.py

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
from typing import TYPE_CHECKING, TypeAlias
66

77
import pandas as pd
8-
from pandas import DataFrame
8+
from pandas import DataFrame, Index
99

10-
from .palette import ColorPalette, MarkerPalette, PaletteStyle, get_palette
10+
from .palette import PaletteStyle, get_color_palette, get_marker_palette
1111

1212
if TYPE_CHECKING:
13-
from collections.abc import Iterator
14-
from typing import Self
13+
from collections.abc import Iterator, Sequence
14+
from typing import Any, Self
1515

1616
from xlviews.chart.axes import Axes
1717
from xlviews.chart.series import Series
@@ -70,8 +70,8 @@ def set(
7070
size: int | None = None,
7171
) -> Self:
7272
index = self.data.index.to_frame(index=False)
73-
marker_palette = get_palette(MarkerPalette, index, marker)
74-
color_palette = get_palette(ColorPalette, index, color)
73+
marker_palette = get_marker_palette(index, marker)
74+
color_palette = get_color_palette(index, color)
7575

7676
for key, s in zip(self.keys(), self.series_collection, strict=True):
7777
s.set(
@@ -85,6 +85,34 @@ def set(
8585

8686
return self
8787

88+
@classmethod
89+
def facet(
90+
cls,
91+
axes: Axes,
92+
data: DataFrame,
93+
index: str | list[str] | None = None,
94+
columns: str | list[str] | None = None,
95+
) -> Iterator[tuple[dict[str, Any], Self]]:
96+
left = axes.chart.left
97+
top = axes.chart.top
98+
width = axes.chart.width
99+
height = axes.chart.height
100+
101+
for r, rkey in enumerate(iterrows(data.index, index)):
102+
for c, ckey in enumerate(iterrows(data.index, columns)):
103+
key = rkey | ckey
104+
sub = xs(data, key)
105+
106+
if len(sub) == 0:
107+
continue
108+
109+
if r == 0 and c == 0:
110+
axes_ = axes
111+
else:
112+
axes_ = axes.copy(left=left + c * width, top=top + r * height)
113+
114+
yield key, cls(axes_, sub)
115+
88116

89117
def get_label(label: Label, key: dict[str, Hashable]) -> str:
90118
if isinstance(label, str):
@@ -95,3 +123,29 @@ def get_label(label: Label, key: dict[str, Hashable]) -> str:
95123

96124
msg = f"Invalid label: {label}"
97125
raise ValueError(msg)
126+
127+
128+
def iterrows(
129+
index: Index,
130+
levels: int | str | Sequence[int | str] | None,
131+
) -> Iterator[dict[str, Any]]:
132+
if levels is None:
133+
yield {}
134+
return
135+
136+
if isinstance(levels, int | str):
137+
levels = [levels]
138+
139+
if levels:
140+
values = {level: index.get_level_values(level) for level in levels}
141+
it = DataFrame(values).drop_duplicates().iterrows()
142+
143+
for _, s in it:
144+
yield s.to_dict()
145+
146+
147+
def xs(df: DataFrame, index: dict[str, Any] | None) -> DataFrame:
148+
if index:
149+
df = df.xs(tuple(index.values()), 0, tuple(index.keys()), drop_level=False) # type: ignore
150+
151+
return df
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from xlviews.chart.axes import Axes
6+
from xlviews.figure.palette import get_color_palette, get_marker_palette
7+
from xlviews.figure.plot import Plot
8+
from xlviews.testing.common import create_sheet
9+
from xlviews.testing.sheet_frame.pivot import Pivot
10+
11+
if TYPE_CHECKING:
12+
from collections.abc import Iterator
13+
from typing import Any
14+
15+
from xlviews.dataframes.sheet_frame import SheetFrame
16+
17+
18+
def facet(sf: SheetFrame) -> Iterator[tuple[dict[str, Any], Plot]]:
19+
sf.set_adjacent_column_width(1)
20+
21+
ax = Axes(2, 11)
22+
data = sf.groupby(["A", "B", "X"]).agg(include_sheetname=True)
23+
cp = get_color_palette(data, "X")
24+
mp = get_marker_palette(data, "A")
25+
26+
for key, plot in Plot.facet(ax, data, index="B", columns="A"):
27+
plot.add("u", "v").set(color=cp, marker=mp, label="{X}", alpha=0.8)
28+
plot.axes.title = "{A}_{B}".format(**key)
29+
yield key, plot
30+
31+
32+
if __name__ == "__main__":
33+
sheet = create_sheet()
34+
fc = Pivot(sheet, style=True)
35+
list(facet(fc.sf))

src/xlviews/testing/heat_frame/pair.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,17 @@
99

1010
if TYPE_CHECKING:
1111
from collections.abc import Hashable, Iterator
12-
from typing import Any
12+
from typing import Any, Literal
1313

1414
from xlviews.dataframes.sheet_frame import SheetFrame
1515

1616

17-
def pair(sf: SheetFrame) -> Iterator[tuple[dict[Hashable, Any], HeatFrame]]:
17+
def pair(
18+
sf: SheetFrame,
19+
values: str | list[str] | None = None,
20+
columns: str | list[str] | None = None,
21+
axis: Literal[0, 1] | None = None,
22+
) -> Iterator[tuple[dict[Hashable, Any], HeatFrame]]:
1823
sf.set_adjacent_column_width(1)
1924

2025
rng = sf.get_range("u")
@@ -27,7 +32,15 @@ def pair(sf: SheetFrame) -> Iterator[tuple[dict[Hashable, Any], HeatFrame]]:
2732

2833
df = sf.pivot_table(["u", "v"], ["B", "Y", "y"], ["A", "X", "x"], formula=True)
2934

30-
for key, frame in HeatFrame.pair(2, 13, df, index="B", columns="A", axis=1):
35+
for key, frame in HeatFrame.pair(
36+
2,
37+
13,
38+
df,
39+
values=values,
40+
index="B",
41+
columns=columns,
42+
axis=axis,
43+
):
3144
frame.autofit()
3245
frame.set_adjacent_column_width(1)
3346
if key["value"] == "u":
@@ -42,4 +55,6 @@ def pair(sf: SheetFrame) -> Iterator[tuple[dict[Hashable, Any], HeatFrame]]:
4255
if __name__ == "__main__":
4356
sheet = create_sheet()
4457
fc = Pivot(sheet, style=True)
45-
list(pair(fc.sf))
58+
# list(pair(fc.sf, values="u", columns=None))
59+
# list(pair(fc.sf, values=None, columns="A", axis=1))
60+
list(pair(fc.sf, values="v", columns="A"))

tests/chart/axes/test_position.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,15 @@ def test_chart_position_top_left(sheet: Sheet):
6767
c = Axes(sheet=sheet, top=0, left=200)
6868
assert c.chart.left == 200
6969
assert c.chart.top == TOP
70+
71+
72+
def test_copy_right(sheet: Sheet):
73+
a = Axes(sheet=sheet)
74+
b = a.copy(left=0)
75+
assert b.chart.left == a.chart.left + a.chart.width
76+
77+
78+
def test_copy_bottom(sheet: Sheet):
79+
a = Axes(sheet=sheet)
80+
b = a.copy(top=0)
81+
assert b.chart.top == a.chart.top + a.chart.height

0 commit comments

Comments
 (0)