11from __future__ import annotations
22
3- from typing import TYPE_CHECKING , Any
3+ from typing import TYPE_CHECKING , Any , Literal
44
55import pyarrow as pa # ignore-banned-import
6+ import pyarrow .compute as pc # ignore-banned-import
67
8+ from narwhals ._plan import expressions as ir
9+ from narwhals ._plan .expressions import aggregation as agg
710from narwhals ._plan .protocols import DataFrameGroupBy
11+ from narwhals ._utils import Implementation , requires
812
913if TYPE_CHECKING :
10- from collections .abc import Iterator
14+ from collections .abc import Iterator , Mapping
1115
12- from typing_extensions import Self
16+ from typing_extensions import Self , TypeAlias
1317
18+ from narwhals ._arrow .typing import ( # type: ignore[attr-defined]
19+ AggregateOptions ,
20+ Aggregation ,
21+ )
22+ from narwhals ._compliant .typing import NarwhalsAggregation as _NarwhalsAggregation
1423 from narwhals ._plan .arrow .dataframe import ArrowDataFrame
1524 from narwhals ._plan .expressions import NamedIR
1625 from narwhals ._plan .typing import Seq
1726
27+ NarwhalsAggregation : TypeAlias = Literal [_NarwhalsAggregation , "first" , "last" ]
28+ InputName : TypeAlias = str
29+ NativeName : TypeAlias = str
30+ OutputName : TypeAlias = str
31+ NativeAggSpec : TypeAlias = tuple [InputName , Aggregation , AggregateOptions | None ]
32+ RenameSpec : TypeAlias = tuple [NativeName , OutputName ]
1833
19- class ArrowGroupBy (DataFrameGroupBy ["ArrowDataFrame" ]):
20- """What narwhals is doing.
2134
22- - Keys are handled only at compliant
23- - `ParseKeysGroupBy` does weird stuff
24- - But has a fast path for all `str` keys
25- - Aggs are handled in both levels
26- - Some compliant have more restrictions
27- """
35+ BACKEND_VERSION = Implementation .PYARROW ._backend_version ()
36+
37+
38+ # TODO @dangotbanned: Missing `nw.col("a").len()`
39+ SUPPORTED_AGG : Mapping [type [agg .AggExpr ], Aggregation ] = {
40+ agg .Sum : "sum" ,
41+ agg .Mean : "mean" ,
42+ agg .Median : "approximate_median" ,
43+ agg .Max : "max" ,
44+ agg .Min : "min" ,
45+ agg .Std : "stddev" ,
46+ agg .Var : "variance" ,
47+ agg .Count : "count" ,
48+ agg .NUnique : "count_distinct" ,
49+ agg .First : "first" ,
50+ agg .Last : "last" ,
51+ }
52+
53+
54+ SUPPORTED_IR : Mapping [type [ir .Len ], Aggregation ] = {ir .Len : "count" }
55+ SUPPORTED_FUNCTION : Mapping [type [ir .boolean .BooleanFunction ], Aggregation ] = {
56+ ir .boolean .All : "all" ,
57+ ir .boolean .Any : "any" ,
58+ }
59+
60+ REMAINING : tuple [Aggregation , ...] = (
61+ "count_all" , # Count the number of rows in each group
62+ "distinct" , # Keep the distinct values in each group
63+ "first_last" , # Compute the first and last of values in each group
64+ "list" , # List all values in each group
65+ "min_max" , # Compute the minimum and maximum of values in each group
66+ "one" , # Get one value from each group
67+ "product" , # Compute the product of values in each group
68+ "tdigest" , # Compute approximate quantiles of values in each group
69+ )
70+ """Available [native aggs] we haven't used (excluding `first`, `last`)
71+
72+ [native aggs]: https://arrow.apache.org/docs/python/compute.html#grouped-aggregations
73+ """
74+
75+
76+ REQUIRES_PYARROW_20 : tuple [
77+ Literal ["kurtosis" ], Literal ["pivot_wider" ], Literal ["skew" ]
78+ ] = (
79+ "kurtosis" , # Compute the kurtosis of values in each group
80+ "pivot_wider" , # Pivot values according to a pivot key column
81+ "skew" , # Compute the skewness of values in each group
82+ )
83+ """https://arrow.apache.org/docs/20.0/python/compute.html#grouped-aggregations"""
84+
85+
86+ def _ensure_single_thread (
87+ grouped : pa .TableGroupBy , expr : ir .OrderableAggExpr , /
88+ ) -> pa .TableGroupBy :
89+ """First/last require disabling threading."""
90+ if BACKEND_VERSION >= (14 , 0 ) and grouped ._use_threads :
91+ # NOTE: Stubs say `_table` is a method, but at runtime it is a property
92+ grouped = pa .TableGroupBy (grouped ._table , grouped .keys , use_threads = False ) # type: ignore[arg-type]
93+ elif BACKEND_VERSION < (14 , 0 ): # pragma: no cover
94+ msg = (
95+ f"Using `{ expr !r} ` in a `group_by().agg(...)` context is only available in 'pyarrow>=14.0.0', "
96+ f"found version { requires ._unparse_version (BACKEND_VERSION )!r} .\n \n "
97+ f"See https://github.com/apache/arrow/issues/36709"
98+ )
99+ raise NotImplementedError (msg )
100+ return grouped
101+
28102
103+ def group_by_error (
104+ expr : ArrowAggExpr ,
105+ reason : Literal [
106+ "too complex" ,
107+ "unsupported aggregation" ,
108+ "unsupported function" ,
109+ "unsupported expression" ,
110+ ],
111+ ) -> NotImplementedError :
112+ if reason == "too complex" :
113+ msg = "Non-trivial complex aggregation found"
114+ else :
115+ msg = reason .title ()
116+ msg = f"{ msg } in 'pyarrow.Table':\n \n { expr .named_ir !r} "
117+ return NotImplementedError (msg )
118+
119+
120+ class ArrowAggExpr :
121+ def __init__ (self , named_ir : NamedIR , / ) -> None :
122+ self .named_ir : NamedIR = named_ir
123+
124+ @property
125+ def output_name (self ) -> OutputName :
126+ return self .named_ir .name
127+
128+ def _parse_agg_expr (
129+ self , expr : agg .AggExpr , grouped : pa .TableGroupBy
130+ ) -> tuple [InputName , Aggregation , AggregateOptions | None , pa .TableGroupBy ]:
131+ if agg_name := SUPPORTED_AGG .get (type (expr )):
132+ option : AggregateOptions | None = None
133+ if isinstance (expr , (agg .Std , agg .Var )):
134+ # NOTE: Only branch which needs an instance (for `ddof`)
135+ option = pc .VarianceOptions (ddof = expr .ddof )
136+ elif isinstance (expr , agg .NUnique ):
137+ option = pc .CountOptions (mode = "all" )
138+ elif isinstance (expr , agg .Count ):
139+ option = pc .CountOptions (mode = "only_valid" )
140+ elif isinstance (expr , (agg .First , agg .Last )):
141+ option = pc .ScalarAggregateOptions (skip_nulls = False )
142+ # NOTE: Only branch which needs access to `pa.TableGroupBy`
143+ grouped = _ensure_single_thread (grouped , expr )
144+ if isinstance (expr .expr , ir .Column ):
145+ return expr .expr .name , agg_name , option , grouped
146+ raise group_by_error (self , "too complex" )
147+ raise group_by_error (self , "unsupported aggregation" )
148+
149+ def _parse_function_expr (self , expr : ir .FunctionExpr ) -> NativeAggSpec :
150+ if isinstance (expr .function , (ir .boolean .All , ir .boolean .Any )):
151+ agg_name = SUPPORTED_FUNCTION [type (expr .function )]
152+ option = pc .ScalarAggregateOptions (min_count = 0 )
153+ if len (expr .input ) == 1 and isinstance (expr .input [0 ], ir .Column ):
154+ return expr .input [0 ].name , agg_name , option
155+ raise group_by_error (self , "too complex" )
156+ raise group_by_error (self , "unsupported function" )
157+
158+ def _rename_spec (self , input_name : InputName , agg_name : Aggregation , / ) -> RenameSpec :
159+ # `pyarrow` auto-generates the lhs
160+ # we want to overwrite that later with rhs
161+ return f"{ input_name } _{ agg_name } " , self .output_name
162+
163+ def to_native (
164+ self , grouped : pa .TableGroupBy
165+ ) -> tuple [pa .TableGroupBy , NativeAggSpec , RenameSpec ]:
166+ expr = self .named_ir .expr
167+ if isinstance (expr , agg .AggExpr ):
168+ input_name , agg_name , option , grouped = self ._parse_agg_expr (expr , grouped )
169+ elif isinstance (expr , ir .Len ):
170+ msg = "Need to investigate https://github.com/narwhals-dev/narwhals/blob/0fb045536f5b56b978f354f8178b292301e9598c/narwhals/_arrow/group_by.py#L132-L141"
171+ raise NotImplementedError (msg )
172+ elif isinstance (expr , ir .FunctionExpr ):
173+ input_name , agg_name , option = self ._parse_function_expr (expr )
174+ else :
175+ raise group_by_error (self , "unsupported expression" )
176+ agg_spec = input_name , agg_name , option
177+ return grouped , agg_spec , self ._rename_spec (input_name , agg_name )
178+
179+
180+ class ArrowGroupBy (DataFrameGroupBy ["ArrowDataFrame" ]):
29181 _df : ArrowDataFrame
30182 _grouped : pa .TableGroupBy
31183 _keys : Seq [NamedIR ]
@@ -52,4 +204,11 @@ def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]:
52204 raise NotImplementedError
53205
54206 def agg (self , irs : Seq [NamedIR ]) -> ArrowDataFrame :
55- raise NotImplementedError
207+ gb = self ._grouped
208+ aggs : list [NativeAggSpec ] = []
209+ renames : list [RenameSpec ] = []
210+ for e in irs :
211+ gb , agg_spec , rename = ArrowAggExpr (e ).to_native (gb )
212+ aggs .append (agg_spec )
213+ renames .append (rename )
214+ return self .compliant ._with_native (gb .aggregate (aggs )).rename (dict (renames ))
0 commit comments