Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
9542273
refactor: Move compatibility flags
dangotbanned Dec 24, 2025
f4b5778
feat(DRAFT): Start adding `ArrowDataFrame.pivot`
dangotbanned Dec 24, 2025
b96f9b2
feat(DRAFT): Handle `unnest`-ing the struct column
dangotbanned Dec 24, 2025
42e76db
feat(DRAFT): Add most of `DataFrame.pivot`
dangotbanned Dec 24, 2025
366e951
add the bits that work so far
dangotbanned Dec 24, 2025
efea337
test: start adding `pivot_test`
dangotbanned Dec 24, 2025
0dde622
test: Add example from docstring
dangotbanned Dec 24, 2025
2c69502
fix: Support multiple index columns
dangotbanned Dec 25, 2025
4079b74
feat: Support `ArrowDataFrame.pivot(values: list[str])`
dangotbanned Dec 25, 2025
3bd4e50
to-done
dangotbanned Dec 25, 2025
2091c91
test: Add most of the remaining tests
dangotbanned Dec 25, 2025
86469e4
test: Add all remaining tests
dangotbanned Dec 25, 2025
c1f86e4
test: Remove aggregation from `sort_columns` tests
dangotbanned Dec 25, 2025
8761c7c
feat: Support `DataFrame.pivot(sort_columns=True)`
dangotbanned Dec 25, 2025
aac49f8
test: Use fixtures instead of globals
dangotbanned Dec 25, 2025
cac30bc
test: Split up `test_pivot_on_multiple_names_out`
dangotbanned Dec 25, 2025
e298c06
test: Use more distinct names/values
dangotbanned Dec 25, 2025
1981613
test: Make `.columns`-only tests more visible
dangotbanned Dec 25, 2025
2e54da3
test: 🧹🧹🧹
dangotbanned Dec 25, 2025
8194b49
feat(DRAFT): Support `pivot(on: list[str], values: str)`
dangotbanned Dec 26, 2025
7d20578
oops, didn't mean to commit that
dangotbanned Dec 26, 2025
9795db8
"fix" `concat_str` typing
dangotbanned Dec 26, 2025
46d6a1d
fix: Short-circuit correctly in `sort` fastpath
dangotbanned Dec 26, 2025
7c77975
cov
dangotbanned Dec 26, 2025
30a118b
feat: Support `DataFrame.pivot(on: list[str], values: list[str])`
dangotbanned Dec 26, 2025
76fbe74
docs: Explain the funky stuff
dangotbanned Dec 26, 2025
4941455
refactor: Move cast->collect into `options.pivot_wider`
dangotbanned Dec 26, 2025
8e59131
todos
dangotbanned Dec 26, 2025
25c0b40
feat: Support `ArrowDataFrame.pivot_agg(on: str)`
dangotbanned Dec 26, 2025
b3d8c73
refactor: Add a dedicated `acero` wrapper
dangotbanned Dec 26, 2025
90c2132
feat: Support `ArrowDataFrame.pivot_agg(on: list[str])`
dangotbanned Dec 27, 2025
2fc1f05
refactor: Align `pivot_agg`, `pivot_on_multiple` a bit
dangotbanned Dec 27, 2025
8f998d5
refactor: Always pass down `on_columns` as a dataframe
dangotbanned Dec 27, 2025
f922fbf
refactor: Align unnest/renaming
dangotbanned Dec 27, 2025
ca8dea9
refactor: Temp move stuff to a common function
dangotbanned Dec 27, 2025
4078e8f
refactor: factor-in `pivot_on_single`
dangotbanned Dec 27, 2025
bc1396a
refactor: factor-in `pivot_on_multiple`
dangotbanned Dec 27, 2025
3b9b454
test: Split out `assert_names_match_polars`
dangotbanned Dec 27, 2025
62d3190
test: Port tests, discover new edge case
dangotbanned Dec 27, 2025
27e6600
refactor: Replace `implode -> index -> explode` w/ `index -> join`
dangotbanned Dec 27, 2025
8aa70e7
refactor: Fiddle some more
dangotbanned Dec 27, 2025
3797d8c
be stricter than pyarrow
dangotbanned Dec 27, 2025
10f3be3
test: xfail for older `pyarrow`
dangotbanned Dec 27, 2025
d97631b
ugh import
dangotbanned Dec 27, 2025
c74644f
test: Is `(15, 0, 1)` not less than `(20,)`?
dangotbanned Dec 27, 2025
993c98b
ah, they were conflicting marks!
dangotbanned Dec 27, 2025
65f3f37
Apply suggestions from code review
dangotbanned Dec 28, 2025
8cec922
refactor: Fully merge `pivot`, `pivot_agg` into 1 method
dangotbanned Dec 28, 2025
dabb266
perf: Huge simplify unnest/rename
dangotbanned Dec 28, 2025
6dec38b
doc, tidy, etc
dangotbanned Dec 28, 2025
d6a2c06
refactor: Just move everything to its own module
dangotbanned Dec 28, 2025
c4077d2
tweak renaming
dangotbanned Dec 28, 2025
e24c44a
refactor: mnove options too
dangotbanned Dec 28, 2025
211581b
test: Steal test from polars instead
dangotbanned Dec 28, 2025
58c985b
test: oof, what do we have here then
dangotbanned Dec 28, 2025
6f0a9c1
start adding `on_columns: Self`
dangotbanned Dec 28, 2025
90c238d
test: Cover errors in `on_columns: Series | DataFrame`
dangotbanned Dec 29, 2025
ad93dbb
test: Mark as a `pyarrow` bug
dangotbanned Dec 29, 2025
fc2e124
chore: Remove note
dangotbanned Dec 29, 2025
8b15f12
chore: document and cover wider type support
dangotbanned Dec 29, 2025
852c541
perf: Generate column names in a single `concat_str`
dangotbanned Dec 29, 2025
3d046a2
refactor: Revert `concat_str` changes, do weird bits inline
dangotbanned Dec 29, 2025
6392caa
docs: Link to `pyarrow` issue
dangotbanned Dec 29, 2025
859ae00
Merge branch 'oh-nodes' into expr-ir/pivot
dangotbanned Dec 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion narwhals/_plan/arrow/acero.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,15 @@

Target: TypeAlias = OneOrSeq[Field]
Aggregation: TypeAlias = Union[
"_Aggregation", Literal["hash_kurtosis", "hash_skew", "kurtosis", "skew"]
"_Aggregation",
Literal[
"hash_kurtosis",
"hash_skew",
"hash_pivot_wider",
"kurtosis",
"skew",
"pivot_wider",
],
]
AggregateOptions: TypeAlias = "_AggregateOptions"
Opts: TypeAlias = "AggregateOptions | None"
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_plan/arrow/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
TAKE_ACCEPTS_TUPLE: Final = BACKEND_VERSION >= (18,)

HAS_STRUCT_TYPE_FIELDS: Final = BACKEND_VERSION >= (18,)
"""`pyarrow.StructType.fields` added in https://github.com/apache/arrow/pull/43481"""
"""`pyarrow.StructType.{fields,names}` added in https://github.com/apache/arrow/pull/43481"""

HAS_SCATTER: Final = BACKEND_VERSION >= (20,)
"""`pyarrow.compute.scatter` added in https://github.com/apache/arrow/pull/44394"""
Expand Down
24 changes: 23 additions & 1 deletion narwhals/_plan/arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from narwhals._plan.arrow.common import ArrowFrameSeries as FrameSeries
from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar
from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy, partition_by
from narwhals._plan.arrow.pivot import pivot_table
from narwhals._plan.arrow.series import ArrowSeries as Series
from narwhals._plan.common import temp
from narwhals._plan.compliant.dataframe import EagerDataFrame
Expand All @@ -36,7 +37,7 @@
from narwhals._plan.typing import NonCrossJoinStrategy
from narwhals._typing import _LazyAllowedImpl
from narwhals.dtypes import DType
from narwhals.typing import IntoSchema, UniqueKeepStrategy
from narwhals.typing import IntoSchema, PivotAgg, UniqueKeepStrategy

Incomplete: TypeAlias = Any

Expand Down Expand Up @@ -325,6 +326,27 @@ def partition_by(self, by: Sequence[str], *, include_key: bool = True) -> list[S
partitions = partition_by(self.native, by, include_key=include_key)
return [from_native(df) for df in partitions]

def pivot(
self,
on: Sequence[str],
on_columns: Self,
*,
index: Sequence[str],
values: Sequence[str],
aggregate_function: PivotAgg | None = None,
separator: str = "_",
) -> Self:
result = pivot_table(
self.native,
list(on),
on_columns.native,
index,
values,
aggregate_function,
separator,
)
return self._with_native(result)


def with_array(table: pa.Table, name: str, column: ChunkedOrArrayAny) -> pa.Table:
column_names = table.column_names
Expand Down
7 changes: 7 additions & 0 deletions narwhals/_plan/arrow/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,12 @@ def struct_schema(native: Arrow[pa.StructScalar] | pa.StructType) -> pa.Schema:
return pa.schema(fields)


def struct_field_names(native: Arrow[pa.StructScalar] | pa.StructType) -> list[str]:
"""Get the names of all struct fields."""
tp = native.type if _is_arrow(native) else native
return tp.names if compat.HAS_STRUCT_TYPE_FIELDS else [f.name for f in tp]


@t.overload
def struct_field(native: ChunkedStruct, field: Field, /) -> ChunkedArrayAny: ...
@t.overload
Expand Down Expand Up @@ -1574,6 +1580,7 @@ def concat_str(
def concat_str(
*arrays: ArrowAny, separator: str = "", ignore_nulls: bool = False
) -> Arrow[StringScalar]:
"""Horizontally arrow data into a single string column."""
dtype = string_type(obj.type for obj in arrays)
it = (obj.cast(dtype) for obj in arrays)
concat: Incomplete = pc.binary_join_element_wise
Expand Down
155 changes: 155 additions & 0 deletions narwhals/_plan/arrow/pivot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from __future__ import annotations

import re
from itertools import chain
from typing import TYPE_CHECKING, Any, cast

import pyarrow.compute as pc

from narwhals._plan.arrow import (
acero,
compat,
functions as fn,
group_by,
options as pa_options,
)
from narwhals._plan.arrow.group_by import AggSpec
from narwhals._plan.common import temp
from narwhals._plan.expressions import aggregation as agg

if TYPE_CHECKING:
from collections.abc import Callable, Mapping, Sequence

import pyarrow as pa

from narwhals._plan.arrow.typing import ChunkedArray, StringScalar
from narwhals.typing import PivotAgg


SUPPORTED_PIVOT_AGG: Mapping[PivotAgg, type[agg.AggExpr]] = {
"min": agg.Min,
"max": agg.Max,
"first": agg.First,
"last": agg.Last,
"sum": agg.Sum,
"mean": agg.Mean,
"median": agg.Median,
"len": agg.Len,
}


def pivot_table(
native: pa.Table,
on: list[str],
on_columns: pa.Table,
/,
index: Sequence[str],
values: Sequence[str],
aggregate_function: PivotAgg | None,
separator: str,
) -> pa.Table:
"""Create a spreadsheet-style `pivot` table.

Supports multiple-`on` and aggregations.
"""
if len(on) == 1:
on_column = on_columns.column(0)
on_one = on[0]
target = native
else:
on_column = _format_on_columns_titles(on_columns)
on_one = temp.column_name(native.column_names)
target = acero.join_inner_tables(
native, on_columns.append_column(on_one, on_column), on
).drop(on)
if aggregate_function:
target = _aggregate(target, on_one, index, values, aggregate_function)
return _pivot(target, on_one, on_column.to_pylist(), index, values, separator)


def _format_on_columns_titles(on_columns: pa.Table, /) -> ChunkedArray[StringScalar]:
dtype = fn.string_type(on_columns.schema.types)
on_columns = fn.cast_table(on_columns, dtype)
parts = '{"', '"}', "", '","'
LB, RB, EMPTY, SEP = (fn.lit(s, dtype) for s in parts) # noqa: N806

# NOTE: Variation of https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.intersperse
seps = (SEP,) * on_columns.num_columns
interspersed = chain.from_iterable(zip(seps, on_columns.itercolumns()))
# skip the first separator, we just need the zip-terminating iterable to be the columns
next(interspersed)
func = "binary_join_element_wise"
args = [LB, *interspersed, RB, EMPTY]
opts = pa_options.join(ignore_nulls=False)
result: ChunkedArray[StringScalar] = pc.call_function(func, args, opts)
return result


def _replace_flatten_names(
column_names: list[str],
/,
on_columns_names: Sequence[str],
values: Sequence[str],
separator: str,
) -> list[str]:
"""Replace the separator used in unnested struct columns.

[`pa.Table.flatten`] *unconditionally* uses the separator `"."`, so we *likely* need to fix that here.

[`pa.Table.flatten`]: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.flatten
"""
if separator == ".":
return column_names
p_on_columns = "|".join(re.escape(name) for name in on_columns_names)
p_values = "|".join(re.escape(name) for name in values)
pattern = re.compile(rf"^(?P<on_column>{p_on_columns})\.(?P<value>{p_values})\Z")
repl = rf"\g<on_column>{separator}\g<value>"
return [pattern.sub(repl, s) for s in column_names]


def _pivot(
native: pa.Table,
on: str,
on_columns: Sequence[Any],
/,
index: Sequence[str],
values: Sequence[str],
separator: str,
) -> pa.Table:
"""Perform a single-`on`, non-aggregating `pivot`."""
options = _pivot_wider_options(on_columns)
specs = (AggSpec((on, name), "hash_pivot_wider", options, name) for name in values)
pivot = acero.group_by_table(native, index, specs)
flat = pivot.flatten()
if len(values) == 1:
names = [*index, *fn.struct_field_names(pivot.column(values[0]))]
else:
names = _replace_flatten_names(flat.column_names, values, on_columns, separator)
return flat.rename_columns(names)


def _aggregate(
native: pa.Table,
on: str,
/,
index: Sequence[str],
values: Sequence[str],
aggregate_function: PivotAgg,
) -> pa.Table:
tp_agg = SUPPORTED_PIVOT_AGG[aggregate_function]
agg_func = group_by.SUPPORTED_AGG[tp_agg]
option = pa_options.AGG.get(tp_agg)
specs = (AggSpec(value, agg_func, option) for value in values)
return acero.group_by_table(native, [*index, on], specs)


def _pivot_wider_options(on_columns: Sequence[Any]) -> pc.FunctionOptions:
"""Tries to wrap [`pc.PivotWiderOptions`], and raises if we're on an old `pyarrow`.

[`pc.PivotWiderOptions`]: https://arrow.apache.org/docs/python/generated/pyarrow.compute.PivotWiderOptions.html
"""
if compat.HAS_PIVOT_WIDER and (tp := getattr(pc, "PivotWiderOptions")): # noqa: B009
tp_options = cast("Callable[..., pc.FunctionOptions]", tp)
return tp_options(on_columns, unexpected_key_behavior="raise")
msg = f"`pivot` requires `pyarrow>=20`, got {compat.BACKEND_VERSION!r}"
raise NotImplementedError(msg)
4 changes: 2 additions & 2 deletions narwhals/_plan/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,14 @@ def flatten_hash_safe(iterable: Iterable[OneOrIterable[Any]], /) -> Iterator[Any
yield element


def _not_one_or_iterable_str_error(obj: Any, /) -> TypeError: # pragma: no cover
def _not_one_or_iterable_str_error(obj: Any, /) -> TypeError:
msg = f"Expected one or an iterable of strings, but got: {qualified_type_name(obj)!r}\n{obj!r}"
return TypeError(msg)


def ensure_seq_str(obj: OneOrIterable[str], /) -> Seq[str]:
if not isinstance(obj, Iterable):
raise _not_one_or_iterable_str_error(obj) # pragma: no cover
raise _not_one_or_iterable_str_error(obj)
return (obj,) if isinstance(obj, str) else tuple(obj)


Expand Down
16 changes: 15 additions & 1 deletion narwhals/_plan/compliant/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl
from narwhals._utils import Implementation, Version
from narwhals.dtypes import DType
from narwhals.typing import IntoSchema, UniqueKeepStrategy
from narwhals.typing import IntoSchema, PivotAgg, UniqueKeepStrategy

Incomplete: TypeAlias = Any

Expand Down Expand Up @@ -152,6 +152,10 @@ def __narwhals_dataframe__(self) -> Self:
def lazy(self, backend: _LazyAllowedImpl | None, **kwds: Any) -> LazyFrameAny: ...
@property
def shape(self) -> tuple[int, int]: ...
@property
def width(self) -> int:
return self.shape[-1]

def __len__(self) -> int: ...
@property
def _group_by(self) -> type[DataFrameGroupBy[Self]]: ...
Expand Down Expand Up @@ -222,6 +226,16 @@ def join_cross(self, other: Self, *, suffix: str = "_right") -> Self: ...
def partition_by(
self, by: Sequence[str], *, include_key: bool = True
) -> list[Self]: ...
def pivot(
self,
on: Sequence[str],
on_columns: Self,
*,
index: Sequence[str],
values: Sequence[str],
aggregate_function: PivotAgg | None = None,
separator: str = "_",
) -> Self: ...
def row(self, index: int) -> tuple[Any, ...]: ...
@overload
def to_dict(self, *, as_series: Literal[True]) -> dict[str, SeriesT]: ...
Expand Down
Loading
Loading