99from narwhals ._plan ._dispatch import get_dispatch_name
1010from narwhals ._plan ._guards import is_agg_expr , is_function_expr
1111from narwhals ._plan .arrow import acero , functions as fn , options
12+ from narwhals ._plan .common import temp
1213from narwhals ._plan .compliant .group_by import EagerDataFrameGroupBy
1314from narwhals ._plan .expressions import aggregation as agg
1415from narwhals ._utils import Implementation
1516from narwhals .exceptions import InvalidOperationError
1617
1718if TYPE_CHECKING :
18- from collections .abc import Iterator , Mapping
19+ from collections .abc import Iterator , Mapping , Sequence
1920
2021 from typing_extensions import Self , TypeAlias
2122
@@ -137,14 +138,6 @@ def group_by_error(
137138 return InvalidOperationError (msg )
138139
139140
140- def concat_str (native : pa .Table , * , separator : str = "" ) -> ChunkedArray :
141- dtype = fn .string_type (native .schema .types )
142- it = fn .cast_table (native , dtype ).itercolumns ()
143- concat : Incomplete = pc .binary_join_element_wise
144- join = options .join_replace_nulls ()
145- return concat (* it , fn .lit (separator , dtype ), options = join ) # type: ignore[no-any-return]
146-
147-
148141class ArrowGroupBy (EagerDataFrameGroupBy ["Frame" ]):
149142 _df : Frame
150143 _keys : Seq [NamedIR ]
@@ -156,8 +149,6 @@ def compliant(self) -> Frame:
156149 return self ._df
157150
158151 def __iter__ (self ) -> Iterator [tuple [Any , Frame ]]:
159- from narwhals ._plan .arrow .dataframe import partition_by
160-
161152 by = self .key_names
162153 from_native = self .compliant ._with_native
163154 for partition in partition_by (self .compliant .native , by ):
@@ -176,3 +167,58 @@ def agg(self, irs: Seq[NamedIR]) -> Frame:
176167 if original := self ._key_names_original :
177168 return result .rename (dict (zip (key_names , original )))
178169 return result
170+
171+
172+ def concat_str (native : pa .Table , * , separator : str = "" ) -> ChunkedArray :
173+ dtype = fn .string_type (native .schema .types )
174+ it = fn .cast_table (native , dtype ).itercolumns ()
175+ concat : Incomplete = pc .binary_join_element_wise
176+ join = options .join_replace_nulls ()
177+ return concat (* it , fn .lit (separator , dtype ), options = join ) # type: ignore[no-any-return]
178+
179+
180+ def partition_by (
181+ native : pa .Table , by : Sequence [str ], * , include_key : bool = True
182+ ) -> Iterator [pa .Table ]:
183+ if len (by ) == 1 :
184+ yield from _partition_by_one (native , by [0 ], include_key = include_key )
185+ else :
186+ yield from _partition_by_many (native , by , include_key = include_key )
187+
188+
189+ def _partition_by_one (
190+ native : pa .Table , by : str , * , include_key : bool = True
191+ ) -> Iterator [pa .Table ]:
192+ """Optimized path for single-column partition."""
193+ arr_dict : Incomplete = fn .array (native .column (by ).dictionary_encode ("encode" ))
194+ indices : pa .Int32Array = arr_dict .indices
195+ if not include_key :
196+ native = native .remove_column (native .schema .get_field_index (by ))
197+ for idx in range (len (arr_dict .dictionary )):
198+ # NOTE: Acero filter doesn't support `null_selection_behavior="emit_null"`
199+ # Is there any reasonable way to do this in Acero?
200+ yield native .filter (pc .equal (pa .scalar (idx ), indices ))
201+
202+
203+ def _partition_by_many (
204+ native : pa .Table , by : Sequence [str ], * , include_key : bool = True
205+ ) -> Iterator [pa .Table ]:
206+ original_names = native .column_names
207+ temp_name = temp .column_name (original_names )
208+ key = acero .col (temp_name )
209+ composite_values = concat_str (acero .select_names_table (native , by ))
210+ # Need to iterate over the whole thing, so py_list first should be faster
211+ unique_py = composite_values .unique ().to_pylist ()
212+ re_keyed = native .add_column (0 , temp_name , composite_values )
213+ source = acero .table_source (re_keyed )
214+ if include_key :
215+ keep = original_names
216+ else :
217+ ignore = {* by , temp_name }
218+ keep = [name for name in original_names if name not in ignore ]
219+ select = acero .select_names (keep )
220+ for v in unique_py :
221+ # NOTE: May want to split the `Declaration` production iterator into it's own function
222+ # E.g, to push down column selection to *before* collection
223+ # Not needed for this task though
224+ yield acero .collect (source , acero .filter (key == v ), select )
0 commit comments