99from narwhals ._arrow .utils import cast_to_comparable_string_types , extract_py_scalar
1010from narwhals ._compliant import EagerGroupBy
1111from narwhals ._expression_parsing import evaluate_output_names_and_aliases
12- from narwhals ._utils import generate_temporary_column_name
12+ from narwhals ._utils import generate_temporary_column_name , requires
1313
1414if TYPE_CHECKING :
1515 from collections .abc import Iterator , Mapping , Sequence
@@ -39,12 +39,23 @@ class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr", "Aggregation"]):
3939 "count" : "count" ,
4040 "all" : "all" ,
4141 "any" : "any" ,
42+ "first" : "first" ,
43+ "last" : "last" ,
4244 }
4345 _REMAP_UNIQUE : ClassVar [Mapping [UniqueKeepStrategy , Aggregation ]] = {
4446 "any" : "min" ,
4547 "first" : "min" ,
4648 "last" : "max" ,
4749 }
50+ _OPTION_COUNT_ALL : ClassVar [frozenset [NarwhalsAggregation ]] = frozenset (
51+ ("len" , "n_unique" )
52+ )
53+ _OPTION_COUNT_VALID : ClassVar [frozenset [NarwhalsAggregation ]] = frozenset (("count" ,))
54+ _OPTION_ORDERED : ClassVar [frozenset [NarwhalsAggregation ]] = frozenset (
55+ ("first" , "last" )
56+ )
57+ _OPTION_VARIANCE : ClassVar [frozenset [NarwhalsAggregation ]] = frozenset (("std" , "var" ))
58+ _OPTION_SCALAR : ClassVar [frozenset [NarwhalsAggregation ]] = frozenset (("any" , "all" ))
4859
4960 def __init__ (
5061 self ,
@@ -60,12 +71,58 @@ def __init__(
6071 self ._grouped = pa .TableGroupBy (self .compliant .native , self ._keys )
6172 self ._drop_null_keys = drop_null_keys
6273
74+ def _configure_agg (
75+ self , grouped : pa .TableGroupBy , expr : ArrowExpr , /
76+ ) -> tuple [pa .TableGroupBy , Aggregation , AggregateOptions | None ]:
77+ option : AggregateOptions | None = None
78+ function_name = self ._leaf_name (expr )
79+ if function_name in self ._OPTION_VARIANCE :
80+ ddof = expr ._scalar_kwargs .get ("ddof" , 1 )
81+ option = pc .VarianceOptions (ddof = ddof )
82+ elif function_name in self ._OPTION_COUNT_ALL :
83+ option = pc .CountOptions (mode = "all" )
84+ elif function_name in self ._OPTION_COUNT_VALID :
85+ option = pc .CountOptions (mode = "only_valid" )
86+ elif function_name in self ._OPTION_SCALAR :
87+ option = pc .ScalarAggregateOptions (min_count = 0 )
88+ elif function_name in self ._OPTION_ORDERED :
89+ grouped , option = self ._ordered_agg (grouped , function_name )
90+ return grouped , self ._remap_expr_name (function_name ), option
91+
92+ def _ordered_agg (
93+ self , grouped : pa .TableGroupBy , name : NarwhalsAggregation , /
94+ ) -> tuple [pa .TableGroupBy , AggregateOptions ]:
95+ """The default behavior of `pyarrow` raises when `first` or `last` are used.
96+
97+ You'd see an error like:
98+
99+ ArrowNotImplementedError: Using ordered aggregator in multiple threaded execution is not supported
100+
101+ We need to **disable** multi-threading to use them, but the ability to do so
102+ wasn't possible before `14.0.0` ([pyarrow-36709])
103+
104+ [pyarrow-36709]: https://github.com/apache/arrow/issues/36709
105+ """
106+ backend_version = self .compliant ._backend_version
107+ if backend_version >= (14 , 0 ) and grouped ._use_threads :
108+ native = self .compliant .native
109+ grouped = pa .TableGroupBy (native , grouped .keys , use_threads = False )
110+ elif backend_version < (14 , 0 ): # pragma: no cover
111+ msg = (
112+ f"Using `{ name } ()` in a `group_by().agg(...)` context is only available in 'pyarrow>=14.0.0', "
113+ f"found version { requires ._unparse_version (backend_version )!r} .\n \n "
114+ f"See https://github.com/apache/arrow/issues/36709"
115+ )
116+ raise NotImplementedError (msg )
117+ return grouped , pc .ScalarAggregateOptions (skip_nulls = False )
118+
63119 def agg (self , * exprs : ArrowExpr ) -> ArrowDataFrame :
64120 self ._ensure_all_simple (exprs )
65121 aggs : list [tuple [str , Aggregation , AggregateOptions | None ]] = []
66122 expected_pyarrow_column_names : list [str ] = self ._keys .copy ()
67123 new_column_names : list [str ] = self ._keys .copy ()
68124 exclude = (* self ._keys , * self ._output_key_names )
125+ grouped = self ._grouped
69126
70127 for expr in exprs :
71128 output_names , aliases = evaluate_output_names_and_aliases (
@@ -83,20 +140,7 @@ def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
83140 aggs .append ((self ._keys [0 ], "count" , pc .CountOptions (mode = "all" )))
84141 continue
85142
86- function_name = self ._leaf_name (expr )
87- if function_name in {"std" , "var" }:
88- assert "ddof" in expr ._scalar_kwargs # noqa: S101
89- option : Any = pc .VarianceOptions (ddof = expr ._scalar_kwargs ["ddof" ])
90- elif function_name in {"len" , "n_unique" }:
91- option = pc .CountOptions (mode = "all" )
92- elif function_name == "count" :
93- option = pc .CountOptions (mode = "only_valid" )
94- elif function_name in {"all" , "any" }:
95- option = pc .ScalarAggregateOptions (min_count = 0 )
96- else :
97- option = None
98-
99- function_name = self ._remap_expr_name (function_name )
143+ grouped , function_name , option = self ._configure_agg (grouped , expr )
100144 new_column_names .extend (aliases )
101145 expected_pyarrow_column_names .extend (
102146 [f"{ output_name } _{ function_name } " for output_name in output_names ]
@@ -105,7 +149,7 @@ def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
105149 [(output_name , function_name , option ) for output_name in output_names ]
106150 )
107151
108- result_simple = self . _grouped .aggregate (aggs )
152+ result_simple = grouped .aggregate (aggs )
109153
110154 # Rename columns, being very careful
111155 expected_old_names_indices : dict [str , list [int ]] = collections .defaultdict (list )
0 commit comments