Skip to content
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
9542273
refactor: Move compatibility flags
dangotbanned Dec 24, 2025
f4b5778
feat(DRAFT): Start adding `ArrowDataFrame.pivot`
dangotbanned Dec 24, 2025
b96f9b2
feat(DRAFT): Handle `unnest`-ing the struct column
dangotbanned Dec 24, 2025
42e76db
feat(DRAFT): Add most of `DataFrame.pivot`
dangotbanned Dec 24, 2025
366e951
add the bits that work so far
dangotbanned Dec 24, 2025
efea337
test: start adding `pivot_test`
dangotbanned Dec 24, 2025
0dde622
test: Add example from docstring
dangotbanned Dec 24, 2025
2c69502
fix: Support multiple index columns
dangotbanned Dec 25, 2025
4079b74
feat: Support `ArrowDataFrame.pivot(values: list[str])`
dangotbanned Dec 25, 2025
3bd4e50
to-done
dangotbanned Dec 25, 2025
2091c91
test: Add most of the remaining tests
dangotbanned Dec 25, 2025
86469e4
test: Add all remaining tests
dangotbanned Dec 25, 2025
c1f86e4
test: Remove aggregation from `sort_columns` tests
dangotbanned Dec 25, 2025
8761c7c
feat: Support `DataFrame.pivot(sort_columns=True)`
dangotbanned Dec 25, 2025
aac49f8
test: Use fixtures instead of globals
dangotbanned Dec 25, 2025
cac30bc
test: Split up `test_pivot_on_multiple_names_out`
dangotbanned Dec 25, 2025
e298c06
test: Use more distinct names/values
dangotbanned Dec 25, 2025
1981613
test: Make `.columns`-only tests more visible
dangotbanned Dec 25, 2025
2e54da3
test: 🧹🧹🧹
dangotbanned Dec 25, 2025
8194b49
feat(DRAFT): Support `pivot(on: list[str], values: str)`
dangotbanned Dec 26, 2025
7d20578
oops, didn't mean to commit that
dangotbanned Dec 26, 2025
9795db8
"fix" `concat_str` typing
dangotbanned Dec 26, 2025
46d6a1d
fix: Short-circuit correctly in `sort` fastpath
dangotbanned Dec 26, 2025
7c77975
cov
dangotbanned Dec 26, 2025
30a118b
feat: Support `DataFrame.pivot(on: list[str], values: list[str])`
dangotbanned Dec 26, 2025
76fbe74
docs: Explain the funky stuff
dangotbanned Dec 26, 2025
4941455
refactor: Move cast->collect into `options.pivot_wider`
dangotbanned Dec 26, 2025
8e59131
todos
dangotbanned Dec 26, 2025
25c0b40
feat: Support `ArrowDataFrame.pivot_agg(on: str)`
dangotbanned Dec 26, 2025
b3d8c73
refactor: Add a dedicated `acero` wrapper
dangotbanned Dec 26, 2025
90c2132
feat: Support `ArrowDataFrame.pivot_agg(on: list[str])`
dangotbanned Dec 27, 2025
2fc1f05
refactor: Align `pivot_agg`, `pivot_on_multiple` a bit
dangotbanned Dec 27, 2025
8f998d5
refactor: Always pass down `on_columns` as a dataframe
dangotbanned Dec 27, 2025
f922fbf
refactor: Align unnest/renaming
dangotbanned Dec 27, 2025
ca8dea9
refactor: Temp move stuff to a common function
dangotbanned Dec 27, 2025
4078e8f
refactor: factor-in `pivot_on_single`
dangotbanned Dec 27, 2025
bc1396a
refactor: factor-in `pivot_on_multiple`
dangotbanned Dec 27, 2025
3b9b454
test: Split out `assert_names_match_polars`
dangotbanned Dec 27, 2025
62d3190
test: Port tests, discover new edge case
dangotbanned Dec 27, 2025
27e6600
refactor: Replace `implode -> index -> explode` w/ `index -> join`
dangotbanned Dec 27, 2025
8aa70e7
refactor: Fiddle some more
dangotbanned Dec 27, 2025
3797d8c
be stricter than pyarrow
dangotbanned Dec 27, 2025
10f3be3
test: xfail for older `pyarrow`
dangotbanned Dec 27, 2025
d97631b
ugh import
dangotbanned Dec 27, 2025
c74644f
test: Is `(15, 0, 1)` not less than `(20,)`?
dangotbanned Dec 27, 2025
993c98b
ah, they were conflicting marks!
dangotbanned Dec 27, 2025
65f3f37
Apply suggestions from code review
dangotbanned Dec 28, 2025
8cec922
refactor: Fully merge `pivot`, `pivot_agg` into 1 method
dangotbanned Dec 28, 2025
dabb266
perf: Huge simplify unnest/rename
dangotbanned Dec 28, 2025
6dec38b
doc, tidy, etc
dangotbanned Dec 28, 2025
d6a2c06
refactor: Just move everything to its own module
dangotbanned Dec 28, 2025
c4077d2
tweak renaming
dangotbanned Dec 28, 2025
e24c44a
refactor: mnove options too
dangotbanned Dec 28, 2025
211581b
test: Steal test from polars instead
dangotbanned Dec 28, 2025
58c985b
test: oof, what do we have here then
dangotbanned Dec 28, 2025
6f0a9c1
start adding `on_columns: Self`
dangotbanned Dec 28, 2025
90c238d
test: Cover errors in `on_columns: Series | DataFrame`
dangotbanned Dec 29, 2025
ad93dbb
test: Mark as a `pyarrow` bug
dangotbanned Dec 29, 2025
fc2e124
chore: Remove note
dangotbanned Dec 29, 2025
8b15f12
chore: document and cover wider type support
dangotbanned Dec 29, 2025
852c541
perf: Generate column names in a single `concat_str`
dangotbanned Dec 29, 2025
3d046a2
refactor: Revert `concat_str` changes, do weird bits inline
dangotbanned Dec 29, 2025
6392caa
docs: Link to `pyarrow` issue
dangotbanned Dec 29, 2025
859ae00
Merge branch 'oh-nodes' into expr-ir/pivot
dangotbanned Dec 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions narwhals/_plan/arrow/acero.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pyarrow.compute as pc # ignore-banned-import
from pyarrow.acero import Declaration as Decl

from narwhals._plan.arrow import options as pa_options
from narwhals._plan.common import ensure_list_str, temp
from narwhals._plan.typing import NonCrossJoinStrategy, OneOrSeq
from narwhals._utils import check_column_names_are_unique
Expand All @@ -47,7 +48,12 @@
Aggregation as _Aggregation,
)
from narwhals._plan.arrow.group_by import AggSpec
from narwhals._plan.arrow.typing import ArrowAny, JoinTypeSubset, ScalarAny
from narwhals._plan.arrow.typing import (
ArrowAny,
ChunkedOrArrayAny,
JoinTypeSubset,
ScalarAny,
)
from narwhals._plan.typing import OneOrIterable, Seq
from narwhals.typing import NonNestedLiteral

Expand All @@ -62,7 +68,15 @@

Target: TypeAlias = OneOrSeq[Field]
Aggregation: TypeAlias = Union[
"_Aggregation", Literal["hash_kurtosis", "hash_skew", "kurtosis", "skew"]
"_Aggregation",
Literal[
"hash_kurtosis",
"hash_skew",
"hash_pivot_wider",
"kurtosis",
"skew",
"pivot_wider",
],
]
AggregateOptions: TypeAlias = "_AggregateOptions"
Opts: TypeAlias = "AggregateOptions | None"
Expand Down Expand Up @@ -320,6 +334,22 @@ def group_by_table(
return collect(table_source(native), group_by(keys, aggs), use_threads=use_threads)


def pivot_table(
native: pa.Table,
on: str,
on_columns: ChunkedOrArrayAny | Sequence[Any],
/,
index: Sequence[str],
values: Sequence[str],
) -> pa.Table:
"""Partial `pivot` implementation."""
from narwhals._plan.arrow.group_by import AggSpec

options = pa_options.pivot_wider(on_columns)
specs = (AggSpec((on, name), "hash_pivot_wider", options, name) for name in values)
return group_by_table(native, index, specs)


def filter_table(native: pa.Table, *predicates: Expr, **constraints: Any) -> pa.Table:
"""Selects rows where all expressions evaluate to True.

Expand Down
5 changes: 3 additions & 2 deletions narwhals/_plan/arrow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from typing import TYPE_CHECKING, Any, ClassVar, Generic

from narwhals._plan.arrow.functions import BACKEND_VERSION, random_indices
from narwhals._plan.arrow import compat
from narwhals._plan.arrow.functions import random_indices
from narwhals._typing_compat import TypeVar
from narwhals._utils import Implementation, Version, _StoresNative

Expand Down Expand Up @@ -47,7 +48,7 @@ def __len__(self) -> int:
msg = f"{type(self).__name__}.__len__"
raise NotImplementedError(msg)

if BACKEND_VERSION >= (18,):
if compat.TAKE_ACCEPTS_TUPLE:

def _gather(self, indices: Indices) -> NativeT:
return self.native.take(indices)
Expand Down
41 changes: 41 additions & 0 deletions narwhals/_plan/arrow/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Flags for features not available in all supported `pyarrow` versions."""

from __future__ import annotations

from typing import Final

from narwhals._utils import Implementation

BACKEND_VERSION = Implementation.PYARROW._backend_version()
"""Static backend version for `pyarrow`."""

RANK_ACCEPTS_CHUNKED: Final = BACKEND_VERSION >= (14,)

HAS_FROM_TO_STRUCT_ARRAY: Final = BACKEND_VERSION >= (15,)
"""`pyarrow.Table.{from,to}_struct_array` added in https://github.com/apache/arrow/pull/38520"""


TABLE_RENAME_ACCEPTS_DICT: Final = BACKEND_VERSION >= (17,)

TAKE_ACCEPTS_TUPLE: Final = BACKEND_VERSION >= (18,)

HAS_STRUCT_TYPE_FIELDS: Final = BACKEND_VERSION >= (18,)
"""`pyarrow.StructType.fields` added in https://github.com/apache/arrow/pull/43481"""

HAS_SCATTER: Final = BACKEND_VERSION >= (20,)
"""`pyarrow.compute.scatter` added in https://github.com/apache/arrow/pull/44394"""

HAS_KURTOSIS_SKEW = BACKEND_VERSION >= (20,)
"""`pyarrow.compute.{kurtosis,skew}` added in https://github.com/apache/arrow/pull/45677"""

HAS_PIVOT_WIDER = BACKEND_VERSION >= (20,)
"""`pyarrow.compute.pivot_wider` added in https://github.com/apache/arrow/pull/45562"""

HAS_ARANGE: Final = BACKEND_VERSION >= (21,)
"""`pyarrow.arange` added in https://github.com/apache/arrow/pull/46778"""

TO_STRUCT_ARRAY_ACCEPTS_EMPTY: Final = BACKEND_VERSION >= (21,)
"""`pyarrow.Table.to_struct_array` fixed in https://github.com/apache/arrow/pull/46357"""

HAS_ZFILL: Final = BACKEND_VERSION >= (21,)
"""`pyarrow.compute.utf8_zero_fill` added in https://github.com/apache/arrow/pull/46815"""
169 changes: 162 additions & 7 deletions narwhals/_plan/arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@

import operator
from functools import reduce
from itertools import chain
from itertools import chain, product
from typing import TYPE_CHECKING, Any, Literal, cast, overload

import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

from narwhals._arrow.utils import native_to_narwhals_dtype
from narwhals._plan.arrow import acero, functions as fn
from narwhals._plan.arrow import (
acero,
compat,
functions as fn,
group_by,
options as pa_options,
)
from narwhals._plan.arrow.common import ArrowFrameSeries as FrameSeries
from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar
from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy, partition_by
Expand All @@ -29,14 +35,20 @@
import polars as pl
from typing_extensions import Self, TypeAlias

from narwhals._plan.arrow.typing import ChunkedArrayAny, ChunkedOrArrayAny, Predicate
from narwhals._plan.arrow.typing import (
ChunkedArrayAny,
ChunkedOrArrayAny,
ChunkedStruct,
Predicate,
StructArray,
)
from narwhals._plan.compliant.group_by import GroupByResolver
from narwhals._plan.expressions import ExprIR, NamedIR
from narwhals._plan.options import ExplodeOptions, SortMultipleOptions
from narwhals._plan.typing import NonCrossJoinStrategy
from narwhals._typing import _LazyAllowedImpl
from narwhals.dtypes import DType
from narwhals.typing import IntoSchema, UniqueKeepStrategy
from narwhals.typing import IntoSchema, PivotAgg, UniqueKeepStrategy

Incomplete: TypeAlias = Any

Expand Down Expand Up @@ -218,9 +230,9 @@ def write_parquet(self, target: str | BytesIO, /) -> None:

def to_struct(self, name: str = "") -> Series:
native = self.native
if fn.TO_STRUCT_ARRAY_ACCEPTS_EMPTY:
if compat.TO_STRUCT_ARRAY_ACCEPTS_EMPTY:
struct = native.to_struct_array()
elif fn.HAS_FROM_TO_STRUCT_ARRAY:
elif compat.HAS_FROM_TO_STRUCT_ARRAY:
if len(native):
struct = native.to_struct_array()
else:
Expand Down Expand Up @@ -252,7 +264,7 @@ def explode(self, subset: Sequence[str], options: ExplodeOptions) -> Self:

def rename(self, mapping: Mapping[str, str]) -> Self:
names: dict[str, str] | list[str]
if fn.BACKEND_VERSION >= (17,):
if compat.TABLE_RENAME_ACCEPTS_DICT:
names = cast("dict[str, str]", mapping)
else: # pragma: no cover
names = [mapping.get(c, c) for c in self.columns]
Expand Down Expand Up @@ -325,6 +337,79 @@ def partition_by(self, by: Sequence[str], *, include_key: bool = True) -> list[S
partitions = partition_by(self.native, by, include_key=include_key)
return [from_native(df) for df in partitions]

def pivot(
self,
on: Sequence[str],
on_columns: Self,
*,
index: Sequence[str],
values: Sequence[str],
separator: str = "_",
) -> Self:
native = self.native
on_columns_ = on_columns.native
if len(on) == 1:
pivot = acero.pivot_table(native, on[0], on_columns_.column(0), index, values)
else:
temp_name = temp.column_name(native.column_names)
on_columns_w_idx = on_columns.with_row_index(temp_name)
on_columns_encoded = on_columns_w_idx.get_column(temp_name).native
single_on = self.join_inner(on_columns_w_idx, list(on)).drop(on).native
# NOTE: Almost identical to `pivot_agg` now!
pivot = acero.pivot_table(
single_on, temp_name, on_columns_encoded, index, values
)
return self._finish_pivot(pivot, on_columns_, index, values, separator)

# TODO @dangotbanned: Align each of the impls more, then de-duplicate
def pivot_agg(
self,
on: Sequence[str],
on_columns: Self,
*,
index: Sequence[str],
values: Sequence[str],
aggregate_function: PivotAgg,
separator: str = "_",
) -> Self:
native = self.native
tp_agg = group_by.SUPPORTED_PIVOT_AGG[aggregate_function]
agg_func = group_by.SUPPORTED_AGG[tp_agg]
option = pa_options.AGG.get(tp_agg)
specs = (group_by.AggSpec(value, agg_func, option) for value in values)

if len(on) == 1:
pre_agg = acero.group_by_table(native, [*index, *on], specs)
return self._with_native(pre_agg).pivot(
on, on_columns, index=index, values=values, separator=separator
)
on_columns_ = on_columns.native
temp_name = temp.column_name(native.column_names)
on_columns_w_idx = on_columns.with_row_index(temp_name)
on_columns_encoded = on_columns_w_idx.get_column(temp_name).native
single_on = self.join_inner(on_columns_w_idx, list(on)).drop(on).native

pre_agg = acero.group_by_table(single_on, [*index, temp_name], specs)
# this part is the tricky one, since the pivot and the renaming use different reprs for `on_columns`
pivot = acero.pivot_table(pre_agg, temp_name, on_columns_encoded, index, values)
return self._finish_pivot(pivot, on_columns_, index, values, separator)

def _finish_pivot(
self,
pivot: pa.Table,
on_columns: pa.Table,
index: Sequence[str],
values: Sequence[str],
separator: str = "_",
) -> Self:
# Everything here should be moved to `acero.pivot_table` if possible
pivot_columns = pivot.columns
n_index = len(index)
unnested = structs_to_arrays(*pivot_columns[n_index:], flatten=True)
names = (*index, *_on_columns_names(on_columns, values, separator=separator))
result = fn.concat_horizontal((*pivot_columns[:n_index], *unnested), names)
return self._with_native(result)


def with_array(table: pa.Table, name: str, column: ChunkedOrArrayAny) -> pa.Table:
column_names = table.column_names
Expand All @@ -343,3 +428,73 @@ def with_arrays(
else:
table = table.append_column(name, column)
return table


def struct_to_arrays(native: ChunkedStruct | StructArray) -> Sequence[ChunkedOrArrayAny]:
"""Unnest the fields of a struct into one array per-struct-field.

Cheaper than `unnest`-ing into a `Table`, and very helpful if the names are going to be replaced.
"""
return cast("ChunkedStruct | pa.StructArray", native).flatten()


@overload
def structs_to_arrays(
*structs: ChunkedStruct | StructArray,
) -> Iterator[Sequence[ChunkedOrArrayAny]]: ...
@overload
def structs_to_arrays(
*structs: ChunkedStruct | StructArray, flatten: Literal[True]
) -> Iterator[ChunkedOrArrayAny]: ...
def structs_to_arrays(
*structs: ChunkedStruct | StructArray, flatten: bool = False
) -> Iterator[Sequence[ChunkedOrArrayAny] | ChunkedOrArrayAny]:
"""Unnest the fields of every struct into one array per-struct-field.

By default, yields the arrays of each struct *as a group*, configurable via `flatten`.

Arguments:
*structs: One or more Struct-typed arrow arrays.
flatten: Yield each array from each struct *without grouping*.
"""
if flatten:
for struct in structs:
yield from struct_to_arrays(struct)
else:
for struct in structs:
yield struct_to_arrays(struct)


def _on_columns_names(
on_columns: pa.Table, values: Sequence[str], *, separator: str = "_"
) -> Iterable[str]:
"""Alignment to polars pivot column naming conventions.

If we started with:

{'on_lower': ['b', 'a', 'b', 'a'], 'on_upper': ['X', 'X', 'Y', 'Y']}

Then this operation will return:

['{"b","X"}', '{"a","X"}', '{"b","Y"}', '{"a","Y"}']
"""
result: Iterable[Any]
if on_columns.num_columns == 1:
on_column = on_columns.column(0)
if len(values) == 1:
result = on_column.to_pylist()
else:
t_left = fn.to_table(fn.array(values))
# NOTE: still don't know why pyarrow outputs the cross join in reverse
t_right = fn.to_table(fn.reverse(on_column))
cross_joined = acero.join_cross_tables(t_left, t_right)
result = fn.concat_str(*cross_joined.columns, separator=separator).to_pylist()
else:
result = fn.concat_str(
'{"', fn.concat_str(*on_columns.columns, separator='","'), '"}'
).to_pylist()
if len(values) != 1:
return (
f"{value}{separator}{name}" for value, name in product(values, result)
)
return cast("list[str]", result)
Loading
Loading