Skip to content

Commit e55aeb0

Browse files
committed
refactor: Move partition_by impl to group_by.py
1 parent c4d494a commit e55aeb0

File tree

2 files changed

+59
-61
lines changed

2 files changed

+59
-61
lines changed

narwhals/_plan/arrow/dataframe.py

Lines changed: 2 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,10 @@
99
import pyarrow.compute as pc # ignore-banned-import
1010

1111
from narwhals._arrow.utils import native_to_narwhals_dtype
12-
from narwhals._plan.arrow import acero, functions as fn, group_by
12+
from narwhals._plan.arrow import acero, functions as fn
1313
from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar
14-
from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy
14+
from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy, partition_by
1515
from narwhals._plan.arrow.series import ArrowSeries as Series
16-
from narwhals._plan.common import temp
1716
from narwhals._plan.compliant.dataframe import EagerDataFrame
1817
from narwhals._plan.compliant.typing import namespace
1918
from narwhals._plan.expressions import NamedIR
@@ -179,50 +178,3 @@ def partition_by(self, by: Sequence[str], *, include_key: bool = True) -> list[S
179178
from_native = self._with_native
180179
partitions = partition_by(self.native, by, include_key=include_key)
181180
return [from_native(df) for df in partitions]
182-
183-
184-
def partition_by(
185-
native: pa.Table, by: Sequence[str], *, include_key: bool = True
186-
) -> Iterator[pa.Table]:
187-
if len(by) == 1:
188-
yield from _partition_by_one(native, by[0], include_key=include_key)
189-
else:
190-
yield from _partition_by_many(native, by, include_key=include_key)
191-
192-
193-
def _partition_by_one(
194-
native: pa.Table, by: str, *, include_key: bool = True
195-
) -> Iterator[pa.Table]:
196-
"""Optimized path for single-column partition."""
197-
arr_dict: Incomplete = fn.array(native.column(by).dictionary_encode("encode"))
198-
indices: pa.Int32Array = arr_dict.indices
199-
if not include_key:
200-
native = native.remove_column(native.schema.get_field_index(by))
201-
for idx in range(len(arr_dict.dictionary)):
202-
# NOTE: Acero filter doesn't support `null_selection_behavior="emit_null"`
203-
# Is there any reasonable way to do this in Acero?
204-
yield native.filter(pc.equal(pa.scalar(idx), indices))
205-
206-
207-
def _partition_by_many(
208-
native: pa.Table, by: Sequence[str], *, include_key: bool = True
209-
) -> Iterator[pa.Table]:
210-
original_names = native.column_names
211-
temp_name = temp.column_name(original_names)
212-
key = acero.col(temp_name)
213-
composite_values = group_by.concat_str(acero.select_names_table(native, by))
214-
# Need to iterate over the whole thing, so py_list first should be faster
215-
unique_py = composite_values.unique().to_pylist()
216-
re_keyed = native.add_column(0, temp_name, composite_values)
217-
source = acero.table_source(re_keyed)
218-
if include_key:
219-
keep = original_names
220-
else:
221-
ignore = {*by, temp_name}
222-
keep = [name for name in original_names if name not in ignore]
223-
select = acero.select_names(keep)
224-
for v in unique_py:
225-
# NOTE: May want to split the `Declaration` production iterator into it's own function
226-
# E.g, to push down column selection to *before* collection
227-
# Not needed for this task though
228-
yield acero.collect(source, acero.filter(key == v), select)

narwhals/_plan/arrow/group_by.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
from narwhals._plan._dispatch import get_dispatch_name
1010
from narwhals._plan._guards import is_agg_expr, is_function_expr
1111
from narwhals._plan.arrow import acero, functions as fn, options
12+
from narwhals._plan.common import temp
1213
from narwhals._plan.compliant.group_by import EagerDataFrameGroupBy
1314
from narwhals._plan.expressions import aggregation as agg
1415
from narwhals._utils import Implementation
1516
from narwhals.exceptions import InvalidOperationError
1617

1718
if TYPE_CHECKING:
18-
from collections.abc import Iterator, Mapping
19+
from collections.abc import Iterator, Mapping, Sequence
1920

2021
from typing_extensions import Self, TypeAlias
2122

@@ -137,14 +138,6 @@ def group_by_error(
137138
return InvalidOperationError(msg)
138139

139140

140-
def concat_str(native: pa.Table, *, separator: str = "") -> ChunkedArray:
141-
dtype = fn.string_type(native.schema.types)
142-
it = fn.cast_table(native, dtype).itercolumns()
143-
concat: Incomplete = pc.binary_join_element_wise
144-
join = options.join_replace_nulls()
145-
return concat(*it, fn.lit(separator, dtype), options=join) # type: ignore[no-any-return]
146-
147-
148141
class ArrowGroupBy(EagerDataFrameGroupBy["Frame"]):
149142
_df: Frame
150143
_keys: Seq[NamedIR]
@@ -156,8 +149,6 @@ def compliant(self) -> Frame:
156149
return self._df
157150

158151
def __iter__(self) -> Iterator[tuple[Any, Frame]]:
159-
from narwhals._plan.arrow.dataframe import partition_by
160-
161152
by = self.key_names
162153
from_native = self.compliant._with_native
163154
for partition in partition_by(self.compliant.native, by):
@@ -176,3 +167,58 @@ def agg(self, irs: Seq[NamedIR]) -> Frame:
176167
if original := self._key_names_original:
177168
return result.rename(dict(zip(key_names, original)))
178169
return result
170+
171+
172+
def concat_str(native: pa.Table, *, separator: str = "") -> ChunkedArray:
173+
dtype = fn.string_type(native.schema.types)
174+
it = fn.cast_table(native, dtype).itercolumns()
175+
concat: Incomplete = pc.binary_join_element_wise
176+
join = options.join_replace_nulls()
177+
return concat(*it, fn.lit(separator, dtype), options=join) # type: ignore[no-any-return]
178+
179+
180+
def partition_by(
181+
native: pa.Table, by: Sequence[str], *, include_key: bool = True
182+
) -> Iterator[pa.Table]:
183+
if len(by) == 1:
184+
yield from _partition_by_one(native, by[0], include_key=include_key)
185+
else:
186+
yield from _partition_by_many(native, by, include_key=include_key)
187+
188+
189+
def _partition_by_one(
190+
native: pa.Table, by: str, *, include_key: bool = True
191+
) -> Iterator[pa.Table]:
192+
"""Optimized path for single-column partition."""
193+
arr_dict: Incomplete = fn.array(native.column(by).dictionary_encode("encode"))
194+
indices: pa.Int32Array = arr_dict.indices
195+
if not include_key:
196+
native = native.remove_column(native.schema.get_field_index(by))
197+
for idx in range(len(arr_dict.dictionary)):
198+
# NOTE: Acero filter doesn't support `null_selection_behavior="emit_null"`
199+
# Is there any reasonable way to do this in Acero?
200+
yield native.filter(pc.equal(pa.scalar(idx), indices))
201+
202+
203+
def _partition_by_many(
204+
native: pa.Table, by: Sequence[str], *, include_key: bool = True
205+
) -> Iterator[pa.Table]:
206+
original_names = native.column_names
207+
temp_name = temp.column_name(original_names)
208+
key = acero.col(temp_name)
209+
composite_values = concat_str(acero.select_names_table(native, by))
210+
# Need to iterate over the whole thing, so py_list first should be faster
211+
unique_py = composite_values.unique().to_pylist()
212+
re_keyed = native.add_column(0, temp_name, composite_values)
213+
source = acero.table_source(re_keyed)
214+
if include_key:
215+
keep = original_names
216+
else:
217+
ignore = {*by, temp_name}
218+
keep = [name for name in original_names if name not in ignore]
219+
select = acero.select_names(keep)
220+
for v in unique_py:
221+
# NOTE: May want to split the `Declaration` production iterator into it's own function
222+
# E.g, to push down column selection to *before* collection
223+
# Not needed for this task though
224+
yield acero.collect(source, acero.filter(key == v), select)

0 commit comments

Comments
 (0)