Skip to content

Commit 5aaceb7

Browse files
committed
refactor: Implement ArrowGroupBy
1 parent 6f68f3a commit 5aaceb7

File tree

2 files changed

+40
-52
lines changed

2 files changed

+40
-52
lines changed

narwhals/_arrow/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def with_columns(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame:
399399
def group_by(self: Self, *keys: str, drop_null_keys: bool) -> ArrowGroupBy:
400400
from narwhals._arrow.group_by import ArrowGroupBy
401401

402-
return ArrowGroupBy(self, list(keys), drop_null_keys=drop_null_keys)
402+
return ArrowGroupBy(self, keys, drop_null_keys=drop_null_keys)
403403

404404
def join(
405405
self: Self,

narwhals/_arrow/group_by.py

Lines changed: 39 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,19 @@
44
import re
55
from typing import TYPE_CHECKING
66
from typing import Any
7+
from typing import ClassVar
78
from typing import Iterator
9+
from typing import Mapping
10+
from typing import Sequence
811

912
import pyarrow as pa
1013
import pyarrow.compute as pc
1114

15+
from narwhals._arrow.dataframe import ArrowDataFrame
1216
from narwhals._arrow.utils import cast_to_comparable_string_types
1317
from narwhals._arrow.utils import extract_py_scalar
18+
from narwhals._compliant import EagerGroupBy
1419
from narwhals._expression_parsing import evaluate_output_names_and_aliases
15-
from narwhals._expression_parsing import is_elementary_expression
1620
from narwhals.utils import generate_temporary_column_name
1721

1822
if TYPE_CHECKING:
@@ -22,62 +26,44 @@
2226
from narwhals._arrow.expr import ArrowExpr
2327
from narwhals._arrow.typing import Incomplete
2428

25-
POLARS_TO_ARROW_AGGREGATIONS = {
26-
"sum": "sum",
27-
"mean": "mean",
28-
"median": "approximate_median",
29-
"max": "max",
30-
"min": "min",
31-
"std": "stddev",
32-
"var": "variance",
33-
"len": "count",
34-
"n_unique": "count_distinct",
35-
"count": "count",
36-
}
37-
38-
39-
class ArrowGroupBy:
29+
30+
class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr"]):
31+
_NARWHALS_TO_NATIVE_AGGREGATIONS: ClassVar[Mapping[str, Any]] = {
32+
"sum": "sum",
33+
"mean": "mean",
34+
"median": "approximate_median",
35+
"max": "max",
36+
"min": "min",
37+
"std": "stddev",
38+
"var": "variance",
39+
"len": "count",
40+
"n_unique": "count_distinct",
41+
"count": "count",
42+
}
43+
4044
def __init__(
41-
self: Self, df: ArrowDataFrame, keys: list[str], *, drop_null_keys: bool
45+
self,
46+
compliant_frame: ArrowDataFrame,
47+
keys: Sequence[str],
48+
*,
49+
drop_null_keys: bool,
4250
) -> None:
4351
if drop_null_keys:
44-
self._df = df.drop_nulls(keys)
52+
self._compliant_frame = compliant_frame.drop_nulls(keys)
4553
else:
46-
self._df = df
47-
self._keys = keys.copy()
48-
self._grouped = pa.TableGroupBy(self._df._native_frame, self._keys)
54+
self._compliant_frame = compliant_frame
55+
self._keys: list[str] = list(keys)
56+
self._grouped = pa.TableGroupBy(self.compliant.native, self._keys)
4957

5058
def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame:
51-
all_simple_aggs = True
52-
for expr in exprs:
53-
if not (
54-
is_elementary_expression(expr)
55-
and re.sub(r"(\w+->)", "", expr._function_name)
56-
in POLARS_TO_ARROW_AGGREGATIONS
57-
):
58-
all_simple_aggs = False
59-
break
60-
61-
if not all_simple_aggs:
62-
msg = (
63-
"Non-trivial complex aggregation found.\n\n"
64-
"Hint: you were probably trying to apply a non-elementary aggregation with a "
65-
"pyarrow table.\n"
66-
"Please rewrite your query such that group-by aggregations "
67-
"are elementary. For example, instead of:\n\n"
68-
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
69-
"use:\n\n"
70-
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
71-
)
72-
raise ValueError(msg)
73-
59+
self._ensure_all_simple(exprs)
7460
aggs: list[tuple[str, str, Any]] = []
7561
expected_pyarrow_column_names: list[str] = self._keys.copy()
7662
new_column_names: list[str] = self._keys.copy()
7763

7864
for expr in exprs:
7965
output_names, aliases = evaluate_output_names_and_aliases(
80-
expr, self._df, self._keys
66+
expr, self.compliant, self._keys
8167
)
8268

8369
if expr._depth == 0:
@@ -102,7 +88,7 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame:
10288
else:
10389
option = None
10490

105-
function_name = POLARS_TO_ARROW_AGGREGATIONS[function_name]
91+
function_name = self._NARWHALS_TO_NATIVE_AGGREGATIONS[function_name]
10692

10793
new_column_names.extend(aliases)
10894
expected_pyarrow_column_names.extend(
@@ -133,18 +119,20 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame:
133119
]
134120
new_column_names = [new_column_names[i] for i in index_map]
135121
result_simple = result_simple.rename_columns(new_column_names)
136-
if self._df._backend_version < (12, 0, 0):
122+
if self.compliant._backend_version < (12, 0, 0):
137123
columns = result_simple.column_names
138124
result_simple = result_simple.select(
139125
[*self._keys, *[col for col in columns if col not in self._keys]]
140126
)
141-
return self._df._from_native_frame(result_simple)
127+
return self.compliant._from_native_frame(result_simple)
142128

143129
def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]:
144-
col_token = generate_temporary_column_name(n_bytes=8, columns=self._df.columns)
130+
col_token = generate_temporary_column_name(
131+
n_bytes=8, columns=self.compliant.columns
132+
)
145133
null_token: str = "__null_token_value__" # noqa: S105
146134

147-
table = self._df._native_frame
135+
table = self.compliant.native
148136
# NOTE: stubs fail in multiple places for `ChunkedArray`
149137
it, separator_scalar = cast_to_comparable_string_types(
150138
*(table[key] for key in self._keys), separator=""
@@ -160,7 +148,7 @@ def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]:
160148
)
161149
table = table.add_column(i=0, field_=col_token, column=key_values)
162150
for v in pc.unique(key_values):
163-
t = self._df._from_native_frame(
151+
t = self.compliant._from_native_frame(
164152
table.filter(pc.equal(table[col_token], v)).drop([col_token])
165153
)
166154
row = t.simple_select(*self._keys).row(0)

0 commit comments

Comments
 (0)