Skip to content

Commit 767261c

Browse files
committed
feat(DRAFT): Simple cases working?
Borrowing some ideas from #2528, #2680
1 parent d9b918f commit 767261c

File tree

3 files changed

+181
-21
lines changed

3 files changed

+181
-21
lines changed

narwhals/_plan/arrow/dataframe.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, Literal, overload
3+
from typing import TYPE_CHECKING, Any, Literal, cast, overload
44

55
import pyarrow as pa # ignore-banned-import
66
import pyarrow.compute as pc # ignore-banned-import
@@ -106,6 +106,14 @@ def drop(self, columns: Sequence[str]) -> Self:
106106
to_drop = list(columns)
107107
return self._with_native(self.native.drop(to_drop))
108108

109+
def rename(self, mapping: Mapping[str, str]) -> Self:
110+
names: dict[str, str] | list[str]
111+
if fn.BACKEND_VERSION >= (17,):
112+
names = cast("dict[str, str]", mapping)
113+
else: # pragma: no cover
114+
names = [mapping.get(c, c) for c in self.columns]
115+
return self._with_native(self.native.rename_columns(names))
116+
109117
# NOTE: Use instead of `with_columns` for trivial cases
110118
def _with_columns(self, exprs: Iterable[Expr | Scalar], /) -> Self:
111119
native = self.native

narwhals/_plan/arrow/group_by.py

Lines changed: 171 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,183 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any
3+
from typing import TYPE_CHECKING, Any, Literal
44

55
import pyarrow as pa # ignore-banned-import
6+
import pyarrow.compute as pc # ignore-banned-import
67

8+
from narwhals._plan import expressions as ir
9+
from narwhals._plan.expressions import aggregation as agg
710
from narwhals._plan.protocols import DataFrameGroupBy
11+
from narwhals._utils import Implementation, requires
812

913
if TYPE_CHECKING:
10-
from collections.abc import Iterator
14+
from collections.abc import Iterator, Mapping
1115

12-
from typing_extensions import Self
16+
from typing_extensions import Self, TypeAlias
1317

18+
from narwhals._arrow.typing import ( # type: ignore[attr-defined]
19+
AggregateOptions,
20+
Aggregation,
21+
)
22+
from narwhals._compliant.typing import NarwhalsAggregation as _NarwhalsAggregation
1423
from narwhals._plan.arrow.dataframe import ArrowDataFrame
1524
from narwhals._plan.expressions import NamedIR
1625
from narwhals._plan.typing import Seq
1726

27+
NarwhalsAggregation: TypeAlias = Literal[_NarwhalsAggregation, "first", "last"]
28+
InputName: TypeAlias = str
29+
NativeName: TypeAlias = str
30+
OutputName: TypeAlias = str
31+
NativeAggSpec: TypeAlias = tuple[InputName, Aggregation, AggregateOptions | None]
32+
RenameSpec: TypeAlias = tuple[NativeName, OutputName]
1833

19-
class ArrowGroupBy(DataFrameGroupBy["ArrowDataFrame"]):
20-
"""What narwhals is doing.
2134

22-
- Keys are handled only at compliant
23-
- `ParseKeysGroupBy` does weird stuff
24-
- But has a fast path for all `str` keys
25-
- Aggs are handled in both levels
26-
- Some compliant have more restrictions
27-
"""
35+
BACKEND_VERSION = Implementation.PYARROW._backend_version()
36+
37+
38+
# TODO @dangotbanned: Missing `nw.col("a").len()`
39+
SUPPORTED_AGG: Mapping[type[agg.AggExpr], Aggregation] = {
40+
agg.Sum: "sum",
41+
agg.Mean: "mean",
42+
agg.Median: "approximate_median",
43+
agg.Max: "max",
44+
agg.Min: "min",
45+
agg.Std: "stddev",
46+
agg.Var: "variance",
47+
agg.Count: "count",
48+
agg.NUnique: "count_distinct",
49+
agg.First: "first",
50+
agg.Last: "last",
51+
}
52+
53+
54+
SUPPORTED_IR: Mapping[type[ir.Len], Aggregation] = {ir.Len: "count"}
55+
SUPPORTED_FUNCTION: Mapping[type[ir.boolean.BooleanFunction], Aggregation] = {
56+
ir.boolean.All: "all",
57+
ir.boolean.Any: "any",
58+
}
59+
60+
REMAINING: tuple[Aggregation, ...] = (
61+
"count_all", # Count the number of rows in each group
62+
"distinct", # Keep the distinct values in each group
63+
"first_last", # Compute the first and last of values in each group
64+
"list", # List all values in each group
65+
"min_max", # Compute the minimum and maximum of values in each group
66+
"one", # Get one value from each group
67+
"product", # Compute the product of values in each group
68+
"tdigest", # Compute approximate quantiles of values in each group
69+
)
70+
"""Available [native aggs] we haven't used (excluding `first`, `last`)
71+
72+
[native aggs]: https://arrow.apache.org/docs/python/compute.html#grouped-aggregations
73+
"""
74+
75+
76+
REQUIRES_PYARROW_20: tuple[
77+
Literal["kurtosis"], Literal["pivot_wider"], Literal["skew"]
78+
] = (
79+
"kurtosis", # Compute the kurtosis of values in each group
80+
"pivot_wider", # Pivot values according to a pivot key column
81+
"skew", # Compute the skewness of values in each group
82+
)
83+
"""https://arrow.apache.org/docs/20.0/python/compute.html#grouped-aggregations"""
84+
85+
86+
def _ensure_single_thread(
87+
grouped: pa.TableGroupBy, expr: ir.OrderableAggExpr, /
88+
) -> pa.TableGroupBy:
89+
"""First/last require disabling threading."""
90+
if BACKEND_VERSION >= (14, 0) and grouped._use_threads:
91+
# NOTE: Stubs say `_table` is a method, but at runtime it is a property
92+
grouped = pa.TableGroupBy(grouped._table, grouped.keys, use_threads=False) # type: ignore[arg-type]
93+
elif BACKEND_VERSION < (14, 0): # pragma: no cover
94+
msg = (
95+
f"Using `{expr!r}` in a `group_by().agg(...)` context is only available in 'pyarrow>=14.0.0', "
96+
f"found version {requires._unparse_version(BACKEND_VERSION)!r}.\n\n"
97+
f"See https://github.com/apache/arrow/issues/36709"
98+
)
99+
raise NotImplementedError(msg)
100+
return grouped
101+
28102

103+
def group_by_error(
104+
expr: ArrowAggExpr,
105+
reason: Literal[
106+
"too complex",
107+
"unsupported aggregation",
108+
"unsupported function",
109+
"unsupported expression",
110+
],
111+
) -> NotImplementedError:
112+
if reason == "too complex":
113+
msg = "Non-trivial complex aggregation found"
114+
else:
115+
msg = reason.title()
116+
msg = f"{msg} in 'pyarrow.Table':\n\n{expr.named_ir!r}"
117+
return NotImplementedError(msg)
118+
119+
120+
class ArrowAggExpr:
121+
def __init__(self, named_ir: NamedIR, /) -> None:
122+
self.named_ir: NamedIR = named_ir
123+
124+
@property
125+
def output_name(self) -> OutputName:
126+
return self.named_ir.name
127+
128+
def _parse_agg_expr(
129+
self, expr: agg.AggExpr, grouped: pa.TableGroupBy
130+
) -> tuple[InputName, Aggregation, AggregateOptions | None, pa.TableGroupBy]:
131+
if agg_name := SUPPORTED_AGG.get(type(expr)):
132+
option: AggregateOptions | None = None
133+
if isinstance(expr, (agg.Std, agg.Var)):
134+
# NOTE: Only branch which needs an instance (for `ddof`)
135+
option = pc.VarianceOptions(ddof=expr.ddof)
136+
elif isinstance(expr, agg.NUnique):
137+
option = pc.CountOptions(mode="all")
138+
elif isinstance(expr, agg.Count):
139+
option = pc.CountOptions(mode="only_valid")
140+
elif isinstance(expr, (agg.First, agg.Last)):
141+
option = pc.ScalarAggregateOptions(skip_nulls=False)
142+
# NOTE: Only branch which needs access to `pa.TableGroupBy`
143+
grouped = _ensure_single_thread(grouped, expr)
144+
if isinstance(expr.expr, ir.Column):
145+
return expr.expr.name, agg_name, option, grouped
146+
raise group_by_error(self, "too complex")
147+
raise group_by_error(self, "unsupported aggregation")
148+
149+
def _parse_function_expr(self, expr: ir.FunctionExpr) -> NativeAggSpec:
150+
if isinstance(expr.function, (ir.boolean.All, ir.boolean.Any)):
151+
agg_name = SUPPORTED_FUNCTION[type(expr.function)]
152+
option = pc.ScalarAggregateOptions(min_count=0)
153+
if len(expr.input) == 1 and isinstance(expr.input[0], ir.Column):
154+
return expr.input[0].name, agg_name, option
155+
raise group_by_error(self, "too complex")
156+
raise group_by_error(self, "unsupported function")
157+
158+
def _rename_spec(self, input_name: InputName, agg_name: Aggregation, /) -> RenameSpec:
159+
# `pyarrow` auto-generates the lhs
160+
# we want to overwrite that later with rhs
161+
return f"{input_name}_{agg_name}", self.output_name
162+
163+
def to_native(
164+
self, grouped: pa.TableGroupBy
165+
) -> tuple[pa.TableGroupBy, NativeAggSpec, RenameSpec]:
166+
expr = self.named_ir.expr
167+
if isinstance(expr, agg.AggExpr):
168+
input_name, agg_name, option, grouped = self._parse_agg_expr(expr, grouped)
169+
elif isinstance(expr, ir.Len):
170+
msg = "Need to investigate https://github.com/narwhals-dev/narwhals/blob/0fb045536f5b56b978f354f8178b292301e9598c/narwhals/_arrow/group_by.py#L132-L141"
171+
raise NotImplementedError(msg)
172+
elif isinstance(expr, ir.FunctionExpr):
173+
input_name, agg_name, option = self._parse_function_expr(expr)
174+
else:
175+
raise group_by_error(self, "unsupported expression")
176+
agg_spec = input_name, agg_name, option
177+
return grouped, agg_spec, self._rename_spec(input_name, agg_name)
178+
179+
180+
class ArrowGroupBy(DataFrameGroupBy["ArrowDataFrame"]):
29181
_df: ArrowDataFrame
30182
_grouped: pa.TableGroupBy
31183
_keys: Seq[NamedIR]
@@ -52,4 +204,11 @@ def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]:
52204
raise NotImplementedError
53205

54206
def agg(self, irs: Seq[NamedIR]) -> ArrowDataFrame:
55-
raise NotImplementedError
207+
gb = self._grouped
208+
aggs: list[NativeAggSpec] = []
209+
renames: list[RenameSpec] = []
210+
for e in irs:
211+
gb, agg_spec, rename = ArrowAggExpr(e).to_native(gb)
212+
aggs.append(agg_spec)
213+
renames.append(rename)
214+
return self.compliant._with_native(gb.aggregate(aggs)).rename(dict(renames))

narwhals/_plan/group_by.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,7 @@ def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFra
6767
else: # noqa: RET506
6868
# If not, we can just use the resolved key names as a fast-path
6969
grouped = compliant_gb.by_names(compliant, resolved.keys_names)
70-
msg = fmt_group_by_error(
71-
"`GroupBy.agg` needs a `CompliantGroupBy.agg` to dispatch to",
72-
resolved.keys,
73-
resolved.aggs,
74-
resolved.result_schema,
75-
)
76-
raise NotImplementedError(msg)
77-
return grouped.agg(resolved.aggs)
70+
return self._frame._from_compliant(grouped.agg(resolved.aggs))
7871

7972

8073
class _TempGroupByStuff(NamedTuple):

0 commit comments

Comments
 (0)