66import pyarrow as pa
77import pyarrow .compute as pc
88
9- from narwhals ._arrow .utils import cast_to_comparable_string_types , extract_py_scalar
9+ from narwhals ._arrow .utils import (
10+ BACKEND_VERSION ,
11+ cast_to_comparable_string_types ,
12+ extract_py_scalar ,
13+ )
1014from narwhals ._compliant import EagerGroupBy
1115from narwhals ._expression_parsing import evaluate_output_names_and_aliases
1216from narwhals ._utils import generate_temporary_column_name , requires
@@ -71,12 +75,11 @@ def __init__(
7175 self ._df = df
7276 frame , self ._keys , self ._output_key_names = self ._parse_keys (df , keys = keys )
7377 self ._compliant_frame = frame .drop_nulls (self ._keys ) if drop_null_keys else frame
74- self ._grouped = pa .TableGroupBy (self .compliant .native , self ._keys )
7578 self ._drop_null_keys = drop_null_keys
7679
7780 def _configure_agg (
78- self , grouped : pa . TableGroupBy , expr : ArrowExpr , /
79- ) -> tuple [pa . TableGroupBy , Aggregation , AggregateOptions | None ]:
81+ self , expr : ArrowExpr , /
82+ ) -> tuple [Aggregation , AggregateOptions | None ]:
8083 option : AggregateOptions | None = None
8184 function_name = self ._leaf_name (expr )
8285 kwargs = self ._kwargs (expr )
@@ -91,50 +94,49 @@ def _configure_agg(
9194 option = pc .ScalarAggregateOptions (min_count = 0 )
9295 elif function_name in self ._OPTION_ORDERED :
9396 ignore_nulls = kwargs .get ("ignore_nulls" , False )
94- grouped , option = self ._ordered_agg (
95- grouped , function_name , ignore_nulls = ignore_nulls
96- )
97- return grouped , self ._remap_expr_name (function_name ), option
98-
99- def _ordered_agg (
100- self ,
101- grouped : pa .TableGroupBy ,
102- name : NarwhalsAggregation ,
103- / ,
104- * ,
105- ignore_nulls : bool ,
106- ) -> tuple [pa .TableGroupBy , AggregateOptions ]:
107- """The default behavior of `pyarrow` raises when `first` or `last` are used.
108-
109- You'd see an error like:
97+ option = pc .ScalarAggregateOptions (skip_nulls = ignore_nulls )
98+ return self ._remap_expr_name (function_name ), option
11099
111- ArrowNotImplementedError: Using ordered aggregator in multiple threaded execution is not supported
112-
113- We need to **disable** multi-threading to use them, but the ability to do so
114- wasn't possible before `14.0.0` ([pyarrow-36709])
115-
116- [pyarrow-36709]: https://github.com/apache/arrow/issues/36709
117- """
118- backend_version = self .compliant ._backend_version
119- if backend_version >= (14 , 0 ) and grouped ._use_threads :
120- native = self .compliant .native
121- grouped = pa .TableGroupBy (native , grouped .keys , use_threads = False )
122- elif backend_version < (14 , 0 ): # pragma: no cover
100+ def _configure_grouped (self , * exprs : ArrowExpr ) -> pa .TableGroupBy :
101+ order_by = ()
102+ use_threads = True
103+ for expr in exprs :
104+ md = next (expr ._metadata .op_nodes_reversed ())
105+ if md .name not in self ._OPTION_ORDERED :
106+ continue
107+ # [pyarrow-36709]: https://github.com/apache/arrow/issues/36709
108+ use_threads = False
109+ if _current_order_by := md .kwargs .get ("order_by" , ()):
110+ if order_by and _current_order_by != order_by :
111+ msg = f"Only one `order_by` can be specified in `group_by`. Found both { order_by } and { _current_order_by } ."
112+ raise NotImplementedError (msg )
113+ order_by = _current_order_by
114+ if not use_threads and BACKEND_VERSION < (14 ,): # pragma: no cover
123115 msg = (
124- f"Using `{ name } () ` in a `group_by().agg(...)` context is only available in 'pyarrow>=14.0.0', "
125- f"found version { requires ._unparse_version (backend_version )!r} .\n \n "
116+ f"Using `first/last ` in a `group_by().agg(...)` context is only available in 'pyarrow>=14.0.0', "
117+ f"found version { requires ._unparse_version (BACKEND_VERSION )!r} .\n \n "
126118 f"See https://github.com/apache/arrow/issues/36709"
127119 )
128120 raise NotImplementedError (msg )
129- return grouped , pc .ScalarAggregateOptions (skip_nulls = ignore_nulls )
121+ if order_by :
122+ return pa .TableGroupBy (
123+ self .compliant .sort (* order_by , descending = False , nulls_last = False ).native ,
124+ self ._keys ,
125+ use_threads = use_threads ,
126+ )
127+ if not use_threads :
128+ return pa .TableGroupBy (self .compliant .native , self ._keys , use_threads = False )
129+ # TODO(unassigned): combine with `return` above once PyArrow 15 is the minimum.
130+ return pa .TableGroupBy (self .compliant .native , self ._keys )
130131
131132 def agg (self , * exprs : ArrowExpr ) -> ArrowDataFrame :
132133 self ._ensure_all_simple (exprs )
134+ grouped = self ._configure_grouped (* exprs )
135+
133136 aggs : list [tuple [str , Aggregation , AggregateOptions | None ]] = []
134137 expected_pyarrow_column_names : list [str ] = self ._keys .copy ()
135138 new_column_names : list [str ] = self ._keys .copy ()
136139 exclude = (* self ._keys , * self ._output_key_names )
137- grouped = self ._grouped
138140
139141 for expr in exprs :
140142 output_names , aliases = evaluate_output_names_and_aliases (
@@ -153,7 +155,7 @@ def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
153155 aggs .append ((self ._keys [0 ], "count" , pc .CountOptions (mode = "all" )))
154156 continue
155157
156- grouped , function_name , option = self ._configure_agg (grouped , expr )
158+ function_name , option = self ._configure_agg (expr )
157159 new_column_names .extend (aliases )
158160 expected_pyarrow_column_names .extend (
159161 [f"{ output_name } _{ function_name } " for output_name in output_names ]
0 commit comments