Skip to content

Commit ac779dd

Browse files
committed
perf: Add an optimized path for single-column partition_by
Avoids the need for a tempoary composite key column, by using `dictionary_encode` and generating boolean masks based on index position
1 parent f17781a commit ac779dd

File tree

2 files changed

+75
-31
lines changed

2 files changed

+75
-31
lines changed

narwhals/_plan/arrow/dataframe.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
if TYPE_CHECKING:
2525
from collections.abc import Iterable, Iterator, Mapping, Sequence
2626

27-
from typing_extensions import Self
27+
from typing_extensions import Self, TypeAlias
2828

2929
from narwhals._arrow.typing import ChunkedArrayAny
3030
from narwhals._plan.arrow.namespace import ArrowNamespace
@@ -34,6 +34,8 @@
3434
from narwhals.dtypes import DType
3535
from narwhals.typing import IntoSchema
3636

37+
Incomplete: TypeAlias = Any
38+
3739

3840
class ArrowDataFrame(EagerDataFrame[Series, "pa.Table", "ChunkedArrayAny"]):
3941
implementation = Implementation.PYARROW
@@ -173,25 +175,54 @@ def filter(self, predicate: NamedIR) -> Self:
173175
mask = acero.lit(resolved.native)
174176
return self._with_native(self.native.filter(mask))
175177

176-
# TODO @dangotbanned: Clean this up after getting more tests in place
177178
def partition_by(self, by: Sequence[str], *, include_key: bool = True) -> list[Self]:
178-
original_names = self.columns
179-
temp_name = temp.column_name(original_names)
180-
native = self.native
181-
composite_values = group_by.concat_str(acero.select_names_table(native, by))
182-
re_keyed = native.add_column(0, temp_name, composite_values)
183-
source = acero.table_source(re_keyed)
184-
if include_key:
185-
keep = original_names
186-
else:
187-
ignore = {*by, temp_name}
188-
keep = [name for name in original_names if name not in ignore]
189-
select = acero.select_names(keep)
190-
key = acero.col(temp_name)
191-
# Need to iterate over the whole thing, so py_list first should be faster
192-
partitions = (
193-
acero.declare(source, acero.filter(key == v), select)
194-
for v in composite_values.unique().to_pylist()
195-
)
196179
from_native = self._with_native
197-
return [from_native(decl.to_table()) for decl in partitions]
180+
partitions = partition_by(self.native, by, include_key=include_key)
181+
return [from_native(df) for df in partitions]
182+
183+
184+
def partition_by(
185+
native: pa.Table, by: Sequence[str], *, include_key: bool = True
186+
) -> Iterator[pa.Table]:
187+
if len(by) == 1:
188+
yield from _partition_by_one(native, by[0], include_key=include_key)
189+
else:
190+
yield from _partition_by_many(native, by, include_key=include_key)
191+
192+
193+
def _partition_by_one(
194+
native: pa.Table, by: str, *, include_key: bool = True
195+
) -> Iterator[pa.Table]:
196+
"""Optimized path for single-column partition."""
197+
arr_dict: Incomplete = fn.array(native.column(by).dictionary_encode("encode"))
198+
indices: pa.Int32Array = arr_dict.indices
199+
if not include_key:
200+
native = native.remove_column(native.schema.get_field_index(by))
201+
for idx in range(len(arr_dict.dictionary)):
202+
# NOTE: Acero filter doesn't support `null_selection_behavior="emit_null"`
203+
# Is there any reasonable way to do this in Acero?
204+
yield native.filter(pc.equal(pa.scalar(idx), indices))
205+
206+
207+
def _partition_by_many(
208+
native: pa.Table, by: Sequence[str], *, include_key: bool = True
209+
) -> Iterator[pa.Table]:
210+
original_names = native.column_names
211+
temp_name = temp.column_name(original_names)
212+
key = acero.col(temp_name)
213+
composite_values = group_by.concat_str(acero.select_names_table(native, by))
214+
# Need to iterate over the whole thing, so py_list first should be faster
215+
unique_py = composite_values.unique().to_pylist()
216+
re_keyed = native.add_column(0, temp_name, composite_values)
217+
source = acero.table_source(re_keyed)
218+
if include_key:
219+
keep = original_names
220+
else:
221+
ignore = {*by, temp_name}
222+
keep = [name for name in original_names if name not in ignore]
223+
select = acero.select_names(keep)
224+
for v in unique_py:
225+
# NOTE: May want to split the `Declaration` production iterator into it's own function
226+
# E.g, to push down column selection to *before* collection
227+
# Not needed for this task though
228+
yield acero.collect(source, acero.filter(key == v), select)

narwhals/_plan/arrow/functions.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import typing as t
66
from collections.abc import Callable
7-
from typing import TYPE_CHECKING, Any
7+
from typing import TYPE_CHECKING, Any, overload
88

99
import pyarrow as pa # ignore-banned-import
1010
import pyarrow.compute as pc # ignore-banned-import
@@ -54,7 +54,7 @@
5454
StringType,
5555
UnaryFunction,
5656
)
57-
from narwhals.typing import ClosedInterval, IntoArrowSchema
57+
from narwhals.typing import ClosedInterval, IntoArrowSchema, PythonLiteral
5858

5959
BACKEND_VERSION = Implementation.PYARROW._backend_version()
6060

@@ -348,20 +348,33 @@ def lit(value: Any, dtype: DataType | None = None) -> NativeScalar:
348348
return pa.scalar(value) if dtype is None else pa.scalar(value, dtype)
349349

350350

351+
@overload
352+
def array(data: ArrowAny, /) -> ArrayAny: ...
353+
@overload
351354
def array(
352-
value: NativeScalar | Iterable[Any], dtype: DataType | None = None, /
355+
data: Iterable[PythonLiteral], dtype: DataType | None = None, /
356+
) -> ArrayAny: ...
357+
def array(
358+
data: ArrowAny | Iterable[PythonLiteral], dtype: DataType | None = None, /
353359
) -> ArrayAny:
354-
return (
355-
pa.array([value], value.type)
356-
if isinstance(value, pa.Scalar)
357-
else pa.array(value, dtype)
358-
)
360+
"""Convert `data` into an Array instance.
361+
362+
Note:
363+
`dtype` is not used for existing `pyarrow` data, use `cast` instead.
364+
"""
365+
if isinstance(data, pa.ChunkedArray):
366+
return data.combine_chunks()
367+
if isinstance(data, pa.Array):
368+
return data
369+
if isinstance(data, pa.Scalar):
370+
return pa.array([data], data.type)
371+
return pa.array(data, dtype)
359372

360373

361374
def chunked_array(
362-
arr: ArrowAny | list[Iterable[Any]], dtype: DataType | None = None, /
375+
data: ArrowAny | list[Iterable[Any]], dtype: DataType | None = None, /
363376
) -> ChunkedArrayAny:
364-
return _chunked_array(array(arr) if isinstance(arr, pa.Scalar) else arr, dtype)
377+
return _chunked_array(array(data) if isinstance(data, pa.Scalar) else data, dtype)
365378

366379

367380
def concat_vertical_chunked(

0 commit comments

Comments
 (0)