44import re
55from typing import TYPE_CHECKING
66from typing import Any
7+ from typing import ClassVar
78from typing import Iterator
9+ from typing import Mapping
10+ from typing import Sequence
811
912import pyarrow as pa
1013import pyarrow .compute as pc
1114
15+ from narwhals ._arrow .dataframe import ArrowDataFrame
1216from narwhals ._arrow .utils import cast_to_comparable_string_types
1317from narwhals ._arrow .utils import extract_py_scalar
18+ from narwhals ._compliant import EagerGroupBy
1419from narwhals ._expression_parsing import evaluate_output_names_and_aliases
15- from narwhals ._expression_parsing import is_elementary_expression
1620from narwhals .utils import generate_temporary_column_name
1721
1822if TYPE_CHECKING :
2226 from narwhals ._arrow .expr import ArrowExpr
2327 from narwhals ._arrow .typing import Incomplete
2428
25- POLARS_TO_ARROW_AGGREGATIONS = {
26- "sum" : "sum" ,
27- "mean" : "mean" ,
28- "median " : "approximate_median " ,
29- "max " : "max " ,
30- "min " : "min " ,
31- "std " : "stddev " ,
32- "var " : "variance " ,
33- "len " : "count " ,
34- "n_unique " : "count_distinct " ,
35- "count " : "count" ,
36- }
37-
38-
39- class ArrowGroupBy :
29+
30+ class ArrowGroupBy ( EagerGroupBy [ "ArrowDataFrame" , "ArrowExpr" ]):
31+ _NARWHALS_TO_NATIVE_AGGREGATIONS : ClassVar [ Mapping [ str , Any ]] = {
32+ "sum " : "sum " ,
33+ "mean " : "mean " ,
34+ "median " : "approximate_median " ,
35+ "max " : "max " ,
36+ "min " : "min " ,
37+ "std " : "stddev " ,
38+ "var " : "variance " ,
39+ "len " : "count" ,
40+ "n_unique" : "count_distinct" ,
41+ "count" : "count" ,
42+ }
43+
4044 def __init__ (
41- self : Self , df : ArrowDataFrame , keys : list [str ], * , drop_null_keys : bool
45+ self ,
46+ compliant_frame : ArrowDataFrame ,
47+ keys : Sequence [str ],
48+ * ,
49+ drop_null_keys : bool ,
4250 ) -> None :
4351 if drop_null_keys :
44- self ._df = df .drop_nulls (keys )
52+ self ._compliant_frame = compliant_frame .drop_nulls (keys )
4553 else :
46- self ._df = df
47- self ._keys = keys . copy ( )
48- self ._grouped = pa .TableGroupBy (self ._df . _native_frame , self ._keys )
54+ self ._compliant_frame = compliant_frame
55+ self ._keys : list [ str ] = list ( keys )
56+ self ._grouped = pa .TableGroupBy (self .compliant . native , self ._keys )
4957
5058 def agg (self : Self , * exprs : ArrowExpr ) -> ArrowDataFrame :
51- all_simple_aggs = True
52- for expr in exprs :
53- if not (
54- is_elementary_expression (expr )
55- and re .sub (r"(\w+->)" , "" , expr ._function_name )
56- in POLARS_TO_ARROW_AGGREGATIONS
57- ):
58- all_simple_aggs = False
59- break
60-
61- if not all_simple_aggs :
62- msg = (
63- "Non-trivial complex aggregation found.\n \n "
64- "Hint: you were probably trying to apply a non-elementary aggregation with a "
65- "pyarrow table.\n "
66- "Please rewrite your query such that group-by aggregations "
67- "are elementary. For example, instead of:\n \n "
68- " df.group_by('a').agg(nw.col('b').round(2).mean())\n \n "
69- "use:\n \n "
70- " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n \n "
71- )
72- raise ValueError (msg )
73-
59+ self ._ensure_all_simple (exprs )
7460 aggs : list [tuple [str , str , Any ]] = []
7561 expected_pyarrow_column_names : list [str ] = self ._keys .copy ()
7662 new_column_names : list [str ] = self ._keys .copy ()
7763
7864 for expr in exprs :
7965 output_names , aliases = evaluate_output_names_and_aliases (
80- expr , self ._df , self ._keys
66+ expr , self .compliant , self ._keys
8167 )
8268
8369 if expr ._depth == 0 :
@@ -102,7 +88,7 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame:
10288 else :
10389 option = None
10490
105- function_name = POLARS_TO_ARROW_AGGREGATIONS [function_name ]
91+ function_name = self . _NARWHALS_TO_NATIVE_AGGREGATIONS [function_name ]
10692
10793 new_column_names .extend (aliases )
10894 expected_pyarrow_column_names .extend (
@@ -133,18 +119,20 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame:
133119 ]
134120 new_column_names = [new_column_names [i ] for i in index_map ]
135121 result_simple = result_simple .rename_columns (new_column_names )
136- if self ._df ._backend_version < (12 , 0 , 0 ):
122+ if self .compliant ._backend_version < (12 , 0 , 0 ):
137123 columns = result_simple .column_names
138124 result_simple = result_simple .select (
139125 [* self ._keys , * [col for col in columns if col not in self ._keys ]]
140126 )
141- return self ._df ._from_native_frame (result_simple )
127+ return self .compliant ._from_native_frame (result_simple )
142128
143129 def __iter__ (self : Self ) -> Iterator [tuple [Any , ArrowDataFrame ]]:
144- col_token = generate_temporary_column_name (n_bytes = 8 , columns = self ._df .columns )
130+ col_token = generate_temporary_column_name (
131+ n_bytes = 8 , columns = self .compliant .columns
132+ )
145133 null_token : str = "__null_token_value__" # noqa: S105
146134
147- table = self ._df . _native_frame
135+ table = self .compliant . native
148136 # NOTE: stubs fail in multiple places for `ChunkedArray`
149137 it , separator_scalar = cast_to_comparable_string_types (
150138 * (table [key ] for key in self ._keys ), separator = ""
@@ -160,7 +148,7 @@ def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]:
160148 )
161149 table = table .add_column (i = 0 , field_ = col_token , column = key_values )
162150 for v in pc .unique (key_values ):
163- t = self ._df ._from_native_frame (
151+ t = self .compliant ._from_native_frame (
164152 table .filter (pc .equal (table [col_token ], v )).drop ([col_token ])
165153 )
166154 row = t .simple_select (* self ._keys ).row (0 )
0 commit comments