Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
1dfc6eb
feat(expr-ir): Support `over(*partition_by)`
dangotbanned Oct 18, 2025
56c6049
test: Start porting `over_test.py`
dangotbanned Oct 18, 2025
dd62ae8
test: hmm any multi-selection
dangotbanned Oct 18, 2025
ad47ac3
test: Add some failing cases
dangotbanned Oct 18, 2025
c4d4030
fix(expr-ir): Ensure nested nodes expand correctly
dangotbanned Oct 19, 2025
bf7fdc2
test: Update test that caught a different bug 😅
dangotbanned Oct 19, 2025
d21ae34
test: Another case
dangotbanned Oct 19, 2025
88dfdbc
feat(expr-ir): Add dedicated selectors for problem children
dangotbanned Oct 19, 2025
ecde261
fix(expr-ir): Ensure `by_dtype` handles bare parametric types
dangotbanned Oct 19, 2025
a2d5b2e
feat: Support `diff`, `shift`
dangotbanned Oct 19, 2025
5fc4075
feat(expr-ir): Support anonymous reductions in `over`
dangotbanned Oct 19, 2025
2a4acd7
feat(expr-ir): Partial `cum_sum` support
dangotbanned Oct 19, 2025
d97d047
feat(expr-ir): Rinse/repeat other `cum_*`
dangotbanned Oct 19, 2025
b49a4f4
Merge remote-tracking branch 'upstream/oh-nodes' into expr-ir/over-an…
dangotbanned Oct 20, 2025
34cac04
simply document existing issue
dangotbanned Oct 20, 2025
79056d7
diff -> kernel, add some typing
dangotbanned Oct 20, 2025
1f830bc
test: remove unused xfail
dangotbanned Oct 20, 2025
aedc330
shift -> kernel, add fancy test
dangotbanned Oct 20, 2025
040d377
test: Allow the exception difference *for now*
dangotbanned Oct 20, 2025
ab7330a
feat(expr-ir): Add a concrete impl for `cs.by_name`
dangotbanned Oct 21, 2025
15c87ea
feat(expr-ir): Add `meta.as_selector`, `parse_into_selector_ir`
dangotbanned Oct 21, 2025
a00dbb7
more `partition_by` prep
dangotbanned Oct 21, 2025
e0d1a00
feat(expr-ir): Implement `ArrowDataFrame.partition_by`
dangotbanned Oct 21, 2025
2bffdaa
test: Add `test_partition_by_multiple`
dangotbanned Oct 21, 2025
f17781a
test: Include `None` in partitions
dangotbanned Oct 22, 2025
ac779dd
perf: Add an optimized path for single-column `partition_by`
dangotbanned Oct 22, 2025
c4d494a
refactor: Re-use `partition_by` in `ArrowGroupBy.__iter__`
dangotbanned Oct 22, 2025
e55aeb0
refactor: Move `partition_by` impl to `group_by.py`
dangotbanned Oct 22, 2025
ae09fc1
refactor: Rename `concat_str` -> `_composite_key` and lightly doc
dangotbanned Oct 22, 2025
9810b73
test: Add some more targets for `polars`-parity
dangotbanned Oct 22, 2025
6d219f4
fix: raise on empty `by`
dangotbanned Oct 22, 2025
41d8cc2
feat(DRAFT): Add acero `union` wrapper
dangotbanned Oct 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 93 additions & 29 deletions narwhals/_plan/_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@

from collections import deque
from functools import lru_cache
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from narwhals._plan import common, meta
from narwhals._plan._guards import is_horizontal_reduction
from narwhals._plan import common, expressions as ir, meta
from narwhals._plan._guards import is_horizontal_reduction, is_window_expr
from narwhals._plan._immutable import Immutable
from narwhals._plan.exceptions import (
column_index_error,
Expand Down Expand Up @@ -72,6 +72,7 @@
IntoFrozenSchema,
freeze_schema,
)
from narwhals._utils import check_column_names_are_unique
from narwhals.dtypes import DType
from narwhals.exceptions import ComputeError, InvalidOperationError

Expand Down Expand Up @@ -156,7 +157,7 @@ def with_multiple_columns(self) -> ExpansionFlags:
def prepare_projection(
exprs: Sequence[ExprIR], /, keys: GroupByKeys = (), *, schema: IntoFrozenSchema
) -> tuple[Seq[NamedIR], FrozenSchema]:
"""Expand IRs into named column selections.
"""Expand IRs into named column projections.
**Primary entry-point**, for `select`, `with_columns`,
and any other context that requires resolving expression names.
Expand All @@ -173,13 +174,33 @@ def prepare_projection(
return named_irs, frozen_schema


def expand_selector_irs_names(
selectors: Sequence[SelectorIR],
/,
keys: GroupByKeys = (),
*,
schema: IntoFrozenSchema,
) -> OutputNames:
"""Expand selector-only input into the column names that match.
Similar to `prepare_projection`, but intended for allowing a subset of `Expr` and all `Selector`s
to be used in more places like `DataFrame.{drop,sort,partition_by}`.
Arguments:
selectors: IRs that **only** contain subclasses of `SelectorIR`.
keys: Names of `group_by` columns.
schema: Scope to expand multi-column selectors in.
"""
frozen_schema = freeze_schema(schema)
names = tuple(_iter_expand_selector_names(selectors, keys, schema=frozen_schema))
return _ensure_valid_output_names(names, frozen_schema)


def into_named_irs(exprs: Seq[ExprIR], names: OutputNames) -> Seq[NamedIR]:
if len(exprs) != len(names):
msg = f"zip length mismatch: {len(exprs)} != {len(names)}"
raise ValueError(msg)
return tuple(
NamedIR(expr=remove_alias(ir), name=name) for ir, name in zip(exprs, names)
)
return tuple(ir.named_ir(name, remove_alias(e)) for e, name in zip(exprs, names))


def ensure_valid_exprs(exprs: Seq[ExprIR], schema: FrozenSchema) -> OutputNames:
Expand All @@ -191,19 +212,80 @@ def ensure_valid_exprs(exprs: Seq[ExprIR], schema: FrozenSchema) -> OutputNames:
return output_names


def _ensure_valid_output_names(names: Seq[str], schema: FrozenSchema) -> OutputNames:
"""Selector-only variant of `ensure_valid_exprs`."""
check_column_names_are_unique(names)
output_names = names
if not (set(schema.names).issuperset(output_names)):
raise column_not_found_error(output_names, schema)
return output_names


def _ensure_output_names_unique(exprs: Seq[ExprIR]) -> OutputNames:
names = tuple(e.meta.output_name() for e in exprs)
if len(names) != len(set(names)):
raise duplicate_error(exprs)
return names


def expand_function_inputs(origin: ExprIR, /, *, schema: FrozenSchema) -> ExprIR:
def _ensure_columns(expr: ExprIR, /) -> Columns:
if not isinstance(expr, Columns):
msg = f"Expected only column selections here, but got {expr!r}"
raise NotImplementedError(msg)
return expr


def _iter_expand_selector_names(
selectors: Iterable[SelectorIR], /, keys: GroupByKeys = (), *, schema: FrozenSchema
) -> Iterator[str]:
for selector in selectors:
names = _ensure_columns(replace_selector(selector, schema=schema)).names
if keys:
yield from (name for name in names if name not in keys)
else:
yield from names


# NOTE: Recursive for all `input` expressions which themselves contain `Seq[ExprIR]`
def rewrite_projections(
input: Seq[ExprIR], /, keys: GroupByKeys = (), *, schema: FrozenSchema
) -> Seq[ExprIR]:
result: deque[ExprIR] = deque()
for expr in input:
expanded = _expand_nested_nodes(expr, schema=schema)
flags = ExpansionFlags.from_ir(expanded)
if flags.has_selector:
expanded = replace_selector(expanded, schema=schema)
flags = flags.with_multiple_columns()
result.extend(iter_replace(expanded, keys, col_names=schema.names, flags=flags))
return tuple(result)


def _expand_nested_nodes(origin: ExprIR, /, *, schema: FrozenSchema) -> ExprIR:
"""Adapted from [`expand_function_inputs`].
Added additional cases for nodes that *also* need to be expanded in the same way.
[`expand_function_inputs`]: https://github.com/pola-rs/polars/blob/df4d21c30c2b383b651e194f8263244f2afaeda3/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L557-L581
"""
rewrite = rewrite_projections

def fn(child: ExprIR, /) -> ExprIR:
if not isinstance(child, (ir.FunctionExpr, ir.WindowExpr, ir.SortBy)):
return child
expanded: dict[str, Any] = {}
if is_horizontal_reduction(child):
rewrites = rewrite_projections(child.input, schema=schema)
return common.replace(child, input=rewrites)
return child
expanded["input"] = rewrite(child.input, schema=schema)
elif is_window_expr(child):
if partition_by := child.partition_by:
expanded["partition_by"] = rewrite(partition_by, schema=schema)
if isinstance(child, ir.OrderedWindowExpr):
expanded["order_by"] = rewrite(child.order_by, schema=schema)
elif isinstance(child, ir.SortBy):
expanded["by"] = rewrite(child.by, schema=schema)
if not expanded:
return child
return common.replace(child, **expanded)

return origin.map_ir(fn)

Expand Down Expand Up @@ -271,24 +353,6 @@ def expand_selector(selector: SelectorIR, schema: FrozenSchema) -> Columns:
return cols(*(k for k, v in schema.items() if matches(selector, k, v)))


def rewrite_projections(
input: Seq[ExprIR], # `FunctionExpr.input`
/,
keys: GroupByKeys = (),
*,
schema: FrozenSchema,
) -> Seq[ExprIR]:
result: deque[ExprIR] = deque()
for expr in input:
expanded = expand_function_inputs(expr, schema=schema)
flags = ExpansionFlags.from_ir(expanded)
if flags.has_selector:
expanded = replace_selector(expanded, schema=schema)
flags = flags.with_multiple_columns()
result.extend(iter_replace(expanded, keys, col_names=schema.names, flags=flags))
return tuple(result)


def iter_replace(
origin: ExprIR,
/,
Expand Down
8 changes: 8 additions & 0 deletions narwhals/_plan/_expr_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from narwhals._plan.common import replace
from narwhals._plan.options import ExprIROptions
from narwhals._plan.typing import ExprIRT
from narwhals.exceptions import InvalidOperationError
from narwhals.utils import Version

if TYPE_CHECKING:
Expand Down Expand Up @@ -59,6 +60,10 @@ def to_narwhals(self, version: Version = Version.MAIN) -> Expr:
tp = expr.Expr if version is Version.MAIN else expr.ExprV1
return tp._from_ir(self)

def to_selector_ir(self) -> SelectorIR:
msg = f"cannot turn `{self!r}` into a selector"
raise InvalidOperationError(msg)

@property
def is_scalar(self) -> bool:
return False
Expand Down Expand Up @@ -201,6 +206,9 @@ def matches_column(self, name: str, dtype: DType) -> bool:
"""
raise NotImplementedError(type(self))

def to_selector_ir(self) -> Self:
return self


class NamedIR(Immutable, Generic[ExprIRT]):
"""Post-projection expansion wrapper for `ExprIR`.
Expand Down
20 changes: 18 additions & 2 deletions narwhals/_plan/_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@
if TYPE_CHECKING:
from typing_extensions import TypeIs

from narwhals._plan import expressions as ir
from narwhals._plan import expr, expressions as ir
from narwhals._plan.compliant.series import CompliantSeries
from narwhals._plan.expr import Expr
from narwhals._plan.series import Series
from narwhals._plan.typing import IntoExprColumn, NativeSeriesT, Seq
from narwhals._plan.typing import (
ColumnNameOrSelector,
IntoExprColumn,
NativeSeriesT,
Seq,
)
from narwhals.typing import NonNestedLiteral

T = TypeVar("T")
Expand Down Expand Up @@ -58,6 +63,10 @@ def is_expr(obj: Any) -> TypeIs[Expr]:
return isinstance(obj, _expr().Expr)


def is_selector(obj: Any) -> TypeIs[expr.Selector]:
return isinstance(obj, _expr().Selector)


def is_column(obj: Any) -> TypeIs[Expr]:
"""Indicate if the given object is a basic/unaliased column."""
return is_expr(obj) and obj.meta.is_column()
Expand All @@ -71,6 +80,13 @@ def is_into_expr_column(obj: Any) -> TypeIs[IntoExprColumn]:
return isinstance(obj, (str, _expr().Expr, _series().Series))


def is_column_name_or_selector(
obj: Any, *, allow_expr: bool = False
) -> TypeIs[ColumnNameOrSelector]:
tps = (str, _expr().Selector) if not allow_expr else (str, _expr().Expr)
return isinstance(obj, tps)


def is_compliant_series(
obj: CompliantSeries[NativeSeriesT] | Any,
) -> TypeIs[CompliantSeries[NativeSeriesT]]:
Expand Down
58 changes: 56 additions & 2 deletions narwhals/_plan/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,19 @@
from itertools import chain
from typing import TYPE_CHECKING

from narwhals._plan._guards import is_expr, is_into_expr_column, is_iterable_reject
from narwhals._plan._guards import (
is_column_name_or_selector,
is_expr,
is_into_expr_column,
is_iterable_reject,
is_selector,
)
from narwhals._plan.exceptions import (
invalid_into_expr_error,
is_iterable_pandas_error,
is_iterable_polars_error,
)
from narwhals._utils import qualified_type_name
from narwhals.dependencies import get_polars, is_pandas_dataframe, is_pandas_series
from narwhals.exceptions import InvalidOperationError

Expand All @@ -22,8 +29,10 @@
import polars as pl
from typing_extensions import TypeAlias, TypeIs

from narwhals._plan.expressions import ExprIR
from narwhals._plan.expr import Expr
from narwhals._plan.expressions import ExprIR, SelectorIR
from narwhals._plan.typing import (
ColumnNameOrSelector,
IntoExpr,
IntoExprColumn,
OneOrIterable,
Expand Down Expand Up @@ -124,6 +133,23 @@ def parse_into_expr_ir(
return expr._ir


# NOTE: Might need to add `require_all`, since selectors are created indirectly from `str`
# here, but use set semantics
def parse_into_selector_ir(input: ColumnNameOrSelector | Expr, /) -> SelectorIR:
if is_selector(input):
selector = input
elif isinstance(input, str):
from narwhals._plan import selectors as cs

selector = cs.by_name(input)
elif is_expr(input):
selector = input.meta.as_selector()
else:
msg = f"cannot turn {qualified_type_name(input)!r} into selector"
raise TypeError(msg)
return selector._ir


def parse_into_seq_of_expr_ir(
first_input: OneOrIterable[IntoExpr] = (),
*more_inputs: IntoExpr | _RaisesInvalidIntoExprError,
Expand Down Expand Up @@ -175,6 +201,34 @@ def _parse_sort_by_into_iter_expr_ir(
yield e


def parse_into_seq_of_selector_ir(
first_input: OneOrIterable[ColumnNameOrSelector], *more_inputs: ColumnNameOrSelector
) -> Seq[SelectorIR]:
return tuple(_parse_into_iter_selector_ir(first_input, more_inputs))


def _parse_into_iter_selector_ir(
first_input: OneOrIterable[ColumnNameOrSelector],
more_inputs: tuple[ColumnNameOrSelector, ...],
/,
) -> Iterator[SelectorIR]:
if is_column_name_or_selector(first_input) and not more_inputs:
yield parse_into_selector_ir(first_input)
return

if not _is_empty_sequence(first_input):
if _is_iterable(first_input) and not isinstance(first_input, str):
if more_inputs:
raise invalid_into_expr_error(first_input, more_inputs, {})
else:
for into in first_input: # type: ignore[var-annotated]
yield parse_into_selector_ir(into)
else:
yield parse_into_selector_ir(first_input)
for into in more_inputs:
yield parse_into_selector_ir(into)


def _parse_into_iter_expr_ir(
first_input: OneOrIterable[IntoExpr],
*more_inputs: IntoExpr | list[Any],
Expand Down
9 changes: 9 additions & 0 deletions narwhals/_plan/arrow/acero.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,15 @@ def prepend_column(native: pa.Table, name: str, values: IntoExpr) -> Decl:
return _add_column(native, 0, name, values)


def _union(declarations: Iterable[Decl], /) -> Decl:
"""[`union`] merges multiple data streams with the same schema into one, similar to a `SQL UNION ALL` clause.
[`union`]: https://arrow.apache.org/docs/cpp/acero/user_guide.html#union
"""
decls: Incomplete = declarations
return Decl("union", pac.ExecNodeOptions(), decls)


def _order_by(
sort_keys: Iterable[tuple[str, Order]] = (),
*,
Expand Down
Loading
Loading