From 3cbea9203cbdb291e2d2550dfa2fafd5d845cc38 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 14 May 2025 11:02:41 +0100 Subject: [PATCH 001/368] add `_nodes` package - Not stressing on the name for now - Just want to make a start on modelling ops --- narwhals/_nodes/__init__.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 narwhals/_nodes/__init__.py diff --git a/narwhals/_nodes/__init__.py b/narwhals/_nodes/__init__.py new file mode 100644 index 0000000000..102e5e4b7f --- /dev/null +++ b/narwhals/_nodes/__init__.py @@ -0,0 +1,23 @@ +"""Brainstorming an `Expr` internal node represention. + +References: + - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs + - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs + - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/function_expr/mod.rs + - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/options/mod.rs#L137-L172 + - https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs#L35-L106 + - https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs#L131-L236 + - https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs#L240-L267 + +Related: + - https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2866902903 + - https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2867331343 + - https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2867446959 + - https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2869070157 + - https://github.com/narwhals-dev/narwhals/pull/2538/commits/a7eeb0d23e67cb70e7cfa73cec2c7b69a15c8bef#r2083562677 + - https://github.com/narwhals-dev/narwhals/issues/2225 + - https://github.com/narwhals-dev/narwhals/issues/1848 + - https://github.com/narwhals-dev/narwhals/issues/2534#issuecomment-2875676729 +""" + +from __future__ import annotations From 1672c2b71f554fce94ea0095c6617366bf65b86a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 14 May 2025 12:01:43 +0100 Subject: [PATCH 002/368] docs: Add some notes --- narwhals/_nodes/__init__.py | 46 ++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/narwhals/_nodes/__init__.py b/narwhals/_nodes/__init__.py index 102e5e4b7f..e138a41b36 100644 --- a/narwhals/_nodes/__init__.py +++ b/narwhals/_nodes/__init__.py @@ -1,23 +1,37 @@ -"""Brainstorming an `Expr` internal node represention. +"""Brainstorming an `Expr` internal represention. + +Notes: +- Each `Expr` method should be representable by a single node + - But the node does not need to be unique to the method +- A chain of `Expr` methods should form a plan of operations +- We must be able to enforce rules on what plans are permitted: + - Must be flexible to both eager/lazy and invdividual backends + - Must be flexible to a given context (select, with_columns, filter, group_by) +- Nodes & plans are: + - Immutable, but + - Can be extended/re-written at both the Narwhals & Compliant levels + - Introspectable, but + - Store as little-as-needed for the common case + - Provide properties/methods for computing the less frequent metadata References: - - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs - - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs - - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/function_expr/mod.rs - - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/options/mod.rs#L137-L172 - - https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs#L35-L106 - - https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs#L131-L236 - - https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs#L240-L267 +- https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs +- https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs +- https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/function_expr/mod.rs +- https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/options/mod.rs#L137-L172 +- https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs#L35-L106 +- https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs#L131-L236 +- https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs#L240-L267 Related: - - https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2866902903 - - https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2867331343 - - https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2867446959 - - https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2869070157 - - https://github.com/narwhals-dev/narwhals/pull/2538/commits/a7eeb0d23e67cb70e7cfa73cec2c7b69a15c8bef#r2083562677 - - https://github.com/narwhals-dev/narwhals/issues/2225 - - https://github.com/narwhals-dev/narwhals/issues/1848 - - https://github.com/narwhals-dev/narwhals/issues/2534#issuecomment-2875676729 +- https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2866902903 +- https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2867331343 +- https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2867446959 +- https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2869070157 +- https://github.com/narwhals-dev/narwhals/pull/2538/commits/a7eeb0d23e67cb70e7cfa73cec2c7b69a15c8bef#r2083562677 +- https://github.com/narwhals-dev/narwhals/issues/2225 +- https://github.com/narwhals-dev/narwhals/issues/1848 +- https://github.com/narwhals-dev/narwhals/issues/2534#issuecomment-2875676729 """ from __future__ import annotations From c7bdd71394b6319ef0bf16dabf3cbe5684cca785 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 14 May 2025 12:03:33 +0100 Subject: [PATCH 003/368] chore: Rename `_nodes` -> `_plan` Thinking it might go: `Op` -> `Node` -> `Plan` But who knows really --- narwhals/{_nodes => _plan}/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename narwhals/{_nodes => _plan}/__init__.py (100%) diff --git a/narwhals/_nodes/__init__.py b/narwhals/_plan/__init__.py similarity index 100% rename from narwhals/_nodes/__init__.py rename to narwhals/_plan/__init__.py From b7ecdaf0c5d7b65363ed229e47ff30421c138e1a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 14 May 2025 14:17:52 +0100 Subject: [PATCH 004/368] feat: Start building out the ops Just trying to get the names & hierarchies done --- narwhals/_plan/aggregation.py | 44 +++++++++++++++++++++++++++ narwhals/_plan/boolean.py | 49 ++++++++++++++++++++++++++++++ narwhals/_plan/common.py | 11 +++++++ narwhals/_plan/expr.py | 36 ++++++++++++++++++++++ narwhals/_plan/operators.py | 48 ++++++++++++++++++++++++++++++ narwhals/_plan/strings.py | 56 +++++++++++++++++++++++++++++++++++ narwhals/_plan/temporal.py | 7 +++++ narwhals/_plan/window.py | 3 ++ 8 files changed, 254 insertions(+) create mode 100644 narwhals/_plan/aggregation.py create mode 100644 narwhals/_plan/boolean.py create mode 100644 narwhals/_plan/common.py create mode 100644 narwhals/_plan/expr.py create mode 100644 narwhals/_plan/operators.py create mode 100644 narwhals/_plan/strings.py create mode 100644 narwhals/_plan/temporal.py create mode 100644 narwhals/_plan/window.py diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py new file mode 100644 index 0000000000..974224e69b --- /dev/null +++ b/narwhals/_plan/aggregation.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from narwhals._plan.common import ExprIR + + +class AggExpr(ExprIR): ... + + +class Count(AggExpr): ... + + +class First(AggExpr): + """https://github.com/narwhals-dev/narwhals/issues/2526.""" + + +class Last(AggExpr): + """https://github.com/narwhals-dev/narwhals/issues/2526.""" + + +class Max(AggExpr): ... + + +class Mean(AggExpr): ... + + +class Median(AggExpr): ... + + +class Min(AggExpr): ... + + +class NUnique(AggExpr): ... + + +class Quantile(AggExpr): ... + + +class Std(AggExpr): ... + + +class Sum(AggExpr): ... + + +class Var(AggExpr): ... diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/boolean.py new file mode 100644 index 0000000000..140a30f194 --- /dev/null +++ b/narwhals/_plan/boolean.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +# NOTE: Needed to avoid naming collisions +# - Any +import typing as t # noqa: F401 + +from narwhals._plan.common import Function + + +class BooleanFunction(Function): ... + + +class All(BooleanFunction): ... + + +class AllHorizontal(BooleanFunction): ... + + +class Any(BooleanFunction): ... + + +class AnyHorizontal(BooleanFunction): ... + + +class IsBetween(BooleanFunction): ... + + +class IsDuplicated(BooleanFunction): ... + + +class IsFinite(BooleanFunction): ... + + +class IsFirstDistinct(BooleanFunction): ... + + +class IsIn(BooleanFunction): ... + + +class IsLastDistinct(BooleanFunction): ... + + +class IsNan(BooleanFunction): ... + + +class IsNull(BooleanFunction): ... + + +class IsUnique(BooleanFunction): ... diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py new file mode 100644 index 0000000000..4760647261 --- /dev/null +++ b/narwhals/_plan/common.py @@ -0,0 +1,11 @@ +from __future__ import annotations + + +class ExprIR: ... + + +class Function(ExprIR): + """Shared by expr functions and namespace functions.""" + + +class FunctionExpr(ExprIR): ... diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py new file mode 100644 index 0000000000..e77056c495 --- /dev/null +++ b/narwhals/_plan/expr.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +# NOTE: Needed to avoid naming collisions +# - Literal +import typing as t # noqa: F401 + +from narwhals._plan.common import ExprIR + + +class Alias(ExprIR): ... + + +class Column(ExprIR): ... + + +class Literal(ExprIR): ... + + +class BinaryExpr(ExprIR): + """Seems like the application of two exprs via an `Operator`.""" + + +class Cast(ExprIR): ... + + +class Sort(ExprIR): ... + + +class SortBy(ExprIR): + """https://github.com/narwhals-dev/narwhals/issues/2534.""" + + +class Filter(ExprIR): ... + + +class Len(ExprIR): ... diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py new file mode 100644 index 0000000000..c355115f49 --- /dev/null +++ b/narwhals/_plan/operators.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from narwhals._plan.common import ExprIR + + +class Operator(ExprIR): ... + + +class Eq(Operator): ... + + +class NotEq(Operator): ... + + +class Lt(Operator): ... + + +class LtEq(Operator): ... + + +class Gt(Operator): ... + + +class GtEq(Operator): ... + + +class Add(Operator): ... + + +class Sub(Operator): ... + + +class Multiply(Operator): ... + + +class TrueDivide(Operator): ... + + +class FloorDivide(Operator): ... + + +class Modulus(Operator): ... + + +class And(Operator): ... + + +class Or(Operator): ... diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/strings.py new file mode 100644 index 0000000000..57781b5e2f --- /dev/null +++ b/narwhals/_plan/strings.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from narwhals._plan.common import Function + + +class StringFunction(Function): ... + + +class Contains(StringFunction): ... + + +class EndsWith(StringFunction): ... + + +class Head(StringFunction): ... + + +class LenChars(StringFunction): ... + + +class Replace(StringFunction): ... + + +class ReplaceAll(StringFunction): + """`polars` uses a single node for this and `Replace`. + + https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/function_expr/strings.rs#L65-L70 + """ + + +class Slice(StringFunction): ... + + +class Split(StringFunction): ... + + +class StartsWith(StringFunction): ... + + +class StripChars(StringFunction): ... + + +class Tail(StringFunction): ... + + +class ToDatetime(StringFunction): + """`polars` uses `Strptime`. + + https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/function_expr/strings.rs#L112 + """ + + +class ToLowercase(StringFunction): ... + + +class ToUppercase(StringFunction): ... diff --git a/narwhals/_plan/temporal.py b/narwhals/_plan/temporal.py new file mode 100644 index 0000000000..348342db5a --- /dev/null +++ b/narwhals/_plan/temporal.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +from narwhals._plan.common import Function + + +# TODO @dangotbanned: Fill out +class TemporalFunction(Function): ... diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py new file mode 100644 index 0000000000..e782f3306c --- /dev/null +++ b/narwhals/_plan/window.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +# TODO @dangotbanned: Investigate From 8166191d34d7ff3b428939e11cf1f3996ee46b2c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 14 May 2025 14:29:06 +0100 Subject: [PATCH 005/368] feat: Fill out `temporal` --- narwhals/_plan/temporal.py | 67 +++++++++++++++++++++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/temporal.py b/narwhals/_plan/temporal.py index 348342db5a..4b1f873412 100644 --- a/narwhals/_plan/temporal.py +++ b/narwhals/_plan/temporal.py @@ -3,5 +3,70 @@ from narwhals._plan.common import Function -# TODO @dangotbanned: Fill out class TemporalFunction(Function): ... + + +class Date(TemporalFunction): ... + + +class Year(TemporalFunction): ... + + +class Month(TemporalFunction): ... + + +class Day(TemporalFunction): ... + + +class Hour(TemporalFunction): ... + + +class Minute(TemporalFunction): ... + + +class Second(TemporalFunction): ... + + +class Millisecond(TemporalFunction): ... + + +class Microsecond(TemporalFunction): ... + + +class Nanosecond(TemporalFunction): ... + + +class OrdinalDay(TemporalFunction): ... + + +class WeekDay(TemporalFunction): ... + + +class TotalMinutes(TemporalFunction): ... + + +class TotalSeconds(TemporalFunction): ... + + +class TotalMilliseconds(TemporalFunction): ... + + +class TotalMicroseconds(TemporalFunction): ... + + +class TotalNanoseconds(TemporalFunction): ... + + +class ToString(TemporalFunction): ... + + +class ReplaceTimeZone(TemporalFunction): ... + + +class ConvertTimeZone(TemporalFunction): ... + + +class Timestamp(TemporalFunction): ... + + +class Truncate(TemporalFunction): ... From b1d8c2ebb707315f56605e21e03005c233195348 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 14 May 2025 14:58:38 +0100 Subject: [PATCH 006/368] feat: Add `functions` + misc missing --- narwhals/_plan/functions.py | 83 +++++++++++++++++++++++++++++++++++++ narwhals/_plan/operators.py | 4 ++ narwhals/_plan/strings.py | 4 ++ 3 files changed, 91 insertions(+) create mode 100644 narwhals/_plan/functions.py diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py new file mode 100644 index 0000000000..f8a2ced9ba --- /dev/null +++ b/narwhals/_plan/functions.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from narwhals._plan.common import FunctionExpr + + +class Abs(FunctionExpr): ... + + +class Hist(FunctionExpr): + """Only supported for `Series` so far.""" + + +class NullCount(FunctionExpr): ... + + +class Pow(FunctionExpr): ... + + +class FillNull(FunctionExpr): ... + + +class FillNullWithStrategy(FunctionExpr): + """We don't support this variant in a lot of backends, so worth keeping it split out.""" + + +class Shift(FunctionExpr): ... + + +class DropNulls(FunctionExpr): ... + + +class Mode(FunctionExpr): ... + + +class Skew(FunctionExpr): ... + + +class Rank(FunctionExpr): ... + + +class Clip(FunctionExpr): ... + + +class CumCount(FunctionExpr): ... + + +class CumMin(FunctionExpr): ... + + +class CumMax(FunctionExpr): ... + + +class CumProd(FunctionExpr): ... + + +class Diff(FunctionExpr): ... + + +class Unique(FunctionExpr): ... + + +class Round(FunctionExpr): ... + + +class SumHorizontal(FunctionExpr): ... + + +class MinHorizontal(FunctionExpr): ... + + +class MaxHorizontal(FunctionExpr): ... + + +class MeanHorizontal(FunctionExpr): ... + + +class EwmMean(FunctionExpr): ... + + +class ReplaceStrict(FunctionExpr): ... + + +class GatherEvery(FunctionExpr): ... diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index c355115f49..d11ce2cc23 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -46,3 +46,7 @@ class And(Operator): ... class Or(Operator): ... + + +class Not(Operator): + """`__invert__`.""" diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/strings.py index 57781b5e2f..54524892b0 100644 --- a/narwhals/_plan/strings.py +++ b/narwhals/_plan/strings.py @@ -6,6 +6,10 @@ class StringFunction(Function): ... +class ConcatHorizontal(StringFunction): + """`concat_str`.""" + + class Contains(StringFunction): ... From 6c8b7db661f4e877ec091efdd778200e5dc79c5f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 14 May 2025 15:07:48 +0100 Subject: [PATCH 007/368] feat: Add some `window` --- narwhals/_plan/window.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py index e782f3306c..2651bc3a89 100644 --- a/narwhals/_plan/window.py +++ b/narwhals/_plan/window.py @@ -1,3 +1,24 @@ from __future__ import annotations -# TODO @dangotbanned: Investigate +from narwhals._plan.common import ExprIR + + +class Window(ExprIR): ... + + +class OverWindow(Window): ... + + +class RollingWindow(Window): ... + + +class RollingSum(RollingWindow): ... + + +class RollingMean(RollingWindow): ... + + +class RollingVar(RollingWindow): ... + + +class RollingStd(RollingWindow): ... From 534c3f799e467c32f600cd11316e5aa2f981cf69 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 14 May 2025 15:20:43 +0100 Subject: [PATCH 008/368] feat: Add `OrderableAgg` --- narwhals/_plan/aggregation.py | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index 974224e69b..2b0857f45d 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -3,42 +3,51 @@ from narwhals._plan.common import ExprIR -class AggExpr(ExprIR): ... +class Agg(ExprIR): ... -class Count(AggExpr): ... +class Count(Agg): ... -class First(AggExpr): +class First(Agg): """https://github.com/narwhals-dev/narwhals/issues/2526.""" -class Last(AggExpr): +class Last(Agg): """https://github.com/narwhals-dev/narwhals/issues/2526.""" -class Max(AggExpr): ... +class Max(Agg): ... -class Mean(AggExpr): ... +class Mean(Agg): ... -class Median(AggExpr): ... +class Median(Agg): ... -class Min(AggExpr): ... +class Min(Agg): ... -class NUnique(AggExpr): ... +class NUnique(Agg): ... -class Quantile(AggExpr): ... +class Quantile(Agg): ... -class Std(AggExpr): ... +class Std(Agg): ... -class Sum(AggExpr): ... +class Sum(Agg): ... -class Var(AggExpr): ... +class Var(Agg): ... + + +class OrderableAgg(Agg): ... + + +class ArgMin(OrderableAgg): ... + + +class ArgMax(OrderableAgg): ... From f5ef0afc8c755be54c818307ec5ee7c3a1e7a7fe Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 14 May 2025 15:29:33 +0100 Subject: [PATCH 009/368] feat: Start adding col/selectors --- narwhals/_plan/expr.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index e77056c495..f1c33fc78f 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -34,3 +34,16 @@ class Filter(ExprIR): ... class Len(ExprIR): ... + + +class Exclude(ExprIR): ... + + +class Nth(ExprIR): ... + + +class All(ExprIR): ... + + +# NOTE: by_dtype, matches, numeric, boolean, string, categorical, datetime, all +class Selector(ExprIR): ... From 770d770b9a53be0af233a638d61b4ee6eb2bd6f8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 14 May 2025 15:31:30 +0100 Subject: [PATCH 010/368] add `MapBatches` --- narwhals/_plan/functions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index f8a2ced9ba..4ff5317253 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -81,3 +81,6 @@ class ReplaceStrict(FunctionExpr): ... class GatherEvery(FunctionExpr): ... + + +class MapBatches(FunctionExpr): ... From 18c00b46ccd52345ce705a944af608f9d3127923 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 14 May 2025 15:44:02 +0100 Subject: [PATCH 011/368] feat: Add some namespaces --- narwhals/_plan/categorical.py | 9 +++++++++ narwhals/_plan/lists.py | 9 +++++++++ narwhals/_plan/struct.py | 9 +++++++++ 3 files changed, 27 insertions(+) create mode 100644 narwhals/_plan/categorical.py create mode 100644 narwhals/_plan/lists.py create mode 100644 narwhals/_plan/struct.py diff --git a/narwhals/_plan/categorical.py b/narwhals/_plan/categorical.py new file mode 100644 index 0000000000..fc43b522ce --- /dev/null +++ b/narwhals/_plan/categorical.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from narwhals._plan.common import Function + + +class CategoricalFunction(Function): ... + + +class GetCategories(CategoricalFunction): ... diff --git a/narwhals/_plan/lists.py b/narwhals/_plan/lists.py new file mode 100644 index 0000000000..7368d41010 --- /dev/null +++ b/narwhals/_plan/lists.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from narwhals._plan.common import Function + + +class ListFunction(Function): ... + + +class Len(ListFunction): ... diff --git a/narwhals/_plan/struct.py b/narwhals/_plan/struct.py new file mode 100644 index 0000000000..cbcbfb3652 --- /dev/null +++ b/narwhals/_plan/struct.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from narwhals._plan.common import Function + + +class StructFunction(Function): ... + + +class FieldByName(StructFunction): ... From f637ecd8c9d69e6979f07f06b89f24bf531aefdf Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 14 May 2025 16:08:15 +0100 Subject: [PATCH 012/368] feat: Add `name` --- narwhals/_plan/name.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 narwhals/_plan/name.py diff --git a/narwhals/_plan/name.py b/narwhals/_plan/name.py new file mode 100644 index 0000000000..393e6e37dc --- /dev/null +++ b/narwhals/_plan/name.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from narwhals._plan.common import Function + + +class NameFunction(Function): + """`polars` version doesn't represent in the same way here. + + https://github.com/pola-rs/polars/blob/6df23a09a81c640c21788607611e09d9f43b1abc/crates/polars-plan/src/dsl/name.rs + """ + + +class Keep(NameFunction): ... + + +class Map(NameFunction): ... + + +class Prefix(NameFunction): ... + + +class Suffix(NameFunction): ... + + +class ToLowercase(NameFunction): ... + + +class ToUppercase(NameFunction): ... From 6e6b2414d19adf73b1cadc5aa5a0c62974ff0a75 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 14 May 2025 16:34:46 +0100 Subject: [PATCH 013/368] feat: Add `Immutable` - Will probably want to set up the `object.__setattr__` part somewhere - Not decided on how initialization should work --- narwhals/_plan/common.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 4760647261..c60218e684 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -1,7 +1,20 @@ from __future__ import annotations +from typing import TYPE_CHECKING -class ExprIR: ... +if TYPE_CHECKING: + from typing_extensions import Never + + +class Immutable: + __slots__ = () + + def __setattr__(self, name: str, value: Never) -> Never: + msg = f"{type(self).__name__!r} is immutable, {name!r} cannot be set." + raise AttributeError(msg) + + +class ExprIR(Immutable): ... class Function(ExprIR): From b6eb233d9ade9d4e434635c5c866a3eb38d93eed Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 14 May 2025 16:36:28 +0100 Subject: [PATCH 014/368] link city --- narwhals/_plan/__init__.py | 1 + narwhals/_plan/expr.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/__init__.py b/narwhals/_plan/__init__.py index e138a41b36..f1210a8ec0 100644 --- a/narwhals/_plan/__init__.py +++ b/narwhals/_plan/__init__.py @@ -22,6 +22,7 @@ - https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs#L35-L106 - https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs#L131-L236 - https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs#L240-L267 +- https://github.com/pola-rs/polars/blob/6df23a09a81c640c21788607611e09d9f43b1abc/crates/polars-plan/src/plans/aexpr/mod.rs Related: - https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2866902903 diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index f1c33fc78f..9c894d46d5 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -17,7 +17,16 @@ class Literal(ExprIR): ... class BinaryExpr(ExprIR): - """Seems like the application of two exprs via an `Operator`.""" + """Seems like the application of two exprs via an `Operator`. + + This ✅ + - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs#L271-L279 + - https://github.com/pola-rs/polars/blob/6df23a09a81c640c21788607611e09d9f43b1abc/crates/polars-plan/src/plans/aexpr/mod.rs#L152-L155 + - https://github.com/pola-rs/polars/blob/da27decd9a1adabe0498b786585287eb730d1d91/crates/polars-expr/src/expressions/binary.rs + + Not this ❌ + - https://github.com/pola-rs/polars/blob/da27decd9a1adabe0498b786585287eb730d1d91/crates/polars-plan/src/dsl/function_expr/mod.rs#L127 + """ class Cast(ExprIR): ... From 05e0823decde1b6ddef43fbe9091062d932d2c8a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 14 May 2025 19:45:07 +0100 Subject: [PATCH 015/368] docs: Link another related issue --- narwhals/_plan/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/narwhals/_plan/__init__.py b/narwhals/_plan/__init__.py index f1210a8ec0..1c803c9254 100644 --- a/narwhals/_plan/__init__.py +++ b/narwhals/_plan/__init__.py @@ -33,6 +33,7 @@ - https://github.com/narwhals-dev/narwhals/issues/2225 - https://github.com/narwhals-dev/narwhals/issues/1848 - https://github.com/narwhals-dev/narwhals/issues/2534#issuecomment-2875676729 +- https://github.com/narwhals-dev/narwhals/issues/2291 """ from __future__ import annotations From e8344e27dadfcb521f50a54611332e92028ec29a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 14 May 2025 19:49:50 +0100 Subject: [PATCH 016/368] feat: Port over `FunctionOptions` --- narwhals/_plan/options.py | 95 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 narwhals/_plan/options.py diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py new file mode 100644 index 0000000000..ce7caf2b92 --- /dev/null +++ b/narwhals/_plan/options.py @@ -0,0 +1,95 @@ +"""`ExprMetadata` but less god object. + +- https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs +""" + +from __future__ import annotations + +import enum + +from narwhals._plan.common import Immutable + + +class FunctionFlags(enum.Flag): + ALLOW_GROUP_AWARE = 1 << 0 + """> Raise if use in group by + + Not sure where this is disabled. + """ + + RETURNS_SCALAR = 1 << 5 + """Automatically explode on unit length if it ran as final aggregation.""" + + ROW_SEPARABLE = 1 << 8 + """Not sure lol. + + https://github.com/pola-rs/polars/pull/22573 + """ + + LENGTH_PRESERVING = 1 << 9 + """mutually exclusive with `RETURNS_SCALAR`""" + + def is_elementwise(self) -> bool: + return self in (FunctionFlags.ROW_SEPARABLE | FunctionFlags.LENGTH_PRESERVING) + + def returns_scalar(self) -> bool: + return self in FunctionFlags.RETURNS_SCALAR + + def is_length_preserving(self) -> bool: + return self in FunctionFlags.LENGTH_PRESERVING + + @staticmethod + def default() -> FunctionFlags: + return FunctionFlags.ALLOW_GROUP_AWARE + + +class FunctionOptions(Immutable): + flags: FunctionFlags + + def is_elementwise(self) -> bool: + return self.flags.is_elementwise() + + def returns_scalar(self) -> bool: + return self.flags.returns_scalar() + + def is_length_preserving(self) -> bool: + return self.flags.is_length_preserving() + + def with_flags(self, flags: FunctionFlags, /) -> FunctionOptions: + if (FunctionFlags.RETURNS_SCALAR | FunctionFlags.LENGTH_PRESERVING) in flags: + msg = "A function cannot both return a scalar and preserve length, they are mutually exclusive." + raise TypeError(msg) + obj = FunctionOptions.__new__(FunctionOptions) + object.__setattr__(obj, "flags", self.flags | flags) + return obj + + def with_elementwise(self) -> FunctionOptions: + return self.with_flags( + FunctionFlags.ROW_SEPARABLE | FunctionFlags.LENGTH_PRESERVING + ) + + @staticmethod + def default() -> FunctionOptions: + obj = FunctionOptions.__new__(FunctionOptions) + object.__setattr__(obj, "flags", FunctionFlags.default()) + return obj + + @staticmethod + def elementwise() -> FunctionOptions: + return FunctionOptions.default().with_elementwise() + + @staticmethod + def row_separable() -> FunctionOptions: + return FunctionOptions.groupwise().with_flags(FunctionFlags.ROW_SEPARABLE) + + @staticmethod + def length_preserving() -> FunctionOptions: + return FunctionOptions.default().with_flags(FunctionFlags.LENGTH_PRESERVING) + + @staticmethod + def groupwise() -> FunctionOptions: + return FunctionOptions.default() + + @staticmethod + def aggregation() -> FunctionOptions: + return FunctionOptions.groupwise().with_flags(FunctionFlags.RETURNS_SCALAR) From 9e8e526a659af0c551b9a4d50b16a5dc5e89d34b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 14 May 2025 21:02:25 +0100 Subject: [PATCH 017/368] node node node your boat --- narwhals/_plan/aggregation.py | 18 ++++++++--- narwhals/_plan/expr.py | 60 ++++++++++++++++++++++++++++------- 2 files changed, 62 insertions(+), 16 deletions(-) diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index 2b0857f45d..25235a2058 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -1,9 +1,15 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from narwhals._plan.common import ExprIR +if TYPE_CHECKING: + from narwhals.typing import RollingInterpolationMethod + -class Agg(ExprIR): ... +class Agg(ExprIR): + expr: ExprIR class Count(Agg): ... @@ -32,16 +38,20 @@ class Min(Agg): ... class NUnique(Agg): ... -class Quantile(Agg): ... +class Quantile(Agg): + quantile: ExprIR + interpolation: RollingInterpolationMethod -class Std(Agg): ... +class Std(Agg): + ddof: int class Sum(Agg): ... -class Var(Agg): ... +class Var(Agg): + ddof: int class OrderableAgg(Agg): ... diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 9c894d46d5..0604720688 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -2,53 +2,89 @@ # NOTE: Needed to avoid naming collisions # - Literal -import typing as t # noqa: F401 +import typing as t from narwhals._plan.common import ExprIR +if t.TYPE_CHECKING: + from typing_extensions import TypeAlias -class Alias(ExprIR): ... + from narwhals._plan.operators import Operator + from narwhals.dtypes import DType + from narwhals.typing import PythonLiteral + SortOptions: TypeAlias = t.Any + SortMultipleOptions: TypeAlias = t.Any -class Column(ExprIR): ... +class Alias(ExprIR): + expr: ExprIR + name: str -class Literal(ExprIR): ... + +class Column(ExprIR): + name: str + + +class Columns(ExprIR): + names: t.Sequence[str] + + +class Literal(ExprIR): + value: PythonLiteral class BinaryExpr(ExprIR): - """Seems like the application of two exprs via an `Operator`. + """Application of two exprs via an `Operator`. This ✅ - - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs#L271-L279 - https://github.com/pola-rs/polars/blob/6df23a09a81c640c21788607611e09d9f43b1abc/crates/polars-plan/src/plans/aexpr/mod.rs#L152-L155 - - https://github.com/pola-rs/polars/blob/da27decd9a1adabe0498b786585287eb730d1d91/crates/polars-expr/src/expressions/binary.rs Not this ❌ - https://github.com/pola-rs/polars/blob/da27decd9a1adabe0498b786585287eb730d1d91/crates/polars-plan/src/dsl/function_expr/mod.rs#L127 """ + left: ExprIR + op: Operator + right: ExprIR -class Cast(ExprIR): ... +class Cast(ExprIR): + expr: ExprIR + dtype: DType -class Sort(ExprIR): ... + +class Sort(ExprIR): + expr: ExprIR + options: SortOptions class SortBy(ExprIR): """https://github.com/narwhals-dev/narwhals/issues/2534.""" + expr: ExprIR + by: t.Sequence[ExprIR] + options: SortMultipleOptions + -class Filter(ExprIR): ... +class Filter(ExprIR): + expr: ExprIR + by: ExprIR class Len(ExprIR): ... -class Exclude(ExprIR): ... +class Exclude(ExprIR): + names: t.Sequence[str] + + +class Nth(ExprIR): + index: int -class Nth(ExprIR): ... +class IndexColumns(ExprIR): + indices: t.Sequence[int] class All(ExprIR): ... From e37f56280546c380392e23acd962e8a1277cdb06 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 15 May 2025 17:27:21 +0100 Subject: [PATCH 018/368] ci: Omit from cov for now --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 1a644577a5..63c2b4c474 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -262,6 +262,8 @@ omit = [ 'narwhals/_ibis/typing.py', # Remove after finishing eager sub-protocols 'narwhals/_compliant/namespace.py', + # Doesn't have a full impl yet + 'narwhals/_plan/*' ] exclude_also = [ "if sys.version_info() <", From efd2437f6169d629895a38edb36a9e63bf163ddb Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 15 May 2025 19:24:04 +0100 Subject: [PATCH 019/368] more modelling w/ notes - Gonna need to revisit `Function` vs `FunctionExpr` - `functions.py` is confusing me now --- narwhals/_plan/common.py | 19 +++++++++++++-- narwhals/_plan/expr.py | 47 +++++++++++++++++++++++++++++++++---- narwhals/_plan/functions.py | 2 +- narwhals/_plan/window.py | 6 ++++- 4 files changed, 65 insertions(+), 9 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index c60218e684..9e6ac63892 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -3,7 +3,16 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from typing import Any + from typing_extensions import Never + from typing_extensions import TypeAlias + + from narwhals._plan.options import FunctionOptions + + SortOptions: TypeAlias = Any + SortMultipleOptions: TypeAlias = Any + WindowType: TypeAlias = Any class Immutable: @@ -18,7 +27,13 @@ class ExprIR(Immutable): ... class Function(ExprIR): - """Shared by expr functions and namespace functions.""" + """Shared by expr functions and namespace functions. + + https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L114 + """ + @property + def function_options(self) -> FunctionOptions: + from narwhals._plan.options import FunctionOptions -class FunctionExpr(ExprIR): ... + return FunctionOptions.default() diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 0604720688..e521c53007 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -7,15 +7,15 @@ from narwhals._plan.common import ExprIR if t.TYPE_CHECKING: - from typing_extensions import TypeAlias - + from narwhals._plan.common import Function + from narwhals._plan.common import SortMultipleOptions + from narwhals._plan.common import SortOptions from narwhals._plan.operators import Operator + from narwhals._plan.options import FunctionOptions + from narwhals._plan.window import Window from narwhals.dtypes import DType from narwhals.typing import PythonLiteral - SortOptions: TypeAlias = t.Any - SortMultipleOptions: TypeAlias = t.Any - class Alias(ExprIR): expr: ExprIR @@ -67,11 +67,41 @@ class SortBy(ExprIR): options: SortMultipleOptions +class FunctionExpr(ExprIR): + """Polars uses seemingly for namespacing, but maybe I'll use for traversal? + + https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/function_expr/mod.rs#L123 + """ + + input: t.Sequence[ExprIR] + function: Function + options: FunctionOptions + + class Filter(ExprIR): expr: ExprIR by: ExprIR +class WindowExpr(ExprIR): + """https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L129-L136.""" + + expr: ExprIR + """Renamed from `function`.""" + + partition_by: t.Sequence[ExprIR] + order_by: tuple[ExprIR, SortOptions] | None + options: Window + """Little confused on the nesting. + + - We don't allow choosing `WindowMapping` kinds + - Haven't ventured into rolling much yet + + Expr::Window { options: WindowType::Over(WindowMapping) } + Expr::Window { options: WindowType::Rolling(RollingGroupOptions) } + """ + + class Len(ExprIR): ... @@ -84,6 +114,13 @@ class Nth(ExprIR): class IndexColumns(ExprIR): + """Renamed from `IndexColumn`. + + `Nth` provides the single variant. + + https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L80 + """ + indices: t.Sequence[int] diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 4ff5317253..2e425f9f07 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -1,6 +1,6 @@ from __future__ import annotations -from narwhals._plan.common import FunctionExpr +from narwhals._plan.expr import FunctionExpr class Abs(FunctionExpr): ... diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py index 2651bc3a89..bb2fe5aef1 100644 --- a/narwhals/_plan/window.py +++ b/narwhals/_plan/window.py @@ -3,7 +3,11 @@ from narwhals._plan.common import ExprIR -class Window(ExprIR): ... +class Window(ExprIR): + """Renamed from `WindowType`. + + https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/options/mod.rs#L139 + """ class OverWindow(Window): ... From 1dbeabd1ce3d1af0ec007ec2088cb63cc4802677 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 15 May 2025 21:16:27 +0100 Subject: [PATCH 020/368] Add `SortOptions`, `SortMultipleOptions` --- narwhals/_plan/common.py | 2 -- narwhals/_plan/expr.py | 4 ++-- narwhals/_plan/options.py | 15 +++++++++++++++ 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 9e6ac63892..8492f5dcea 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -10,8 +10,6 @@ from narwhals._plan.options import FunctionOptions - SortOptions: TypeAlias = Any - SortMultipleOptions: TypeAlias = Any WindowType: TypeAlias = Any diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index e521c53007..c7a0cbb3e4 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -8,10 +8,10 @@ if t.TYPE_CHECKING: from narwhals._plan.common import Function - from narwhals._plan.common import SortMultipleOptions - from narwhals._plan.common import SortOptions from narwhals._plan.operators import Operator from narwhals._plan.options import FunctionOptions + from narwhals._plan.options import SortMultipleOptions + from narwhals._plan.options import SortOptions from narwhals._plan.window import Window from narwhals.dtypes import DType from narwhals.typing import PythonLiteral diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index ce7caf2b92..e86f5828d9 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -6,9 +6,13 @@ from __future__ import annotations import enum +from typing import TYPE_CHECKING from narwhals._plan.common import Immutable +if TYPE_CHECKING: + from typing import Sequence + class FunctionFlags(enum.Flag): ALLOW_GROUP_AWARE = 1 << 0 @@ -93,3 +97,14 @@ def groupwise() -> FunctionOptions: @staticmethod def aggregation() -> FunctionOptions: return FunctionOptions.groupwise().with_flags(FunctionFlags.RETURNS_SCALAR) + + +# TODO @dangotbanned: spec these out +class SortOptions(Immutable): + descending: bool + nulls_last: bool + + +class SortMultipleOptions(Immutable): + descending: Sequence[bool] + nulls_last: Sequence[bool] From 5526830012e8820275bfff63d9aad2a59fc26724 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 15 May 2025 21:16:57 +0100 Subject: [PATCH 021/368] Make `Immutable` spicier --- narwhals/_plan/common.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 8492f5dcea..4775a0ca68 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -20,6 +20,28 @@ def __setattr__(self, name: str, value: Never) -> Never: msg = f"{type(self).__name__!r} is immutable, {name!r} cannot be set." raise AttributeError(msg) + def __init_subclass__(cls, *args: Any, **kwds: Any) -> None: + super().__init_subclass__(*args, **kwds) + if cls.__slots__: + ... + else: + cls.__slots__ = () + + def __hash__(self) -> int: + empty = object() + return hash(tuple(getattr(self, name, empty) for name in self.__slots__)) + + def __eq__(self, other: object) -> bool: + if self is other: + return True + elif type(self) is not type(other): + return False + empty = object() + return all( + getattr(self, name, empty) == getattr(other, name, empty) + for name in self.__slots__ + ) + class ExprIR(Immutable): ... From c91864f219cf70801e2610b32300ea6e0c331c97 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 15 May 2025 21:31:27 +0100 Subject: [PATCH 022/368] chore: Filling out some slots --- narwhals/_plan/aggregation.py | 8 ++++++++ narwhals/_plan/common.py | 7 ++++--- narwhals/_plan/expr.py | 28 ++++++++++++++++++++++++++++ narwhals/_plan/options.py | 8 +++++++- 4 files changed, 47 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index 25235a2058..87e409c104 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -9,6 +9,8 @@ class Agg(ExprIR): + __slots__ = ("expr",) + expr: ExprIR @@ -39,11 +41,15 @@ class NUnique(Agg): ... class Quantile(Agg): + __slots__ = ("expr", "interpolation", "quantile") + quantile: ExprIR interpolation: RollingInterpolationMethod class Std(Agg): + __slots__ = ("ddof", "expr") + ddof: int @@ -51,6 +57,8 @@ class Sum(Agg): ... class Var(Agg): + __slots__ = ("ddof", "expr") + ddof: int diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 4775a0ca68..e434e2b816 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -29,7 +29,8 @@ def __init_subclass__(cls, *args: Any, **kwds: Any) -> None: def __hash__(self) -> int: empty = object() - return hash(tuple(getattr(self, name, empty) for name in self.__slots__)) + slots: tuple[str, ...] = self.__slots__ + return hash(tuple(getattr(self, name, empty) for name in slots)) def __eq__(self, other: object) -> bool: if self is other: @@ -37,9 +38,9 @@ def __eq__(self, other: object) -> bool: elif type(self) is not type(other): return False empty = object() + slots: tuple[str, ...] = self.__slots__ return all( - getattr(self, name, empty) == getattr(other, name, empty) - for name in self.__slots__ + getattr(self, name, empty) == getattr(other, name, empty) for name in slots ) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index c7a0cbb3e4..5d3608d0ea 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -18,19 +18,27 @@ class Alias(ExprIR): + __slots__ = ("expr", "name") + expr: ExprIR name: str class Column(ExprIR): + __slots__ = ("name",) + name: str class Columns(ExprIR): + __slots__ = ("names",) + names: t.Sequence[str] class Literal(ExprIR): + __slots__ = ("value",) + value: PythonLiteral @@ -44,17 +52,23 @@ class BinaryExpr(ExprIR): - https://github.com/pola-rs/polars/blob/da27decd9a1adabe0498b786585287eb730d1d91/crates/polars-plan/src/dsl/function_expr/mod.rs#L127 """ + __slots__ = ("left", "op", "right") + left: ExprIR op: Operator right: ExprIR class Cast(ExprIR): + __slots__ = ("dtype", "expr") + expr: ExprIR dtype: DType class Sort(ExprIR): + __slots__ = ("expr", "options") + expr: ExprIR options: SortOptions @@ -62,6 +76,8 @@ class Sort(ExprIR): class SortBy(ExprIR): """https://github.com/narwhals-dev/narwhals/issues/2534.""" + __slots__ = ("by", "expr", "options") + expr: ExprIR by: t.Sequence[ExprIR] options: SortMultipleOptions @@ -73,12 +89,16 @@ class FunctionExpr(ExprIR): https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/function_expr/mod.rs#L123 """ + __slots__ = ("function", "input", "options") + input: t.Sequence[ExprIR] function: Function options: FunctionOptions class Filter(ExprIR): + __slots__ = ("by", "expr") + expr: ExprIR by: ExprIR @@ -86,6 +106,8 @@ class Filter(ExprIR): class WindowExpr(ExprIR): """https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L129-L136.""" + __slots__ = ("expr", "options", "order_by", "partition_by") + expr: ExprIR """Renamed from `function`.""" @@ -106,10 +128,14 @@ class Len(ExprIR): ... class Exclude(ExprIR): + __slots__ = ("names",) + names: t.Sequence[str] class Nth(ExprIR): + __slots__ = ("index",) + index: int @@ -121,6 +147,8 @@ class IndexColumns(ExprIR): https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L80 """ + __slots__ = ("indices",) + indices: t.Sequence[int] diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index e86f5828d9..c2912605af 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -48,6 +48,8 @@ def default() -> FunctionFlags: class FunctionOptions(Immutable): + __slots__ = ("flags",) + flags: FunctionFlags def is_elementwise(self) -> bool: @@ -99,12 +101,16 @@ def aggregation() -> FunctionOptions: return FunctionOptions.groupwise().with_flags(FunctionFlags.RETURNS_SCALAR) -# TODO @dangotbanned: spec these out +# TODO @dangotbanned: Decide on constructors class SortOptions(Immutable): + __slots__ = ("descending", "nulls_last") + descending: bool nulls_last: bool class SortMultipleOptions(Immutable): + __slots__ = ("descending", "nulls_last") + descending: Sequence[bool] nulls_last: Sequence[bool] From 3a6d61fb3d7ede1d7d9b307de0843c37a3bdf823 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 15 May 2025 21:34:16 +0100 Subject: [PATCH 023/368] docs: Align with (#2547) version --- narwhals/_plan/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/__init__.py b/narwhals/_plan/__init__.py index 1c803c9254..ea8952c8a0 100644 --- a/narwhals/_plan/__init__.py +++ b/narwhals/_plan/__init__.py @@ -5,7 +5,7 @@ - But the node does not need to be unique to the method - A chain of `Expr` methods should form a plan of operations - We must be able to enforce rules on what plans are permitted: - - Must be flexible to both eager/lazy and invdividual backends + - Must be flexible to both eager/lazy and individual backends - Must be flexible to a given context (select, with_columns, filter, group_by) - Nodes & plans are: - Immutable, but @@ -29,7 +29,7 @@ - https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2867331343 - https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2867446959 - https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2869070157 -- https://github.com/narwhals-dev/narwhals/pull/2538/commits/a7eeb0d23e67cb70e7cfa73cec2c7b69a15c8bef#r2083562677 +- (https://github.com/narwhals-dev/narwhals/pull/2538/commits/a7eeb0d23e67cb70e7cfa73cec2c7b69a15c8bef#r2083562677) - https://github.com/narwhals-dev/narwhals/issues/2225 - https://github.com/narwhals-dev/narwhals/issues/1848 - https://github.com/narwhals-dev/narwhals/issues/2534#issuecomment-2875676729 From b6e4d1e9cbb7bfcd6d4382483bf5aa770f9d7148 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 15 May 2025 21:39:41 +0100 Subject: [PATCH 024/368] docs: add todos --- narwhals/_plan/boolean.py | 2 ++ narwhals/_plan/categorical.py | 2 ++ narwhals/_plan/expr.py | 2 +- narwhals/_plan/functions.py | 2 ++ narwhals/_plan/lists.py | 2 ++ narwhals/_plan/name.py | 2 ++ narwhals/_plan/strings.py | 2 ++ narwhals/_plan/struct.py | 2 ++ narwhals/_plan/temporal.py | 2 ++ narwhals/_plan/window.py | 2 ++ 10 files changed, 19 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/boolean.py index 140a30f194..98378df42a 100644 --- a/narwhals/_plan/boolean.py +++ b/narwhals/_plan/boolean.py @@ -1,3 +1,5 @@ +"""TODO: Attributes.""" + from __future__ import annotations # NOTE: Needed to avoid naming collisions diff --git a/narwhals/_plan/categorical.py b/narwhals/_plan/categorical.py index fc43b522ce..4d2751a99b 100644 --- a/narwhals/_plan/categorical.py +++ b/narwhals/_plan/categorical.py @@ -1,3 +1,5 @@ +"""TODO: Attributes.""" + from __future__ import annotations from narwhals._plan.common import Function diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 5d3608d0ea..29a1a01ceb 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -84,7 +84,7 @@ class SortBy(ExprIR): class FunctionExpr(ExprIR): - """Polars uses seemingly for namespacing, but maybe I'll use for traversal? + """Polars uses *seemingly* for namespacing, but maybe I'll use for traversal? https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/function_expr/mod.rs#L123 """ diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 2e425f9f07..972b504001 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -1,3 +1,5 @@ +"""TODO: Attributes.""" + from __future__ import annotations from narwhals._plan.expr import FunctionExpr diff --git a/narwhals/_plan/lists.py b/narwhals/_plan/lists.py index 7368d41010..7c9834185a 100644 --- a/narwhals/_plan/lists.py +++ b/narwhals/_plan/lists.py @@ -1,3 +1,5 @@ +"""TODO: Attributes.""" + from __future__ import annotations from narwhals._plan.common import Function diff --git a/narwhals/_plan/name.py b/narwhals/_plan/name.py index 393e6e37dc..c7ef85c807 100644 --- a/narwhals/_plan/name.py +++ b/narwhals/_plan/name.py @@ -1,3 +1,5 @@ +"""TODO: Attributes.""" + from __future__ import annotations from narwhals._plan.common import Function diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/strings.py index 54524892b0..6fb41321db 100644 --- a/narwhals/_plan/strings.py +++ b/narwhals/_plan/strings.py @@ -1,3 +1,5 @@ +"""TODO: Attributes.""" + from __future__ import annotations from narwhals._plan.common import Function diff --git a/narwhals/_plan/struct.py b/narwhals/_plan/struct.py index cbcbfb3652..0dda592538 100644 --- a/narwhals/_plan/struct.py +++ b/narwhals/_plan/struct.py @@ -1,3 +1,5 @@ +"""TODO: Attributes.""" + from __future__ import annotations from narwhals._plan.common import Function diff --git a/narwhals/_plan/temporal.py b/narwhals/_plan/temporal.py index 4b1f873412..86f0de25ad 100644 --- a/narwhals/_plan/temporal.py +++ b/narwhals/_plan/temporal.py @@ -1,3 +1,5 @@ +"""TODO: Attributes.""" + from __future__ import annotations from narwhals._plan.common import Function diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py index bb2fe5aef1..172697182e 100644 --- a/narwhals/_plan/window.py +++ b/narwhals/_plan/window.py @@ -1,3 +1,5 @@ +"""TODO: Attributes.""" + from __future__ import annotations from narwhals._plan.common import ExprIR From 56007d596b7d7a1afaaee117af3a2970a408cdc8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 15 May 2025 21:41:03 +0100 Subject: [PATCH 025/368] revert: remove unused --- narwhals/_plan/common.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index e434e2b816..8676e7afc3 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -6,12 +6,9 @@ from typing import Any from typing_extensions import Never - from typing_extensions import TypeAlias from narwhals._plan.options import FunctionOptions - WindowType: TypeAlias = Any - class Immutable: __slots__ = () From 854f6b43b423e949b64b55e683f6b1248197a412 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 11:07:10 +0100 Subject: [PATCH 026/368] Mock up `ExprIR` conversion --- narwhals/_plan/common.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 8676e7afc3..d7c0e4632e 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -6,6 +6,7 @@ from typing import Any from typing_extensions import Never + from typing_extensions import Self from narwhals._plan.options import FunctionOptions @@ -41,7 +42,36 @@ def __eq__(self, other: object) -> bool: ) -class ExprIR(Immutable): ... +class ExprIR(Immutable): + """Anything that can be a node on a graph of expressions.""" + + def to_narwhals(self) -> DummyExpr: + return DummyExpr._from_ir(self) + + def to_compliant(self) -> DummyCompliantExpr: + return DummyCompliantExpr._from_ir(self) + + +# NOTE: Overly simplified placeholders for mocking typing +# Entirely ignoring namespace + function binding +class DummyExpr: + _ir: ExprIR + + @classmethod + def _from_ir(cls, ir: ExprIR, /) -> Self: + obj = cls.__new__(cls) + obj._ir = ir + return obj + + +class DummyCompliantExpr: + _ir: ExprIR + + @classmethod + def _from_ir(cls, ir: ExprIR, /) -> Self: + obj = cls.__new__(cls) + obj._ir = ir + return obj class Function(ExprIR): From 7ff6d3c130e5358c7e9967aed86c1f506a8ef0b6 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 11:10:05 +0100 Subject: [PATCH 027/368] feat: Fill out `struct` --- narwhals/_plan/struct.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/struct.py b/narwhals/_plan/struct.py index 0dda592538..f30e11aa64 100644 --- a/narwhals/_plan/struct.py +++ b/narwhals/_plan/struct.py @@ -1,11 +1,19 @@ -"""TODO: Attributes.""" - from __future__ import annotations from narwhals._plan.common import Function +from narwhals._plan.options import FunctionOptions class StructFunction(Function): ... -class FieldByName(StructFunction): ... +class FieldByName(StructFunction): + """https://github.com/pola-rs/polars/blob/62257860a43ec44a638e8492ed2cf98a49c05f2e/crates/polars-plan/src/dsl/function_expr/struct_.rs#L11.""" + + __slots__ = ("name",) + + name: str + + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() From f27e5901f2519ba08be57760b31186132ae4adb1 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 11:13:56 +0100 Subject: [PATCH 028/368] feat: Fill out `lists` --- narwhals/_plan/lists.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/lists.py b/narwhals/_plan/lists.py index 7c9834185a..97d55c982f 100644 --- a/narwhals/_plan/lists.py +++ b/narwhals/_plan/lists.py @@ -1,11 +1,15 @@ -"""TODO: Attributes.""" - from __future__ import annotations from narwhals._plan.common import Function +from narwhals._plan.options import FunctionOptions class ListFunction(Function): ... -class Len(ListFunction): ... +class Len(ListFunction): + """https://github.com/pola-rs/polars/blob/62257860a43ec44a638e8492ed2cf98a49c05f2e/crates/polars-plan/src/dsl/function_expr/list.rs#L32.""" + + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() From ba69ae361d6153030d0321835473cfd551177575 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 11:16:36 +0100 Subject: [PATCH 029/368] feat: Fill out `categorical` --- narwhals/_plan/categorical.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/categorical.py b/narwhals/_plan/categorical.py index 4d2751a99b..834211c63e 100644 --- a/narwhals/_plan/categorical.py +++ b/narwhals/_plan/categorical.py @@ -1,11 +1,16 @@ -"""TODO: Attributes.""" - from __future__ import annotations from narwhals._plan.common import Function +from narwhals._plan.options import FunctionOptions class CategoricalFunction(Function): ... -class GetCategories(CategoricalFunction): ... +class GetCategories(CategoricalFunction): + """https://github.com/pola-rs/polars/blob/62257860a43ec44a638e8492ed2cf98a49c05f2e/crates/polars-plan/src/dsl/function_expr/cat.rs#L7.""" + + @property + def function_options(self) -> FunctionOptions: + """https://github.com/pola-rs/polars/blob/62257860a43ec44a638e8492ed2cf98a49c05f2e/crates/polars-plan/src/dsl/function_expr/cat.rs#L41.""" + return FunctionOptions.groupwise() From b87836008263ae6a21cbd8e52e505fd0f0e529f7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 11:43:27 +0100 Subject: [PATCH 030/368] feat: Fill out `boolean` Starting understand `Function` vs `FunctionExpr` a bit better now --- narwhals/_plan/boolean.py | 94 ++++++++++++++++++++++++++++++++------- narwhals/_plan/options.py | 16 ++++--- 2 files changed, 89 insertions(+), 21 deletions(-) diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/boolean.py index 98378df42a..4ce751facc 100644 --- a/narwhals/_plan/boolean.py +++ b/narwhals/_plan/boolean.py @@ -1,51 +1,113 @@ -"""TODO: Attributes.""" - from __future__ import annotations # NOTE: Needed to avoid naming collisions # - Any -import typing as t # noqa: F401 +import typing as t from narwhals._plan.common import Function +from narwhals._plan.options import FunctionFlags +from narwhals._plan.options import FunctionOptions + +if t.TYPE_CHECKING: + from narwhals.typing import ClosedInterval class BooleanFunction(Function): ... -class All(BooleanFunction): ... +class All(BooleanFunction): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.aggregation() + + +class AllHorizontal(BooleanFunction): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise().with_flags( + FunctionFlags.INPUT_WILDCARD_EXPANSION + ) + + +class Any(BooleanFunction): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.aggregation() + + +class AnyHorizontal(BooleanFunction): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise().with_flags( + FunctionFlags.INPUT_WILDCARD_EXPANSION + ) + +class IsBetween(BooleanFunction): + """`lower_bound`, `upper_bound` aren't spec'd in the function enum. -class AllHorizontal(BooleanFunction): ... + Assuming the `FunctionExpr.input` becomes `s` in the impl + https://github.com/pola-rs/polars/blob/62257860a43ec44a638e8492ed2cf98a49c05f2e/crates/polars-plan/src/dsl/function_expr/boolean.rs#L225-L237 + """ -class Any(BooleanFunction): ... + __slots__ = ("closed",) + closed: ClosedInterval -class AnyHorizontal(BooleanFunction): ... + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() -class IsBetween(BooleanFunction): ... +class IsDuplicated(BooleanFunction): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.length_preserving() -class IsDuplicated(BooleanFunction): ... +class IsFinite(BooleanFunction): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() -class IsFinite(BooleanFunction): ... +class IsFirstDistinct(BooleanFunction): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.length_preserving() -class IsFirstDistinct(BooleanFunction): ... +class IsIn(BooleanFunction): + """``other` isn't spec'd in the function enum. + See `IsBetween` comment. + """ -class IsIn(BooleanFunction): ... + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() -class IsLastDistinct(BooleanFunction): ... +class IsLastDistinct(BooleanFunction): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.length_preserving() -class IsNan(BooleanFunction): ... +class IsNan(BooleanFunction): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() -class IsNull(BooleanFunction): ... +class IsNull(BooleanFunction): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() -class IsUnique(BooleanFunction): ... +class IsUnique(BooleanFunction): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.length_preserving() diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index c2912605af..d0b6ec1941 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -1,8 +1,3 @@ -"""`ExprMetadata` but less god object. - -- https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs -""" - from __future__ import annotations import enum @@ -21,6 +16,12 @@ class FunctionFlags(enum.Flag): Not sure where this is disabled. """ + INPUT_WILDCARD_EXPANSION = 1 << 4 + """Appears on all the horizontal aggs. + + https://github.com/pola-rs/polars/blob/e8ad1059721410e65a3d5c1d84055fb22a4d6d43/crates/polars-plan/src/plans/options.rs#L49-L58 + """ + RETURNS_SCALAR = 1 << 5 """Automatically explode on unit length if it ran as final aggregation.""" @@ -48,6 +49,11 @@ def default() -> FunctionFlags: class FunctionOptions(Immutable): + """ExprMetadata` but less god object. + + https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs + """ + __slots__ = ("flags",) flags: FunctionFlags From db0e35d7c4a5710917297e47ffbdcee4debf53db Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 12:58:22 +0100 Subject: [PATCH 031/368] feat: Fill out `strings` - Probably going to loop back and explcitly define more input fields - At least until we accept `Expr` everywhere, these should more closely reflect `narwhals` than `polars` --- narwhals/_plan/strings.py | 69 ++++++++++++++++++++++++++++++++------- 1 file changed, 57 insertions(+), 12 deletions(-) diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/strings.py index 6fb41321db..4ac3531d4b 100644 --- a/narwhals/_plan/strings.py +++ b/narwhals/_plan/strings.py @@ -1,30 +1,45 @@ -"""TODO: Attributes.""" - from __future__ import annotations from narwhals._plan.common import Function +from narwhals._plan.options import FunctionFlags +from narwhals._plan.options import FunctionOptions -class StringFunction(Function): ... +class StringFunction(Function): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() class ConcatHorizontal(StringFunction): - """`concat_str`.""" + """`nw.functions.concat_str`.""" + __slots__ = ("ignore_nulls", "separator") -class Contains(StringFunction): ... + separator: str + ignore_nulls: bool + @property + def function_options(self) -> FunctionOptions: + return super().function_options.with_flags(FunctionFlags.INPUT_WILDCARD_EXPANSION) -class EndsWith(StringFunction): ... +class Contains(StringFunction): + __slots__ = ("literal",) + + literal: bool -class Head(StringFunction): ... + +class EndsWith(StringFunction): ... class LenChars(StringFunction): ... -class Replace(StringFunction): ... +class Replace(StringFunction): + __slots__ = ("literal",) + + literal: bool class ReplaceAll(StringFunction): @@ -33,8 +48,35 @@ class ReplaceAll(StringFunction): https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/function_expr/strings.rs#L65-L70 """ + __slots__ = ("literal",) + + literal: bool + + +class Slice(StringFunction): + """We're using for `Head`, `Tail` as well. + + https://github.com/dangotbanned/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/function_expr/strings.rs#L87-L89 + + I don't think it's likely we'll support `Expr` as inputs for this any time soon. + """ + + __slots__ = ("length", "offset") + + offset: int + length: int | None + + +class Head(StringFunction): + __slots__ = ("n",) + + n: int + -class Slice(StringFunction): ... +class Tail(StringFunction): + __slots__ = ("n",) + + n: int class Split(StringFunction): ... @@ -46,15 +88,18 @@ class StartsWith(StringFunction): ... class StripChars(StringFunction): ... -class Tail(StringFunction): ... - - class ToDatetime(StringFunction): """`polars` uses `Strptime`. + We've got a fairly different representation. + https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/function_expr/strings.rs#L112 """ + __slots__ = ("format",) + + format: str | None + class ToLowercase(StringFunction): ... From 311351dda73ec074bc35dc07e0adf0152b67b8e9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 12:59:24 +0100 Subject: [PATCH 032/368] docs: Add note on nested `options` --- narwhals/_plan/expr.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 29a1a01ceb..4d8b915168 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -94,6 +94,11 @@ class FunctionExpr(ExprIR): input: t.Sequence[ExprIR] function: Function options: FunctionOptions + """Assuming this is **either**: + + 1. `function.function_options` + 2. The union of (1) and any `FunctionOptions` in `inputs` + """ class Filter(ExprIR): From e59f0eefbc1bec2d0ac286c19aaedd5c11d63c2d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 13:07:41 +0100 Subject: [PATCH 033/368] feat: Fill out `temporal` --- narwhals/_plan/temporal.py | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/temporal.py b/narwhals/_plan/temporal.py index 86f0de25ad..3f50f25f46 100644 --- a/narwhals/_plan/temporal.py +++ b/narwhals/_plan/temporal.py @@ -1,11 +1,18 @@ -"""TODO: Attributes.""" - from __future__ import annotations +from typing import TYPE_CHECKING + from narwhals._plan.common import Function +from narwhals._plan.options import FunctionOptions +if TYPE_CHECKING: + from narwhals.typing import TimeUnit -class TemporalFunction(Function): ... + +class TemporalFunction(Function): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() class Date(TemporalFunction): ... @@ -59,16 +66,31 @@ class TotalMicroseconds(TemporalFunction): ... class TotalNanoseconds(TemporalFunction): ... -class ToString(TemporalFunction): ... +class ToString(TemporalFunction): + __slots__ = ("format",) + + format: str + + +class ReplaceTimeZone(TemporalFunction): + __slots__ = ("time_zone",) + + time_zone: str | None + +class ConvertTimeZone(TemporalFunction): + __slots__ = ("time_zone",) -class ReplaceTimeZone(TemporalFunction): ... + time_zone: str -class ConvertTimeZone(TemporalFunction): ... +class Timestamp(TemporalFunction): + __slots__ = ("time_unit",) + time_unit: TimeUnit -class Timestamp(TemporalFunction): ... +class Truncate(TemporalFunction): + __slots__ = ("every",) -class Truncate(TemporalFunction): ... + every: str From 29741adf5471b1f5784d580dc2db6ef302a45898 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 13:15:58 +0100 Subject: [PATCH 034/368] feat: Fill out `name` --- narwhals/_plan/name.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/name.py b/narwhals/_plan/name.py index c7ef85c807..3d4e6dd7e0 100644 --- a/narwhals/_plan/name.py +++ b/narwhals/_plan/name.py @@ -1,8 +1,12 @@ -"""TODO: Attributes.""" - from __future__ import annotations +from typing import TYPE_CHECKING + from narwhals._plan.common import Function +from narwhals._plan.options import FunctionOptions + +if TYPE_CHECKING: + from narwhals._compliant.typing import AliasName class NameFunction(Function): @@ -11,17 +15,30 @@ class NameFunction(Function): https://github.com/pola-rs/polars/blob/6df23a09a81c640c21788607611e09d9f43b1abc/crates/polars-plan/src/dsl/name.rs """ + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() + class Keep(NameFunction): ... -class Map(NameFunction): ... +class Map(NameFunction): + __slots__ = ("function",) + + function: AliasName + + +class Prefix(NameFunction): + __slots__ = ("prefix",) + prefix: str -class Prefix(NameFunction): ... +class Suffix(NameFunction): + __slots__ = ("suffix",) -class Suffix(NameFunction): ... + suffix: str class ToLowercase(NameFunction): ... From 8bdd3cddd340341f8a47d4b060832ba153333937 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 13:49:50 +0100 Subject: [PATCH 035/368] feat: Add `ExprIR.is_scalar` - Will need to investigate a lot of the others - May need subtypes for `Literal` --- narwhals/_plan/common.py | 4 ++++ narwhals/_plan/expr.py | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index d7c0e4632e..7b83d8371d 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -51,6 +51,10 @@ def to_narwhals(self) -> DummyExpr: def to_compliant(self) -> DummyCompliantExpr: return DummyCompliantExpr._from_ir(self) + @property + def is_scalar(self) -> bool: + return False + # NOTE: Overly simplified placeholders for mocking typing # Entirely ignoring namespace + function binding diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 4d8b915168..543038cb6d 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -23,6 +23,10 @@ class Alias(ExprIR): expr: ExprIR name: str + @property + def is_scalar(self) -> bool: + return self.expr.is_scalar + class Column(ExprIR): __slots__ = ("name",) @@ -41,6 +45,10 @@ class Literal(ExprIR): value: PythonLiteral + @property + def is_scalar(self) -> bool: + return True + class BinaryExpr(ExprIR): """Application of two exprs via an `Operator`. @@ -58,6 +66,10 @@ class BinaryExpr(ExprIR): op: Operator right: ExprIR + @property + def is_scalar(self) -> bool: + return self.left.is_scalar and self.right.is_scalar + class Cast(ExprIR): __slots__ = ("dtype", "expr") @@ -65,6 +77,10 @@ class Cast(ExprIR): expr: ExprIR dtype: DType + @property + def is_scalar(self) -> bool: + return self.expr.is_scalar + class Sort(ExprIR): __slots__ = ("expr", "options") @@ -72,6 +88,10 @@ class Sort(ExprIR): expr: ExprIR options: SortOptions + @property + def is_scalar(self) -> bool: + return self.expr.is_scalar + class SortBy(ExprIR): """https://github.com/narwhals-dev/narwhals/issues/2534.""" @@ -82,6 +102,10 @@ class SortBy(ExprIR): by: t.Sequence[ExprIR] options: SortMultipleOptions + @property + def is_scalar(self) -> bool: + return self.expr.is_scalar + class FunctionExpr(ExprIR): """Polars uses *seemingly* for namespacing, but maybe I'll use for traversal? @@ -107,6 +131,10 @@ class Filter(ExprIR): expr: ExprIR by: ExprIR + @property + def is_scalar(self) -> bool: + return self.expr.is_scalar and self.by.is_scalar + class WindowExpr(ExprIR): """https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L129-L136.""" @@ -129,7 +157,10 @@ class WindowExpr(ExprIR): """ -class Len(ExprIR): ... +class Len(ExprIR): + @property + def is_scalar(self) -> bool: + return True class Exclude(ExprIR): From e51eba891719a5eb1f7ce91c02a477af39c0baee Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 14:51:34 +0100 Subject: [PATCH 036/368] fix: Remove `Operator` from `ExprIR` hierarchy - Entirely separate from expressions - Realised the mistake while doing `is_scalar` --- narwhals/_plan/operators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index d11ce2cc23..2a4a904d73 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -1,9 +1,9 @@ from __future__ import annotations -from narwhals._plan.common import ExprIR +from narwhals._plan.common import Immutable -class Operator(ExprIR): ... +class Operator(Immutable): ... class Eq(Operator): ... From 2ab0f6840b88b2cff44127ff39848f01ef064b88 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 15:42:48 +0100 Subject: [PATCH 037/368] feat: Add `literal` The `Literal` -> `LiteralValue` wrapping mirrors what the other top-level `expr` nodes do --- narwhals/_plan/common.py | 3 +++ narwhals/_plan/expr.py | 8 ++++--- narwhals/_plan/literal.py | 49 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 3 deletions(-) create mode 100644 narwhals/_plan/literal.py diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 7b83d8371d..55131dfc03 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -78,6 +78,9 @@ def _from_ir(cls, ir: ExprIR, /) -> Self: return obj +class DummySeries: ... + + class Function(ExprIR): """Shared by expr functions and namespace functions. diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 543038cb6d..81ca6ac2bf 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -8,13 +8,13 @@ if t.TYPE_CHECKING: from narwhals._plan.common import Function + from narwhals._plan.literal import LiteralValue from narwhals._plan.operators import Operator from narwhals._plan.options import FunctionOptions from narwhals._plan.options import SortMultipleOptions from narwhals._plan.options import SortOptions from narwhals._plan.window import Window from narwhals.dtypes import DType - from narwhals.typing import PythonLiteral class Alias(ExprIR): @@ -41,13 +41,15 @@ class Columns(ExprIR): class Literal(ExprIR): + """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L81.""" + __slots__ = ("value",) - value: PythonLiteral + value: LiteralValue @property def is_scalar(self) -> bool: - return True + return self.value.is_scalar class BinaryExpr(ExprIR): diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py new file mode 100644 index 0000000000..bbf3c3a093 --- /dev/null +++ b/narwhals/_plan/literal.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from narwhals._plan.common import ExprIR + +if TYPE_CHECKING: + from narwhals._plan.common import DummySeries + from narwhals.dtypes import DType + from narwhals.typing import PythonLiteral + + +class LiteralValue(ExprIR): + """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/plans/lit.rs#L67-L73.""" + + +class ScalarLiteral(LiteralValue): + __slots__ = ("value",) + + value: PythonLiteral + + @property + def is_scalar(self) -> bool: + return True + + +class SeriesLiteral(LiteralValue): + """We already need this. + + https://github.com/narwhals-dev/narwhals/blob/e51eba891719a5eb1f7ce91c02a477af39c0baee/narwhals/_expression_parsing.py#L96-L97 + """ + + __slots__ = ("value",) + + value: DummySeries + + +class RangeLiteral(LiteralValue): + """Don't need yet, but might push forward the discussions. + + - https://github.com/narwhals-dev/narwhals/issues/2463#issuecomment-2844654064 + - https://github.com/narwhals-dev/narwhals/issues/2307#issuecomment-2832422364. + """ + + __slots__ = ("dtype", "high", "low") + + low: int + high: int + dtype: DType From bacf7ddb112d4be09e374d1761c0bb3b42d26d1c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 15:43:55 +0100 Subject: [PATCH 038/368] feat: Link `FunctionOptions.returns_scalar` and `Function.is_scalar` --- narwhals/_plan/common.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 55131dfc03..0960c23105 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -92,3 +92,7 @@ def function_options(self) -> FunctionOptions: from narwhals._plan.options import FunctionOptions return FunctionOptions.default() + + @property + def is_scalar(self) -> bool: + return self.function_options.returns_scalar() From 2721a74bbafda1357bf572a73830915c3afc3545 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 15:59:36 +0100 Subject: [PATCH 039/368] chore: More clearly separate `Function` vs `FunctionExpr` --- narwhals/_plan/expr.py | 11 +++++++- narwhals/_plan/functions.py | 56 ++++++++++++++++++------------------- 2 files changed, 38 insertions(+), 29 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 81ca6ac2bf..5aef1ee458 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -110,7 +110,9 @@ def is_scalar(self) -> bool: class FunctionExpr(ExprIR): - """Polars uses *seemingly* for namespacing, but maybe I'll use for traversal? + """**Representing `Expr::Function`**. + + https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L114-L120 https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/function_expr/mod.rs#L123 """ @@ -119,6 +121,13 @@ class FunctionExpr(ExprIR): input: t.Sequence[ExprIR] function: Function + """Enum type is named `FunctionExpr` in `polars`. + + Mirroring *exactly* doesn't make much sense in OOP. + + https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/function_expr/mod.rs#L123 + """ + options: FunctionOptions """Assuming this is **either**: diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 972b504001..9e2f118b9a 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -2,87 +2,87 @@ from __future__ import annotations -from narwhals._plan.expr import FunctionExpr +from narwhals._plan.common import Function -class Abs(FunctionExpr): ... +class Abs(Function): ... -class Hist(FunctionExpr): +class Hist(Function): """Only supported for `Series` so far.""" -class NullCount(FunctionExpr): ... +class NullCount(Function): ... -class Pow(FunctionExpr): ... +class Pow(Function): ... -class FillNull(FunctionExpr): ... +class FillNull(Function): ... -class FillNullWithStrategy(FunctionExpr): +class FillNullWithStrategy(Function): """We don't support this variant in a lot of backends, so worth keeping it split out.""" -class Shift(FunctionExpr): ... +class Shift(Function): ... -class DropNulls(FunctionExpr): ... +class DropNulls(Function): ... -class Mode(FunctionExpr): ... +class Mode(Function): ... -class Skew(FunctionExpr): ... +class Skew(Function): ... -class Rank(FunctionExpr): ... +class Rank(Function): ... -class Clip(FunctionExpr): ... +class Clip(Function): ... -class CumCount(FunctionExpr): ... +class CumCount(Function): ... -class CumMin(FunctionExpr): ... +class CumMin(Function): ... -class CumMax(FunctionExpr): ... +class CumMax(Function): ... -class CumProd(FunctionExpr): ... +class CumProd(Function): ... -class Diff(FunctionExpr): ... +class Diff(Function): ... -class Unique(FunctionExpr): ... +class Unique(Function): ... -class Round(FunctionExpr): ... +class Round(Function): ... -class SumHorizontal(FunctionExpr): ... +class SumHorizontal(Function): ... -class MinHorizontal(FunctionExpr): ... +class MinHorizontal(Function): ... -class MaxHorizontal(FunctionExpr): ... +class MaxHorizontal(Function): ... -class MeanHorizontal(FunctionExpr): ... +class MeanHorizontal(Function): ... -class EwmMean(FunctionExpr): ... +class EwmMean(Function): ... -class ReplaceStrict(FunctionExpr): ... +class ReplaceStrict(Function): ... -class GatherEvery(FunctionExpr): ... +class GatherEvery(Function): ... -class MapBatches(FunctionExpr): ... +class MapBatches(Function): ... From 451931c7c3b3bfc8209ed86def510cd124a37923 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 15:59:47 +0100 Subject: [PATCH 040/368] typo --- narwhals/_plan/expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 5aef1ee458..325215df76 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -189,7 +189,7 @@ class Nth(ExprIR): class IndexColumns(ExprIR): """Renamed from `IndexColumn`. - `Nth` provides the single variant. + `Nth` provides the singlular variant. https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L80 """ From 88bc70f282a80eabe1d481be5cf4069e4a903701 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 17:22:03 +0100 Subject: [PATCH 041/368] feat: Fill out a big chunk of `functions` Interesting to see another link with (#2522) --- narwhals/_plan/__init__.py | 1 + narwhals/_plan/expr.py | 15 ++++++ narwhals/_plan/functions.py | 94 +++++++++++++++++++++++++++++++++---- narwhals/_plan/options.py | 34 ++++++++++++++ 4 files changed, 134 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/__init__.py b/narwhals/_plan/__init__.py index ea8952c8a0..0346e1d912 100644 --- a/narwhals/_plan/__init__.py +++ b/narwhals/_plan/__init__.py @@ -34,6 +34,7 @@ - https://github.com/narwhals-dev/narwhals/issues/1848 - https://github.com/narwhals-dev/narwhals/issues/2534#issuecomment-2875676729 - https://github.com/narwhals-dev/narwhals/issues/2291 +- https://github.com/narwhals-dev/narwhals/issues/2522 """ from __future__ import annotations diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 325215df76..92d9ff0164 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -8,6 +8,7 @@ if t.TYPE_CHECKING: from narwhals._plan.common import Function + from narwhals._plan.functions import MapBatches from narwhals._plan.literal import LiteralValue from narwhals._plan.operators import Operator from narwhals._plan.options import FunctionOptions @@ -136,6 +137,20 @@ class FunctionExpr(ExprIR): """ +class AnonymousFunctionExpr(ExprIR): + """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L158-L166.""" + + __slots__ = ("function", "input", "options") + + input: t.Sequence[ExprIR] + function: MapBatches + options: FunctionOptions + + @property + def is_scalar(self) -> bool: + return self.function.function_options.returns_scalar() + + class Filter(ExprIR): __slots__ = ("by", "expr") diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 9e2f118b9a..7d6670639c 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -2,7 +2,21 @@ from __future__ import annotations +from typing import TYPE_CHECKING + +from narwhals._plan.common import ExprIR from narwhals._plan.common import Function +from narwhals._plan.options import FunctionFlags +from narwhals._plan.options import FunctionOptions + +if TYPE_CHECKING: + from typing import Any + from typing import Sequence + + from narwhals._plan.options import EWMOptions + from narwhals._plan.options import RankOptions + from narwhals.dtypes import DType + from narwhals.typing import FillNullStrategy class Abs(Function): ... @@ -11,6 +25,24 @@ class Abs(Function): ... class Hist(Function): """Only supported for `Series` so far.""" + __slots__ = ("include_breakpoint",) + + include_breakpoint: bool + + +class HistBins(Hist): + """Subclasses for each variant.""" + + __slots__ = (*Hist.__slots__, "bins") + + bins: Sequence[float] + + +class HistBinCount(Hist): + __slots__ = (*Hist.__slots__, "bin_count") + + bin_count: int + class NullCount(Function): ... @@ -18,12 +50,20 @@ class NullCount(Function): ... class Pow(Function): ... -class FillNull(Function): ... +class FillNull(Function): + __slots__ = ("value",) + + value: ExprIR class FillNullWithStrategy(Function): """We don't support this variant in a lot of backends, so worth keeping it split out.""" + __slots__ = ("limit", "strategy") + + strategy: FillNullStrategy + limit: int | None + class Shift(Function): ... @@ -37,22 +77,31 @@ class Mode(Function): ... class Skew(Function): ... -class Rank(Function): ... +class Rank(Function): + __slots__ = ("options",) + + options: RankOptions class Clip(Function): ... -class CumCount(Function): ... +class CumAgg(Function): + __slots__ = ("reverse",) + + reverse: bool + +class CumCount(CumAgg): ... -class CumMin(Function): ... +class CumMin(CumAgg): ... -class CumMax(Function): ... +class CumMax(CumAgg): ... -class CumProd(Function): ... + +class CumProd(CumAgg): ... class Diff(Function): ... @@ -61,7 +110,10 @@ class Diff(Function): ... class Unique(Function): ... -class Round(Function): ... +class Round(Function): + __slots__ = ("decimals",) + + decimals: int class SumHorizontal(Function): ... @@ -76,13 +128,35 @@ class MaxHorizontal(Function): ... class MeanHorizontal(Function): ... -class EwmMean(Function): ... +class EwmMean(Function): + __slots__ = ("options",) + + options: EWMOptions -class ReplaceStrict(Function): ... +class ReplaceStrict(Function): + __slots__ = ("return_dtype",) + + return_dtype: DType | type[DType] | None class GatherEvery(Function): ... -class MapBatches(Function): ... +class MapBatches(Function): + __slots__ = ("function", "is_elementwise", "return_dtype", "returns_scalar") + + function: Any + return_dtype: DType | None + is_elementwise: bool + returns_scalar: bool + + @property + def function_options(self) -> FunctionOptions: + """https://github.com/narwhals-dev/narwhals/issues/2522.""" + options = super().function_options + if self.is_elementwise: + options = options.with_elementwise() + if self.returns_scalar: + options = options.with_flags(FunctionFlags.RETURNS_SCALAR) + return options diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index d0b6ec1941..1db9539994 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -8,6 +8,8 @@ if TYPE_CHECKING: from typing import Sequence + from narwhals.typing import RankMethod + class FunctionFlags(enum.Flag): ALLOW_GROUP_AWARE = 1 << 0 @@ -120,3 +122,35 @@ class SortMultipleOptions(Immutable): descending: Sequence[bool] nulls_last: Sequence[bool] + + +class RankOptions(Immutable): + __slots__ = ("descending", "method") + + method: RankMethod + descending: bool + + +class EWMOptions(Immutable): + """Deviates from polars, since we aren't pre-computing alpha. + + https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs#L14-L20 + """ + + __slots__ = ( + "adjust", + "alpha", + "com", + "half_life", + "ignore_nulls", + "min_samples", + "span", + ) + + com: float | None + span: float | None + half_life: float | None + alpha: float | None + adjust: bool + min_samples: int + ignore_nulls: bool From 2e4f1d049f504c39fee74a5becae4892314ea40c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 18:23:44 +0100 Subject: [PATCH 042/368] chore: Add dummy constructor to `Immutable` Purely for demonstrations --- narwhals/_plan/common.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 0960c23105..33beeca761 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -41,6 +41,31 @@ def __eq__(self, other: object) -> bool: getattr(self, name, empty) == getattr(other, name, empty) for name in slots ) + def __init__(self, **kwds: Any) -> None: + # NOTE: DUMMY CONSTRUCTOR - don't use beyond prototyping! + # Just need a quick way to demonstrate `ExprIR` and interactions + slots: set[str] = set(self.__slots__) + if not slots and not kwds: + # NOTE: Fastpath for empty slots + ... + elif slots == set(kwds): + # NOTE: Everything is as expected + for name, value in kwds.items(): + object.__setattr__(self, name, value) + elif missing := slots.difference(kwds): + msg = ( + f"{type(self).__name__!r} requires attributes {sorted(slots)!r}, \n" + f"but missing values for {sorted(missing)!r}" + ) + raise TypeError(msg) + else: + extra = set(kwds).difference(slots) + msg = ( + f"{type(self).__name__!r} only supports attributes {sorted(slots)!r}, \n" + f"but got unknown arguments {sorted(extra)!r}" + ) + raise TypeError(msg) + class ExprIR(Immutable): """Anything that can be a node on a graph of expressions.""" From e62c96d71a5f4bbea5ad1c7b65f42c1c9d3fe136 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 18:59:51 +0100 Subject: [PATCH 043/368] move `Not` --- narwhals/_plan/boolean.py | 8 ++++++++ narwhals/_plan/operators.py | 4 ---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/boolean.py index 4ce751facc..c8e281ac3b 100644 --- a/narwhals/_plan/boolean.py +++ b/narwhals/_plan/boolean.py @@ -111,3 +111,11 @@ class IsUnique(BooleanFunction): @property def function_options(self) -> FunctionOptions: return FunctionOptions.length_preserving() + + +class Not(BooleanFunction): + """`__invert__`.""" + + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index 2a4a904d73..e69fa19dbf 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -46,7 +46,3 @@ class And(Operator): ... class Or(Operator): ... - - -class Not(Operator): - """`__invert__`.""" From 071f8979c6640f698ad0a58cb339c960090a8d95 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 19:56:49 +0100 Subject: [PATCH 044/368] feat: Add a bunch of reprs --- narwhals/_plan/expr.py | 46 ++++++++++++++++++++++++++++++++++++- narwhals/_plan/literal.py | 8 +++++++ narwhals/_plan/operators.py | 23 ++++++++++++++++++- 3 files changed, 75 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 92d9ff0164..3c171c08ae 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -28,18 +28,27 @@ class Alias(ExprIR): def is_scalar(self) -> bool: return self.expr.is_scalar + def __repr__(self) -> str: + return f"{self.expr!r}.alias({self.name!r})" + class Column(ExprIR): __slots__ = ("name",) name: str + def __repr__(self) -> str: + return f"col({self.name!r})" + class Columns(ExprIR): __slots__ = ("names",) names: t.Sequence[str] + def __repr__(self) -> str: + return f"cols({self.names!r})" + class Literal(ExprIR): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L81.""" @@ -52,6 +61,9 @@ class Literal(ExprIR): def is_scalar(self) -> bool: return self.value.is_scalar + def __repr__(self) -> str: + return f"lit({self.value!r})" + class BinaryExpr(ExprIR): """Application of two exprs via an `Operator`. @@ -73,6 +85,9 @@ class BinaryExpr(ExprIR): def is_scalar(self) -> bool: return self.left.is_scalar and self.right.is_scalar + def __repr__(self) -> str: + return f"[({self.left!r}) {self.op!r} ({self.right!r})]" + class Cast(ExprIR): __slots__ = ("dtype", "expr") @@ -84,6 +99,9 @@ class Cast(ExprIR): def is_scalar(self) -> bool: return self.expr.is_scalar + def __repr__(self) -> str: + return f"{self.expr!r}.cast({self.dtype!r})" + class Sort(ExprIR): __slots__ = ("expr", "options") @@ -95,6 +113,10 @@ class Sort(ExprIR): def is_scalar(self) -> bool: return self.expr.is_scalar + def __repr__(self) -> str: + direction = "desc" if self.options.descending else "asc" + return f"{self.expr!r}.sort({direction})" + class SortBy(ExprIR): """https://github.com/narwhals-dev/narwhals/issues/2534.""" @@ -109,6 +131,9 @@ class SortBy(ExprIR): def is_scalar(self) -> bool: return self.expr.is_scalar + def __repr__(self) -> str: + return f"{self.expr!r}.sort_by(by={self.by!r}, options={self.options!r})" + class FunctionExpr(ExprIR): """**Representing `Expr::Function`**. @@ -161,6 +186,9 @@ class Filter(ExprIR): def is_scalar(self) -> bool: return self.expr.is_scalar and self.by.is_scalar + def __repr__(self) -> str: + return f"{self.expr!r}.filter({self.by!r})" + class WindowExpr(ExprIR): """https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L129-L136.""" @@ -188,6 +216,9 @@ class Len(ExprIR): def is_scalar(self) -> bool: return True + def __repr__(self) -> str: + return "len()" + class Exclude(ExprIR): __slots__ = ("names",) @@ -200,6 +231,9 @@ class Nth(ExprIR): index: int + def __repr__(self) -> str: + return f"nth({self.index})" + class IndexColumns(ExprIR): """Renamed from `IndexColumn`. @@ -213,8 +247,18 @@ class IndexColumns(ExprIR): indices: t.Sequence[int] + def __repr__(self) -> str: + return f"index_columns({self.indices!r})" + + +class All(ExprIR): + """Aka Wildcard (`pl.all()` or `pl.col("*")`). + + https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L137 + """ -class All(ExprIR): ... + def __repr__(self) -> str: + return "*" # NOTE: by_dtype, matches, numeric, boolean, string, categorical, datetime, all diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index bbf3c3a093..7cd484201c 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -23,6 +23,11 @@ class ScalarLiteral(LiteralValue): def is_scalar(self) -> bool: return True + def __repr__(self) -> str: + if self.value is not None: + return f"{type(self.value).__name__}: {self.value}" + return "null" + class SeriesLiteral(LiteralValue): """We already need this. @@ -34,6 +39,9 @@ class SeriesLiteral(LiteralValue): value: DummySeries + def __repr__(self) -> str: + return "Series" + class RangeLiteral(LiteralValue): """Don't need yet, but might push forward the discussions. diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index e69fa19dbf..d1d000734a 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -3,7 +3,28 @@ from narwhals._plan.common import Immutable -class Operator(Immutable): ... +class Operator(Immutable): + def __repr__(self) -> str: + tp = type(self) + if tp is Operator: + return "Operator" + m = { + Eq: "==", + NotEq: "!=", + Lt: "<", + LtEq: "<=", + Gt: ">", + GtEq: ">=", + Add: "+", + Sub: "-", + Multiply: "*", + TrueDivide: "/", + FloorDivide: "//", + Modulus: "%", + And: "&", + Or: "|", + } + return m[tp] class Eq(Operator): ... From ecab1d887b816396408461818afce7c42234175f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 20:27:49 +0100 Subject: [PATCH 045/368] fix: Mark al `Agg` as scalar --- narwhals/_plan/aggregation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index 87e409c104..4836431150 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -13,6 +13,10 @@ class Agg(ExprIR): expr: ExprIR + @property + def is_scalar(self) -> bool: + return True + class Count(Agg): ... From 89d5d15144442ede4334181d2941a32f1a33c7ba Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 20:28:17 +0100 Subject: [PATCH 046/368] refactor: Inherit slot names --- narwhals/_plan/aggregation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index 4836431150..1ae3363036 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -45,14 +45,14 @@ class NUnique(Agg): ... class Quantile(Agg): - __slots__ = ("expr", "interpolation", "quantile") + __slots__ = (*Agg.__slots__, "interpolation", "quantile") quantile: ExprIR interpolation: RollingInterpolationMethod class Std(Agg): - __slots__ = ("ddof", "expr") + __slots__ = (*Agg.__slots__, "ddof") ddof: int @@ -61,7 +61,7 @@ class Sum(Agg): ... class Var(Agg): - __slots__ = ("ddof", "expr") + __slots__ = (*Agg.__slots__, "ddof") ddof: int From 39705026bf6087b8c02dcba239791b2fa5cb44d3 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 20:51:24 +0100 Subject: [PATCH 047/368] feat: Add all `functions` options --- narwhals/_plan/functions.py | 121 +++++++++++++++++++++++++++++++----- 1 file changed, 106 insertions(+), 15 deletions(-) diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 7d6670639c..dd63672367 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -19,7 +19,10 @@ from narwhals.typing import FillNullStrategy -class Abs(Function): ... +class Abs(Function): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() class Hist(Function): @@ -29,6 +32,10 @@ class Hist(Function): include_breakpoint: bool + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.groupwise() + class HistBins(Hist): """Subclasses for each variant.""" @@ -44,10 +51,16 @@ class HistBinCount(Hist): bin_count: int -class NullCount(Function): ... +class NullCount(Function): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.aggregation() -class Pow(Function): ... +class Pow(Function): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() class FillNull(Function): @@ -55,6 +68,10 @@ class FillNull(Function): value: ExprIR + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() + class FillNullWithStrategy(Function): """We don't support this variant in a lot of backends, so worth keeping it split out.""" @@ -64,17 +81,39 @@ class FillNullWithStrategy(Function): strategy: FillNullStrategy limit: int | None + @property + def function_options(self) -> FunctionOptions: + # NOTE: We don't support these strategies yet + # but might be good to encode this difference now + return ( + FunctionOptions.elementwise() + if self.strategy in {"one", "zero"} + else FunctionOptions.groupwise() + ) -class Shift(Function): ... +class Shift(Function): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.length_preserving() -class DropNulls(Function): ... +class DropNulls(Function): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.row_separable() -class Mode(Function): ... +class Mode(Function): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.groupwise() -class Skew(Function): ... + +class Skew(Function): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.aggregation() class Rank(Function): @@ -82,8 +121,15 @@ class Rank(Function): options: RankOptions + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.groupwise() + -class Clip(Function): ... +class Clip(Function): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() class CumAgg(Function): @@ -91,6 +137,10 @@ class CumAgg(Function): reverse: bool + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.length_preserving() + class CumCount(CumAgg): ... @@ -104,10 +154,16 @@ class CumMax(CumAgg): ... class CumProd(CumAgg): ... -class Diff(Function): ... +class Diff(Function): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.length_preserving() -class Unique(Function): ... +class Unique(Function): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.groupwise() class Round(Function): @@ -115,17 +171,41 @@ class Round(Function): decimals: int + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() -class SumHorizontal(Function): ... +class SumHorizontal(Function): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise().with_flags( + FunctionFlags.INPUT_WILDCARD_EXPANSION + ) -class MinHorizontal(Function): ... +class MinHorizontal(Function): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise().with_flags( + FunctionFlags.INPUT_WILDCARD_EXPANSION + ) -class MaxHorizontal(Function): ... +class MaxHorizontal(Function): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise().with_flags( + FunctionFlags.INPUT_WILDCARD_EXPANSION + ) -class MeanHorizontal(Function): ... + +class MeanHorizontal(Function): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise().with_flags( + FunctionFlags.INPUT_WILDCARD_EXPANSION + ) class EwmMean(Function): @@ -133,14 +213,25 @@ class EwmMean(Function): options: EWMOptions + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.length_preserving() + class ReplaceStrict(Function): __slots__ = ("return_dtype",) return_dtype: DType | type[DType] | None + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() + -class GatherEvery(Function): ... +class GatherEvery(Function): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.groupwise() class MapBatches(Function): From c4cec97836f5ff0ad2c7e83c79f1be97aacbbb90 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 21:21:14 +0100 Subject: [PATCH 048/368] feat: Add even more reprs --- narwhals/_plan/aggregation.py | 8 ++++++++ narwhals/_plan/boolean.py | 23 ++++++++++++++++++++++- narwhals/_plan/categorical.py | 3 +++ narwhals/_plan/lists.py | 3 +++ narwhals/_plan/literal.py | 2 +- narwhals/_plan/struct.py | 3 +++ 6 files changed, 40 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index 1ae3363036..d03034fb0a 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -17,6 +17,14 @@ class Agg(ExprIR): def is_scalar(self) -> bool: return True + def __repr__(self) -> str: + tp = type(self) + if tp in {Agg, OrderableAgg}: + return tp.__name__ + m = {ArgMin: "arg_min", ArgMax: "arg_max", NUnique: "n_unique"} + name = m.get(tp, tp.__name__.lower()) + return f"{self.expr!r}.{name}()" + class Count(Agg): ... diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/boolean.py index c8e281ac3b..0db8a55b66 100644 --- a/narwhals/_plan/boolean.py +++ b/narwhals/_plan/boolean.py @@ -12,7 +12,28 @@ from narwhals.typing import ClosedInterval -class BooleanFunction(Function): ... +class BooleanFunction(Function): + def __repr__(self) -> str: + tp = type(self) + if tp is BooleanFunction: + return tp.__name__ + m = { + All: "all", + Any: "any", + AllHorizontal: "all_horizontal", + AnyHorizontal: "any_horizontal", + IsBetween: "is_between", + IsDuplicated: "is_duplicated", + IsFinite: "is_finite", + IsNan: "is_nan", + IsNull: "is_null", + IsFirstDistinct: "is_first_distinct", + IsLastDistinct: "is_last_distinct", + IsUnique: "is_unique", + IsIn: "is_in", + Not: "not", + } + return m[tp] class All(BooleanFunction): diff --git a/narwhals/_plan/categorical.py b/narwhals/_plan/categorical.py index 834211c63e..44c2c7023c 100644 --- a/narwhals/_plan/categorical.py +++ b/narwhals/_plan/categorical.py @@ -14,3 +14,6 @@ class GetCategories(CategoricalFunction): def function_options(self) -> FunctionOptions: """https://github.com/pola-rs/polars/blob/62257860a43ec44a638e8492ed2cf98a49c05f2e/crates/polars-plan/src/dsl/function_expr/cat.rs#L41.""" return FunctionOptions.groupwise() + + def __repr__(self) -> str: + return "cat.get_categories" diff --git a/narwhals/_plan/lists.py b/narwhals/_plan/lists.py index 97d55c982f..88da2223ab 100644 --- a/narwhals/_plan/lists.py +++ b/narwhals/_plan/lists.py @@ -13,3 +13,6 @@ class Len(ListFunction): @property def function_options(self) -> FunctionOptions: return FunctionOptions.elementwise() + + def __repr__(self) -> str: + return "list.len" diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index 7cd484201c..a6cebe4609 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -25,7 +25,7 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: if self.value is not None: - return f"{type(self.value).__name__}: {self.value}" + return f"{type(self.value).__name__}: {self.value!s}" return "null" diff --git a/narwhals/_plan/struct.py b/narwhals/_plan/struct.py index f30e11aa64..fa7dd3dc07 100644 --- a/narwhals/_plan/struct.py +++ b/narwhals/_plan/struct.py @@ -17,3 +17,6 @@ class FieldByName(StructFunction): @property def function_options(self) -> FunctionOptions: return FunctionOptions.elementwise() + + def __repr__(self) -> str: + return f"struct.field_by_name({self.name!r})" From 05decf1502167c61391fe83bed5edb0bee888f79 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 16 May 2025 21:35:24 +0100 Subject: [PATCH 049/368] feat(typing): Get more pedantic bye bye mutability --- narwhals/_plan/common.py | 15 +++++++++++++++ narwhals/_plan/expr.py | 15 ++++++++------- narwhals/_plan/functions.py | 9 ++++----- narwhals/_plan/literal.py | 4 ++-- narwhals/_plan/options.py | 7 +++---- 5 files changed, 32 insertions(+), 18 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 33beeca761..69e4f59d8a 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -1,16 +1,31 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing import TypeVar if TYPE_CHECKING: from typing import Any + from typing import Callable from typing_extensions import Never from typing_extensions import Self + from typing_extensions import TypeAlias from narwhals._plan.options import FunctionOptions +T = TypeVar("T") + +Seq: TypeAlias = "tuple[T,...]" +"""Immutable Sequence. + +Using instead of `Sequence`, as a `list` can be passed there (can't break immutability promise). +""" + +Udf: TypeAlias = "Callable[[Any], Any]" +"""Placeholder for `map_batches(function=...)`.""" + + class Immutable: __slots__ = () diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 3c171c08ae..29e2b1f8db 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -8,6 +8,7 @@ if t.TYPE_CHECKING: from narwhals._plan.common import Function + from narwhals._plan.common import Seq from narwhals._plan.functions import MapBatches from narwhals._plan.literal import LiteralValue from narwhals._plan.operators import Operator @@ -44,7 +45,7 @@ def __repr__(self) -> str: class Columns(ExprIR): __slots__ = ("names",) - names: t.Sequence[str] + names: Seq[str] def __repr__(self) -> str: return f"cols({self.names!r})" @@ -124,7 +125,7 @@ class SortBy(ExprIR): __slots__ = ("by", "expr", "options") expr: ExprIR - by: t.Sequence[ExprIR] + by: Seq[ExprIR] options: SortMultipleOptions @property @@ -145,7 +146,7 @@ class FunctionExpr(ExprIR): __slots__ = ("function", "input", "options") - input: t.Sequence[ExprIR] + input: Seq[ExprIR] function: Function """Enum type is named `FunctionExpr` in `polars`. @@ -167,7 +168,7 @@ class AnonymousFunctionExpr(ExprIR): __slots__ = ("function", "input", "options") - input: t.Sequence[ExprIR] + input: Seq[ExprIR] function: MapBatches options: FunctionOptions @@ -198,7 +199,7 @@ class WindowExpr(ExprIR): expr: ExprIR """Renamed from `function`.""" - partition_by: t.Sequence[ExprIR] + partition_by: Seq[ExprIR] order_by: tuple[ExprIR, SortOptions] | None options: Window """Little confused on the nesting. @@ -223,7 +224,7 @@ def __repr__(self) -> str: class Exclude(ExprIR): __slots__ = ("names",) - names: t.Sequence[str] + names: Seq[str] class Nth(ExprIR): @@ -245,7 +246,7 @@ class IndexColumns(ExprIR): __slots__ = ("indices",) - indices: t.Sequence[int] + indices: Seq[int] def __repr__(self) -> str: return f"index_columns({self.indices!r})" diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index dd63672367..fc9b05b657 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -10,9 +10,8 @@ from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: - from typing import Any - from typing import Sequence - + from narwhals._plan.common import Seq + from narwhals._plan.common import Udf from narwhals._plan.options import EWMOptions from narwhals._plan.options import RankOptions from narwhals.dtypes import DType @@ -42,7 +41,7 @@ class HistBins(Hist): __slots__ = (*Hist.__slots__, "bins") - bins: Sequence[float] + bins: Seq[float] class HistBinCount(Hist): @@ -237,7 +236,7 @@ def function_options(self) -> FunctionOptions: class MapBatches(Function): __slots__ = ("function", "is_elementwise", "return_dtype", "returns_scalar") - function: Any + function: Udf return_dtype: DType | None is_elementwise: bool returns_scalar: bool diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index a6cebe4609..08d03d7cb6 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: from narwhals._plan.common import DummySeries from narwhals.dtypes import DType - from narwhals.typing import PythonLiteral + from narwhals.typing import NonNestedLiteral class LiteralValue(ExprIR): @@ -17,7 +17,7 @@ class LiteralValue(ExprIR): class ScalarLiteral(LiteralValue): __slots__ = ("value",) - value: PythonLiteral + value: NonNestedLiteral @property def is_scalar(self) -> bool: diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 1db9539994..a5727efc6c 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -6,8 +6,7 @@ from narwhals._plan.common import Immutable if TYPE_CHECKING: - from typing import Sequence - + from narwhals._plan.common import Seq from narwhals.typing import RankMethod @@ -120,8 +119,8 @@ class SortOptions(Immutable): class SortMultipleOptions(Immutable): __slots__ = ("descending", "nulls_last") - descending: Sequence[bool] - nulls_last: Sequence[bool] + descending: Seq[bool] + nulls_last: Seq[bool] class RankOptions(Immutable): From a8bf8d28fa5764028a0aca2f44be1502c299a124 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 17 May 2025 11:34:48 +0100 Subject: [PATCH 050/368] chore(typing): IDE support for demo constructors --- narwhals/_plan/common.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 69e4f59d8a..bab3d8c5e2 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -10,8 +10,35 @@ from typing_extensions import Never from typing_extensions import Self from typing_extensions import TypeAlias + from typing_extensions import dataclass_transform from narwhals._plan.options import FunctionOptions +else: + # NOTE: This isn't important to the proposal, just wanted IDE support + # for the **temporary** constructors. + # It is interesting how much boilerplate this avoids though 🤔 + # https://docs.python.org/3/library/typing.html#typing.dataclass_transform + def dataclass_transform( + *, + eq_default: bool = True, + order_default: bool = False, + kw_only_default: bool = False, + frozen_default: bool = False, + field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = (), + **kwargs: Any, + ) -> Callable[[T], T]: + def decorator(cls_or_fn: T) -> T: + cls_or_fn.__dataclass_transform__ = { + "eq_default": eq_default, + "order_default": order_default, + "kw_only_default": kw_only_default, + "frozen_default": frozen_default, + "field_specifiers": field_specifiers, + "kwargs": kwargs, + } + return cls_or_fn + + return decorator T = TypeVar("T") @@ -26,6 +53,7 @@ """Placeholder for `map_batches(function=...)`.""" +@dataclass_transform(kw_only_default=True, frozen_default=True) class Immutable: __slots__ = () From 6678fa6ff04ecbfc79c537a04650b2ca9bb97cd9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 17 May 2025 11:38:41 +0100 Subject: [PATCH 051/368] fix(typing): help `mypy` --- narwhals/_plan/boolean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/boolean.py index 0db8a55b66..c92d032754 100644 --- a/narwhals/_plan/boolean.py +++ b/narwhals/_plan/boolean.py @@ -17,7 +17,7 @@ def __repr__(self) -> str: tp = type(self) if tp is BooleanFunction: return tp.__name__ - m = { + m: dict[type[BooleanFunction], str] = { All: "all", Any: "any", AllHorizontal: "all_horizontal", From 290d5234acfe62f41466d44aeaa019e8477a93a5 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 17 May 2025 11:47:37 +0100 Subject: [PATCH 052/368] docs: Mark up `scalar_kwargs` from (#2555) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Since we're already doing some extra handling for these, switching to explicit nodes that **always** contain them would be helpful 🙂 --- narwhals/_plan/__init__.py | 1 + narwhals/_plan/aggregation.py | 2 ++ narwhals/_plan/functions.py | 11 ++++++++++- narwhals/_plan/options.py | 2 ++ 4 files changed, 15 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/__init__.py b/narwhals/_plan/__init__.py index 0346e1d912..9a5ca74e61 100644 --- a/narwhals/_plan/__init__.py +++ b/narwhals/_plan/__init__.py @@ -35,6 +35,7 @@ - https://github.com/narwhals-dev/narwhals/issues/2534#issuecomment-2875676729 - https://github.com/narwhals-dev/narwhals/issues/2291 - https://github.com/narwhals-dev/narwhals/issues/2522 +- https://github.com/narwhals-dev/narwhals/pull/2555 """ from __future__ import annotations diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index d03034fb0a..4d461ed04b 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -63,6 +63,7 @@ class Std(Agg): __slots__ = (*Agg.__slots__, "ddof") ddof: int + """https://github.com/narwhals-dev/narwhals/pull/2555""" class Sum(Agg): ... @@ -72,6 +73,7 @@ class Var(Agg): __slots__ = (*Agg.__slots__, "ddof") ddof: int + """https://github.com/narwhals-dev/narwhals/pull/2555""" class OrderableAgg(Agg): ... diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index fc9b05b657..fd0dcf86e9 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -73,7 +73,10 @@ def function_options(self) -> FunctionOptions: class FillNullWithStrategy(Function): - """We don't support this variant in a lot of backends, so worth keeping it split out.""" + """We don't support this variant in a lot of backends, so worth keeping it split out. + + https://github.com/narwhals-dev/narwhals/pull/2555 + """ __slots__ = ("limit", "strategy") @@ -92,6 +95,11 @@ def function_options(self) -> FunctionOptions: class Shift(Function): + __slots__ = ("n",) + + n: int + """https://github.com/narwhals-dev/narwhals/pull/2555""" + @property def function_options(self) -> FunctionOptions: return FunctionOptions.length_preserving() @@ -135,6 +143,7 @@ class CumAgg(Function): __slots__ = ("reverse",) reverse: bool + """https://github.com/narwhals-dev/narwhals/pull/2555""" @property def function_options(self) -> FunctionOptions: diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index a5727efc6c..a4dac01244 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -124,6 +124,8 @@ class SortMultipleOptions(Immutable): class RankOptions(Immutable): + """https://github.com/narwhals-dev/narwhals/pull/2555.""" + __slots__ = ("descending", "method") method: RankMethod From cc0fe061e6ccb3cb9ae42d2a988f25289656d053 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 17 May 2025 15:04:18 +0100 Subject: [PATCH 053/368] feat: Completely redo rolling - Realised I was modelling [`Expr.rolling`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.rolling.html) - (#2544) is somewhat related to the direction I was heading in before --- narwhals/_plan/expr.py | 28 ++++++++++++++++++++++++---- narwhals/_plan/functions.py | 24 ++++++++++++++++++++++++ narwhals/_plan/options.py | 19 +++++++++++++++++++ narwhals/_plan/window.py | 24 +++++++----------------- 4 files changed, 74 insertions(+), 21 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 29e2b1f8db..ea81e2f8eb 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -10,6 +10,7 @@ from narwhals._plan.common import Function from narwhals._plan.common import Seq from narwhals._plan.functions import MapBatches + from narwhals._plan.functions import RollingWindow from narwhals._plan.literal import LiteralValue from narwhals._plan.operators import Operator from narwhals._plan.options import FunctionOptions @@ -18,6 +19,9 @@ from narwhals._plan.window import Window from narwhals.dtypes import DType +_FunctionT = t.TypeVar("_FunctionT", bound="Function") +_RollingT = t.TypeVar("_RollingT", bound="RollingWindow") + class Alias(ExprIR): __slots__ = ("expr", "name") @@ -136,7 +140,7 @@ def __repr__(self) -> str: return f"{self.expr!r}.sort_by(by={self.by!r}, options={self.options!r})" -class FunctionExpr(ExprIR): +class FunctionExpr(ExprIR, t.Generic[_FunctionT]): """**Representing `Expr::Function`**. https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L114-L120 @@ -147,7 +151,7 @@ class FunctionExpr(ExprIR): __slots__ = ("function", "input", "options") input: Seq[ExprIR] - function: Function + function: _FunctionT """Enum type is named `FunctionExpr` in `polars`. Mirroring *exactly* doesn't make much sense in OOP. @@ -163,6 +167,9 @@ class FunctionExpr(ExprIR): """ +class RollingExpr(FunctionExpr[_RollingT]): ... + + class AnonymousFunctionExpr(ExprIR): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L158-L166.""" @@ -192,12 +199,23 @@ def __repr__(self) -> str: class WindowExpr(ExprIR): - """https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L129-L136.""" + """A fully specified `.over()`, that occured after another expression. + + I think we want variants for partitioned, ordered, both. + + Related: + - https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L129-L136 + - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/mod.rs#L835-L838 + - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/mod.rs#L840-L876 + """ __slots__ = ("expr", "options", "order_by", "partition_by") expr: ExprIR - """Renamed from `function`.""" + """Renamed from `function`. + + For lazy backends, this should be the only place we allow `rolling_*`, `cum_*`. + """ partition_by: Seq[ExprIR] order_by: tuple[ExprIR, SortOptions] | None @@ -206,6 +224,8 @@ class WindowExpr(ExprIR): - We don't allow choosing `WindowMapping` kinds - Haven't ventured into rolling much yet + - Turns out this is for `Expr.rolling` (not `Expr.rolling_`) + - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/mod.rs#L879-L888 Expr::Window { options: WindowType::Over(WindowMapping) } Expr::Window { options: WindowType::Rolling(RollingGroupOptions) } diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index fd0dcf86e9..c2a8a69ab9 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -14,6 +14,7 @@ from narwhals._plan.common import Udf from narwhals._plan.options import EWMOptions from narwhals._plan.options import RankOptions + from narwhals._plan.options import RollingOptionsFixedWindow from narwhals.dtypes import DType from narwhals.typing import FillNullStrategy @@ -150,6 +151,17 @@ def function_options(self) -> FunctionOptions: return FunctionOptions.length_preserving() +class RollingWindow(Function): + __slots__ = ("options",) + + options: RollingOptionsFixedWindow + + @property + def function_options(self) -> FunctionOptions: + """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/function_expr/mod.rs#L1276.""" + return FunctionOptions.length_preserving() + + class CumCount(CumAgg): ... @@ -162,6 +174,18 @@ class CumMax(CumAgg): ... class CumProd(CumAgg): ... +class RollingSum(RollingWindow): ... + + +class RollingMean(RollingWindow): ... + + +class RollingVar(RollingWindow): ... + + +class RollingStd(RollingWindow): ... + + class Diff(Function): @property def function_options(self) -> FunctionOptions: diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index a4dac01244..0517cf82bc 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -155,3 +155,22 @@ class EWMOptions(Immutable): adjust: bool min_samples: int ignore_nulls: bool + + +class RollingVarParams(Immutable): + __slots__ = ("ddof",) + + ddof: int + + +class RollingOptionsFixedWindow(Immutable): + """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-core/src/chunked_array/ops/rolling_window.rs#L10-L23.""" + + __slots__ = ("center", "fn_params", "min_samples", "window_size") + + window_size: int + min_samples: int + """Renamed from `min_periods`, re-uses `window_size` if null.""" + + center: bool + fn_params: RollingVarParams | None diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py index 172697182e..61097efe2d 100644 --- a/narwhals/_plan/window.py +++ b/narwhals/_plan/window.py @@ -1,4 +1,4 @@ -"""TODO: Attributes.""" +"""TODO: Figure out what `Over` should be holding or skip it and go straight to `WindowExpr`.""" from __future__ import annotations @@ -12,19 +12,9 @@ class Window(ExprIR): """ -class OverWindow(Window): ... - - -class RollingWindow(Window): ... - - -class RollingSum(RollingWindow): ... - - -class RollingMean(RollingWindow): ... - - -class RollingVar(RollingWindow): ... - - -class RollingStd(RollingWindow): ... +# TODO @dangotbanned: What are all the variants we have code paths for? +# - Over has *at least* (partition_by,), (order_by,), (partition_by, order_by), + options +# - `_plan.expr.WindowExpr` has: +# - expr (last node) +# - partition_by, optional order_by, `options` which is one of these classes? +class Over(Window): ... From 39bdaa9c934d3e3cb612512aaf89868cc3bad0ab Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 17 May 2025 18:44:18 +0100 Subject: [PATCH 054/368] feat: Mock up more of the entry api --- narwhals/_plan/common.py | 50 +++++++++++++++++++++ narwhals/_plan/demo.py | 94 +++++++++++++++++++++++++++++++++++++++ narwhals/_plan/literal.py | 3 +- 3 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 narwhals/_plan/demo.py diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index bab3d8c5e2..f213ececed 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -13,6 +13,8 @@ from typing_extensions import dataclass_transform from narwhals._plan.options import FunctionOptions + from narwhals.dtypes import DType + else: # NOTE: This isn't important to the proposal, just wanted IDE support # for the **temporary** constructors. @@ -135,6 +137,54 @@ def _from_ir(cls, ir: ExprIR, /) -> Self: obj._ir = ir return obj + def alias(self, name: str) -> Self: + from narwhals._plan.expr import Alias + + return self._from_ir(Alias(expr=self._ir, name=name)) + + def cast(self, dtype: DType | type[DType]) -> Self: + from narwhals._plan.expr import Cast + from narwhals.dtypes import DType + from narwhals.dtypes import Unknown + + dtype = dtype if isinstance(dtype, DType) else Unknown() + return self._from_ir(Cast(expr=self._ir, dtype=dtype)) + + def count(self) -> Self: + from narwhals._plan.aggregation import Count + + return self._from_ir(Count(expr=self._ir)) + + def max(self) -> Self: + from narwhals._plan.aggregation import Max + + return self._from_ir(Max(expr=self._ir)) + + def mean(self) -> Self: + from narwhals._plan.aggregation import Mean + + return self._from_ir(Mean(expr=self._ir)) + + def min(self) -> Self: + from narwhals._plan.aggregation import Min + + return self._from_ir(Min(expr=self._ir)) + + def median(self) -> Self: + from narwhals._plan.aggregation import Median + + return self._from_ir(Median(expr=self._ir)) + + def n_unique(self) -> Self: + from narwhals._plan.aggregation import NUnique + + return self._from_ir(NUnique(expr=self._ir)) + + def sum(self) -> Self: + from narwhals._plan.aggregation import Sum + + return self._from_ir(Sum(expr=self._ir)) + class DummyCompliantExpr: _ir: ExprIR diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py new file mode 100644 index 0000000000..dc784229c2 --- /dev/null +++ b/narwhals/_plan/demo.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import builtins as bltns +import typing as t + +from narwhals._plan.expr import All +from narwhals._plan.expr import Column +from narwhals._plan.expr import Columns +from narwhals._plan.expr import FunctionExpr +from narwhals._plan.expr import IndexColumns +from narwhals._plan.expr import Len +from narwhals._plan.expr import Literal +from narwhals._plan.expr import Nth +from narwhals._plan.functions import SumHorizontal +from narwhals._plan.literal import ScalarLiteral +from narwhals._plan.options import FunctionOptions + +if t.TYPE_CHECKING: + from narwhals._plan.common import DummyExpr + from narwhals.dtypes import DType + from narwhals.typing import NonNestedLiteral + + +def col(*names: str | t.Iterable[str]) -> DummyExpr: + from narwhals.utils import flatten + + flat_names = tuple(flatten(names)) + node = ( + Column(name=flat_names[0]) + if bltns.len(flat_names) == 1 + else Columns(names=flat_names) + ) + return node.to_narwhals() + + +def nth(*indices: int | t.Sequence[int]) -> DummyExpr: + from narwhals.utils import flatten + + flat_indices = tuple(flatten(indices)) + node = ( + Nth(index=flat_indices[0]) + if bltns.len(flat_indices) == 1 + else IndexColumns(indices=flat_indices) + ) + return node.to_narwhals() + + +def lit(value: NonNestedLiteral, dtype: DType | type[DType] | None = None) -> DummyExpr: + from narwhals.dtypes import DType + from narwhals.dtypes import Unknown + + if dtype is None or not isinstance(dtype, DType): + dtype = Unknown() + return Literal(value=ScalarLiteral(value=value, dtype=dtype)).to_narwhals() + + +def len() -> DummyExpr: + return Len().to_narwhals() + + +def all() -> DummyExpr: + return All().to_narwhals() + + +def max(*columns: str) -> DummyExpr: + return col(columns).max() + + +def mean(*columns: str) -> DummyExpr: + return col(columns).mean() + + +def min(*columns: str) -> DummyExpr: + return col(columns).min() + + +def median(*columns: str) -> DummyExpr: + return col(columns).median() + + +def sum(*columns: str) -> DummyExpr: + return col(columns).sum() + + +def sum_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr: + from narwhals.utils import flatten + + flat_exprs = tuple(flatten(exprs)) + # NOTE: Still need to figure out how these should be generated + # Feel like it should be the union of `input` & `function` + PLACEHOLDER = FunctionOptions.default() # noqa: N806 + return FunctionExpr( + input=flat_exprs, function=SumHorizontal(), options=PLACEHOLDER + ).to_narwhals() diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index 08d03d7cb6..56e5dbfbd0 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -15,9 +15,10 @@ class LiteralValue(ExprIR): class ScalarLiteral(LiteralValue): - __slots__ = ("value",) + __slots__ = ("dtype", "value") value: NonNestedLiteral + dtype: DType @property def is_scalar(self) -> bool: From f2ee6889aa254907ea762b1e600ae6a14511657c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 17 May 2025 19:44:22 +0100 Subject: [PATCH 055/368] chore: Tweak reprs --- narwhals/_plan/common.py | 3 +++ narwhals/_plan/expr.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index f213ececed..61a3237d30 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -131,6 +131,9 @@ def is_scalar(self) -> bool: class DummyExpr: _ir: ExprIR + def __repr__(self) -> str: + return f"Narwhals DummyExpr:\n{self._ir!r}" + @classmethod def _from_ir(cls, ir: ExprIR, /) -> Self: obj = cls.__new__(cls) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index ea81e2f8eb..75aad29b80 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -52,7 +52,7 @@ class Columns(ExprIR): names: Seq[str] def __repr__(self) -> str: - return f"cols({self.names!r})" + return f"cols({list(self.names)!r})" class Literal(ExprIR): From db19822200463d200a0492d46f1ab0b6421ddd63 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 12:00:30 +0100 Subject: [PATCH 056/368] feat: Build out `Series`/`Literal` support some more Gets us closer to the python side of `polars` --- narwhals/_plan/common.py | 46 ++++++++++++++++++++++++++++++++++++++- narwhals/_plan/demo.py | 12 +++++++--- narwhals/_plan/expr.py | 4 ++++ narwhals/_plan/literal.py | 14 ++++++++++++ 4 files changed, 72 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 61a3237d30..5f4ef3344f 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -14,6 +14,7 @@ from narwhals._plan.options import FunctionOptions from narwhals.dtypes import DType + from narwhals.typing import NativeSeries else: # NOTE: This isn't important to the proposal, just wanted IDE support @@ -199,7 +200,50 @@ def _from_ir(cls, ir: ExprIR, /) -> Self: return obj -class DummySeries: ... +class DummySeries: + _compliant: DummyCompliantSeries + + @property + def dtype(self) -> DType: + return self._compliant.dtype + + @property + def name(self) -> str: + return self._compliant.name + + @classmethod + def from_native(cls, native: NativeSeries, /) -> Self: + obj = cls.__new__(cls) + obj._compliant = DummyCompliantSeries.from_native(native) + return obj + + +class DummyCompliantSeries: + _native: NativeSeries + _name: str + + @property + def dtype(self) -> DType: + from narwhals.dtypes import Float64 + + return Float64() + + @property + def name(self) -> str: + return self._name + + @classmethod + def from_native(cls, native: NativeSeries, /) -> Self: + from narwhals.utils import _hasattr_static + + name: str = "" + + if _hasattr_static(native, "name"): + name = getattr(native, "name", name) + obj = cls.__new__(cls) + obj._native = native + obj._name = name + return obj class Function(ExprIR): diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index dc784229c2..3a6d778b75 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -3,16 +3,17 @@ import builtins as bltns import typing as t +from narwhals._plan.common import DummySeries from narwhals._plan.expr import All from narwhals._plan.expr import Column from narwhals._plan.expr import Columns from narwhals._plan.expr import FunctionExpr from narwhals._plan.expr import IndexColumns from narwhals._plan.expr import Len -from narwhals._plan.expr import Literal from narwhals._plan.expr import Nth from narwhals._plan.functions import SumHorizontal from narwhals._plan.literal import ScalarLiteral +from narwhals._plan.literal import SeriesLiteral from narwhals._plan.options import FunctionOptions if t.TYPE_CHECKING: @@ -45,13 +46,18 @@ def nth(*indices: int | t.Sequence[int]) -> DummyExpr: return node.to_narwhals() -def lit(value: NonNestedLiteral, dtype: DType | type[DType] | None = None) -> DummyExpr: +def lit( + value: NonNestedLiteral | DummySeries, dtype: DType | type[DType] | None = None +) -> DummyExpr: from narwhals.dtypes import DType from narwhals.dtypes import Unknown + if isinstance(value, DummySeries): + return SeriesLiteral(value=value).to_literal().to_narwhals() + if dtype is None or not isinstance(dtype, DType): dtype = Unknown() - return Literal(value=ScalarLiteral(value=value, dtype=dtype)).to_narwhals() + return ScalarLiteral(value=value, dtype=dtype).to_literal().to_narwhals() def len() -> DummyExpr: diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 75aad29b80..e02dd3c129 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -66,6 +66,10 @@ class Literal(ExprIR): def is_scalar(self) -> bool: return self.value.is_scalar + @property + def dtype(self) -> DType: + return self.value.dtype + def __repr__(self) -> str: return f"lit({self.value!r})" diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index 56e5dbfbd0..226887a2fb 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -6,6 +6,7 @@ if TYPE_CHECKING: from narwhals._plan.common import DummySeries + from narwhals._plan.expr import Literal from narwhals.dtypes import DType from narwhals.typing import NonNestedLiteral @@ -13,6 +14,15 @@ class LiteralValue(ExprIR): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/plans/lit.rs#L67-L73.""" + @property + def dtype(self) -> DType: + raise NotImplementedError + + def to_literal(self) -> Literal: + from narwhals._plan.expr import Literal + + return Literal(value=self) + class ScalarLiteral(LiteralValue): __slots__ = ("dtype", "value") @@ -40,6 +50,10 @@ class SeriesLiteral(LiteralValue): value: DummySeries + @property + def dtype(self) -> DType: + return self.value.dtype + def __repr__(self) -> str: return "Series" From 51184401173c3b37784b50d8e8f869b770ed7cc4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 12:01:28 +0100 Subject: [PATCH 057/368] revert: Don't alias `builtins` --- narwhals/_plan/demo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index 3a6d778b75..1865a950e9 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -1,6 +1,6 @@ from __future__ import annotations -import builtins as bltns +import builtins import typing as t from narwhals._plan.common import DummySeries @@ -28,7 +28,7 @@ def col(*names: str | t.Iterable[str]) -> DummyExpr: flat_names = tuple(flatten(names)) node = ( Column(name=flat_names[0]) - if bltns.len(flat_names) == 1 + if builtins.len(flat_names) == 1 else Columns(names=flat_names) ) return node.to_narwhals() @@ -40,7 +40,7 @@ def nth(*indices: int | t.Sequence[int]) -> DummyExpr: flat_indices = tuple(flatten(indices)) node = ( Nth(index=flat_indices[0]) - if bltns.len(flat_indices) == 1 + if builtins.len(flat_indices) == 1 else IndexColumns(indices=flat_indices) ) return node.to_narwhals() From 19350e384106438d888355e6a3f25aea5e8f0ca7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 13:48:20 +0100 Subject: [PATCH 058/368] refactor: Organize for fewer inline imports `dummy` and `demo` are temporary anyway --- narwhals/_plan/common.py | 128 ++------------------------------------ narwhals/_plan/demo.py | 18 ++---- narwhals/_plan/dummy.py | 115 ++++++++++++++++++++++++++++++++++ narwhals/_plan/literal.py | 2 +- 4 files changed, 127 insertions(+), 136 deletions(-) create mode 100644 narwhals/_plan/dummy.py diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 5f4ef3344f..878e71b212 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -8,13 +8,12 @@ from typing import Callable from typing_extensions import Never - from typing_extensions import Self from typing_extensions import TypeAlias from typing_extensions import dataclass_transform + from narwhals._plan.dummy import DummyCompliantExpr + from narwhals._plan.dummy import DummyExpr from narwhals._plan.options import FunctionOptions - from narwhals.dtypes import DType - from narwhals.typing import NativeSeries else: # NOTE: This isn't important to the proposal, just wanted IDE support @@ -117,9 +116,13 @@ class ExprIR(Immutable): """Anything that can be a node on a graph of expressions.""" def to_narwhals(self) -> DummyExpr: + from narwhals._plan.dummy import DummyExpr + return DummyExpr._from_ir(self) def to_compliant(self) -> DummyCompliantExpr: + from narwhals._plan.dummy import DummyCompliantExpr + return DummyCompliantExpr._from_ir(self) @property @@ -127,125 +130,6 @@ def is_scalar(self) -> bool: return False -# NOTE: Overly simplified placeholders for mocking typing -# Entirely ignoring namespace + function binding -class DummyExpr: - _ir: ExprIR - - def __repr__(self) -> str: - return f"Narwhals DummyExpr:\n{self._ir!r}" - - @classmethod - def _from_ir(cls, ir: ExprIR, /) -> Self: - obj = cls.__new__(cls) - obj._ir = ir - return obj - - def alias(self, name: str) -> Self: - from narwhals._plan.expr import Alias - - return self._from_ir(Alias(expr=self._ir, name=name)) - - def cast(self, dtype: DType | type[DType]) -> Self: - from narwhals._plan.expr import Cast - from narwhals.dtypes import DType - from narwhals.dtypes import Unknown - - dtype = dtype if isinstance(dtype, DType) else Unknown() - return self._from_ir(Cast(expr=self._ir, dtype=dtype)) - - def count(self) -> Self: - from narwhals._plan.aggregation import Count - - return self._from_ir(Count(expr=self._ir)) - - def max(self) -> Self: - from narwhals._plan.aggregation import Max - - return self._from_ir(Max(expr=self._ir)) - - def mean(self) -> Self: - from narwhals._plan.aggregation import Mean - - return self._from_ir(Mean(expr=self._ir)) - - def min(self) -> Self: - from narwhals._plan.aggregation import Min - - return self._from_ir(Min(expr=self._ir)) - - def median(self) -> Self: - from narwhals._plan.aggregation import Median - - return self._from_ir(Median(expr=self._ir)) - - def n_unique(self) -> Self: - from narwhals._plan.aggregation import NUnique - - return self._from_ir(NUnique(expr=self._ir)) - - def sum(self) -> Self: - from narwhals._plan.aggregation import Sum - - return self._from_ir(Sum(expr=self._ir)) - - -class DummyCompliantExpr: - _ir: ExprIR - - @classmethod - def _from_ir(cls, ir: ExprIR, /) -> Self: - obj = cls.__new__(cls) - obj._ir = ir - return obj - - -class DummySeries: - _compliant: DummyCompliantSeries - - @property - def dtype(self) -> DType: - return self._compliant.dtype - - @property - def name(self) -> str: - return self._compliant.name - - @classmethod - def from_native(cls, native: NativeSeries, /) -> Self: - obj = cls.__new__(cls) - obj._compliant = DummyCompliantSeries.from_native(native) - return obj - - -class DummyCompliantSeries: - _native: NativeSeries - _name: str - - @property - def dtype(self) -> DType: - from narwhals.dtypes import Float64 - - return Float64() - - @property - def name(self) -> str: - return self._name - - @classmethod - def from_native(cls, native: NativeSeries, /) -> Self: - from narwhals.utils import _hasattr_static - - name: str = "" - - if _hasattr_static(native, "name"): - name = getattr(native, "name", name) - obj = cls.__new__(cls) - obj._native = native - obj._name = name - return obj - - class Function(ExprIR): """Shared by expr functions and namespace functions. diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index 1865a950e9..26adf2e53e 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -3,7 +3,7 @@ import builtins import typing as t -from narwhals._plan.common import DummySeries +from narwhals._plan.dummy import DummySeries from narwhals._plan.expr import All from narwhals._plan.expr import Column from narwhals._plan.expr import Columns @@ -15,16 +15,16 @@ from narwhals._plan.literal import ScalarLiteral from narwhals._plan.literal import SeriesLiteral from narwhals._plan.options import FunctionOptions +from narwhals.dtypes import DType +from narwhals.dtypes import Unknown +from narwhals.utils import flatten if t.TYPE_CHECKING: - from narwhals._plan.common import DummyExpr - from narwhals.dtypes import DType + from narwhals._plan.dummy import DummyExpr from narwhals.typing import NonNestedLiteral def col(*names: str | t.Iterable[str]) -> DummyExpr: - from narwhals.utils import flatten - flat_names = tuple(flatten(names)) node = ( Column(name=flat_names[0]) @@ -35,8 +35,6 @@ def col(*names: str | t.Iterable[str]) -> DummyExpr: def nth(*indices: int | t.Sequence[int]) -> DummyExpr: - from narwhals.utils import flatten - flat_indices = tuple(flatten(indices)) node = ( Nth(index=flat_indices[0]) @@ -49,12 +47,8 @@ def nth(*indices: int | t.Sequence[int]) -> DummyExpr: def lit( value: NonNestedLiteral | DummySeries, dtype: DType | type[DType] | None = None ) -> DummyExpr: - from narwhals.dtypes import DType - from narwhals.dtypes import Unknown - if isinstance(value, DummySeries): return SeriesLiteral(value=value).to_literal().to_narwhals() - if dtype is None or not isinstance(dtype, DType): dtype = Unknown() return ScalarLiteral(value=value, dtype=dtype).to_literal().to_narwhals() @@ -89,8 +83,6 @@ def sum(*columns: str) -> DummyExpr: def sum_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr: - from narwhals.utils import flatten - flat_exprs = tuple(flatten(exprs)) # NOTE: Still need to figure out how these should be generated # Feel like it should be the union of `input` & `function` diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py new file mode 100644 index 0000000000..79a6ddcac6 --- /dev/null +++ b/narwhals/_plan/dummy.py @@ -0,0 +1,115 @@ +"""Mock version of current narwhals API.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from narwhals._plan import aggregation as agg +from narwhals._plan import expr +from narwhals.dtypes import DType +from narwhals.dtypes import Unknown + +if TYPE_CHECKING: + from typing_extensions import Self + + from narwhals._plan.common import ExprIR + from narwhals.typing import NativeSeries + + +# NOTE: Overly simplified placeholders for mocking typing +# Entirely ignoring namespace + function binding +class DummyExpr: + _ir: ExprIR + + def __repr__(self) -> str: + return f"Narwhals DummyExpr:\n{self._ir!r}" + + @classmethod + def _from_ir(cls, ir: ExprIR, /) -> Self: + obj = cls.__new__(cls) + obj._ir = ir + return obj + + def alias(self, name: str) -> Self: + return self._from_ir(expr.Alias(expr=self._ir, name=name)) + + def cast(self, dtype: DType | type[DType]) -> Self: + dtype = dtype if isinstance(dtype, DType) else Unknown() + return self._from_ir(expr.Cast(expr=self._ir, dtype=dtype)) + + def count(self) -> Self: + return self._from_ir(agg.Count(expr=self._ir)) + + def max(self) -> Self: + return self._from_ir(agg.Max(expr=self._ir)) + + def mean(self) -> Self: + return self._from_ir(agg.Mean(expr=self._ir)) + + def min(self) -> Self: + return self._from_ir(agg.Min(expr=self._ir)) + + def median(self) -> Self: + return self._from_ir(agg.Median(expr=self._ir)) + + def n_unique(self) -> Self: + return self._from_ir(agg.NUnique(expr=self._ir)) + + def sum(self) -> Self: + return self._from_ir(agg.Sum(expr=self._ir)) + + +class DummyCompliantExpr: + _ir: ExprIR + + @classmethod + def _from_ir(cls, ir: ExprIR, /) -> Self: + obj = cls.__new__(cls) + obj._ir = ir + return obj + + +class DummySeries: + _compliant: DummyCompliantSeries + + @property + def dtype(self) -> DType: + return self._compliant.dtype + + @property + def name(self) -> str: + return self._compliant.name + + @classmethod + def from_native(cls, native: NativeSeries, /) -> Self: + obj = cls.__new__(cls) + obj._compliant = DummyCompliantSeries.from_native(native) + return obj + + +class DummyCompliantSeries: + _native: NativeSeries + _name: str + + @property + def dtype(self) -> DType: + from narwhals.dtypes import Float64 + + return Float64() + + @property + def name(self) -> str: + return self._name + + @classmethod + def from_native(cls, native: NativeSeries, /) -> Self: + from narwhals.utils import _hasattr_static + + name: str = "" + + if _hasattr_static(native, "name"): + name = getattr(native, "name", name) + obj = cls.__new__(cls) + obj._native = native + obj._name = name + return obj diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index 226887a2fb..2a7ec58b1c 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -5,7 +5,7 @@ from narwhals._plan.common import ExprIR if TYPE_CHECKING: - from narwhals._plan.common import DummySeries + from narwhals._plan.dummy import DummySeries from narwhals._plan.expr import Literal from narwhals.dtypes import DType from narwhals.typing import NonNestedLiteral From 812860ec6d661fb1071f372833b61e0d4822a2a2 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 14:08:43 +0100 Subject: [PATCH 059/368] feat: Support binary ops --- narwhals/_plan/dummy.py | 53 +++++++++++++++++++++++++++++++++++++ narwhals/_plan/operators.py | 11 ++++++++ 2 files changed, 64 insertions(+) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 79a6ddcac6..570a742d7d 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -6,6 +6,7 @@ from narwhals._plan import aggregation as agg from narwhals._plan import expr +from narwhals._plan import operators as ops from narwhals.dtypes import DType from narwhals.dtypes import Unknown @@ -58,6 +59,58 @@ def n_unique(self) -> Self: def sum(self) -> Self: return self._from_ir(agg.Sum(expr=self._ir)) + def __ne__(self, other: DummyExpr) -> DummyExpr: + op = ops.NotEq() + return op.to_binary_expr(self._ir, other._ir).to_narwhals() + + def __lt__(self, other: DummyExpr) -> DummyExpr: + op = ops.Lt() + return op.to_binary_expr(self._ir, other._ir).to_narwhals() + + def __le__(self, other: DummyExpr) -> DummyExpr: + op = ops.LtEq() + return op.to_binary_expr(self._ir, other._ir).to_narwhals() + + def __gt__(self, other: DummyExpr) -> DummyExpr: + op = ops.Gt() + return op.to_binary_expr(self._ir, other._ir).to_narwhals() + + def __ge__(self, other: DummyExpr) -> DummyExpr: + op = ops.GtEq() + return op.to_binary_expr(self._ir, other._ir).to_narwhals() + + def __add__(self, other: DummyExpr) -> DummyExpr: + op = ops.Add() + return op.to_binary_expr(self._ir, other._ir).to_narwhals() + + def __sub__(self, other: DummyExpr) -> DummyExpr: + op = ops.Sub() + return op.to_binary_expr(self._ir, other._ir).to_narwhals() + + def __mul__(self, other: DummyExpr) -> DummyExpr: + op = ops.Multiply() + return op.to_binary_expr(self._ir, other._ir).to_narwhals() + + def __truediv__(self, other: DummyExpr) -> DummyExpr: + op = ops.TrueDivide() + return op.to_binary_expr(self._ir, other._ir).to_narwhals() + + def __floordiv__(self, other: DummyExpr) -> DummyExpr: + op = ops.FloorDivide() + return op.to_binary_expr(self._ir, other._ir).to_narwhals() + + def __mod__(self, other: DummyExpr) -> DummyExpr: + op = ops.Modulus() + return op.to_binary_expr(self._ir, other._ir).to_narwhals() + + def __and__(self, other: DummyExpr) -> DummyExpr: + op = ops.And() + return op.to_binary_expr(self._ir, other._ir).to_narwhals() + + def __or__(self, other: DummyExpr) -> DummyExpr: + op = ops.Or() + return op.to_binary_expr(self._ir, other._ir).to_narwhals() + class DummyCompliantExpr: _ir: ExprIR diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index d1d000734a..fcdab07101 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -1,5 +1,11 @@ from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from narwhals._plan.expr import BinaryExpr + +from narwhals._plan.common import ExprIR from narwhals._plan.common import Immutable @@ -26,6 +32,11 @@ def __repr__(self) -> str: } return m[tp] + def to_binary_expr(self, left: ExprIR, right: ExprIR, /) -> BinaryExpr: + from narwhals._plan.expr import BinaryExpr + + return BinaryExpr(left=left, op=self, right=right) + class Eq(Operator): ... From de9739f98832963df6d92a5cd8db27993884572a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 14:18:52 +0100 Subject: [PATCH 060/368] feat: Highly sugar `FunctionExpr` Realised I need to use it for `__invert__`/ `Not`, which is a unary expr --- narwhals/_plan/common.py | 11 +++++++++++ narwhals/_plan/demo.py | 11 ++--------- narwhals/_plan/expr.py | 6 ++++++ 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 878e71b212..cf99c1c4e5 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -8,11 +8,13 @@ from typing import Callable from typing_extensions import Never + from typing_extensions import Self from typing_extensions import TypeAlias from typing_extensions import dataclass_transform from narwhals._plan.dummy import DummyCompliantExpr from narwhals._plan.dummy import DummyExpr + from narwhals._plan.expr import FunctionExpr from narwhals._plan.options import FunctionOptions else: @@ -145,3 +147,12 @@ def function_options(self) -> FunctionOptions: @property def is_scalar(self) -> bool: return self.function_options.returns_scalar() + + def to_function_expr(self, *inputs: ExprIR) -> FunctionExpr[Self]: + from narwhals._plan.expr import FunctionExpr + from narwhals._plan.options import FunctionOptions + + # NOTE: Still need to figure out how these should be generated + # Feel like it should be the union of `input` & `function` + PLACEHOLDER = FunctionOptions.default() # noqa: N806 + return FunctionExpr(input=inputs, function=self, options=PLACEHOLDER) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index 26adf2e53e..cbe215359e 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -7,14 +7,12 @@ from narwhals._plan.expr import All from narwhals._plan.expr import Column from narwhals._plan.expr import Columns -from narwhals._plan.expr import FunctionExpr from narwhals._plan.expr import IndexColumns from narwhals._plan.expr import Len from narwhals._plan.expr import Nth from narwhals._plan.functions import SumHorizontal from narwhals._plan.literal import ScalarLiteral from narwhals._plan.literal import SeriesLiteral -from narwhals._plan.options import FunctionOptions from narwhals.dtypes import DType from narwhals.dtypes import Unknown from narwhals.utils import flatten @@ -83,10 +81,5 @@ def sum(*columns: str) -> DummyExpr: def sum_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr: - flat_exprs = tuple(flatten(exprs)) - # NOTE: Still need to figure out how these should be generated - # Feel like it should be the union of `input` & `function` - PLACEHOLDER = FunctionOptions.default() # noqa: N806 - return FunctionExpr( - input=flat_exprs, function=SumHorizontal(), options=PLACEHOLDER - ).to_narwhals() + it = (expr._ir for expr in flatten(exprs)) + return SumHorizontal().to_function_expr(*it).to_narwhals() diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index e02dd3c129..e6c1bcfdc9 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -7,6 +7,8 @@ from narwhals._plan.common import ExprIR if t.TYPE_CHECKING: + from typing_extensions import Self + from narwhals._plan.common import Function from narwhals._plan.common import Seq from narwhals._plan.functions import MapBatches @@ -170,6 +172,10 @@ class FunctionExpr(ExprIR, t.Generic[_FunctionT]): 2. The union of (1) and any `FunctionOptions` in `inputs` """ + def with_options(self, options: FunctionOptions, /) -> Self: + options = self.options.with_flags(options.flags) + return type(self)(input=self.input, function=self.function, options=options) + class RollingExpr(FunctionExpr[_RollingT]): ... From 629ece24723b0f3a76af383042826cc1eb13e144 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 14:20:24 +0100 Subject: [PATCH 061/368] feat: Add missing `__eq__` --- narwhals/_plan/dummy.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 570a742d7d..1879860c49 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -59,7 +59,11 @@ def n_unique(self) -> Self: def sum(self) -> Self: return self._from_ir(agg.Sum(expr=self._ir)) - def __ne__(self, other: DummyExpr) -> DummyExpr: + def __eq__(self, other: DummyExpr) -> DummyExpr: # type: ignore[override] + op = ops.Eq() + return op.to_binary_expr(self._ir, other._ir).to_narwhals() + + def __ne__(self, other: DummyExpr) -> DummyExpr: # type: ignore[override] op = ops.NotEq() return op.to_binary_expr(self._ir, other._ir).to_narwhals() From ece4f4d4b7cc666e3764bd127d2c87a5922e9e73 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 14:23:19 +0100 Subject: [PATCH 062/368] feat: Add `__invert__` --- narwhals/_plan/dummy.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 1879860c49..bb06d0f31a 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from narwhals._plan import aggregation as agg +from narwhals._plan import boolean from narwhals._plan import expr from narwhals._plan import operators as ops from narwhals.dtypes import DType @@ -115,6 +116,9 @@ def __or__(self, other: DummyExpr) -> DummyExpr: op = ops.Or() return op.to_binary_expr(self._ir, other._ir).to_narwhals() + def __invert__(self) -> DummyExpr: + return boolean.Not().to_function_expr().to_narwhals() + class DummyCompliantExpr: _ir: ExprIR From bd1001067b88f56dbb9f7bcdbcbee441bfa80436 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 15:04:24 +0100 Subject: [PATCH 063/368] feat: `FunctionExpr` repr, fix `Not` --- narwhals/_plan/common.py | 6 ++++++ narwhals/_plan/dummy.py | 2 +- narwhals/_plan/expr.py | 9 +++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index cf99c1c4e5..b28bc2ec8b 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -88,6 +88,12 @@ def __eq__(self, other: object) -> bool: getattr(self, name, empty) == getattr(other, name, empty) for name in slots ) + def __str__(self) -> str: + # NOTE: Debug repr, closer to constructor + slots: tuple[str, ...] = self.__slots__ + fields = ", ".join(f"{name}={getattr(self, name)}" for name in slots) + return f"{type(self).__name__}({fields})" + def __init__(self, **kwds: Any) -> None: # NOTE: DUMMY CONSTRUCTOR - don't use beyond prototyping! # Just need a quick way to demonstrate `ExprIR` and interactions diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index bb06d0f31a..0dfa6efdb1 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -117,7 +117,7 @@ def __or__(self, other: DummyExpr) -> DummyExpr: return op.to_binary_expr(self._ir, other._ir).to_narwhals() def __invert__(self) -> DummyExpr: - return boolean.Not().to_function_expr().to_narwhals() + return boolean.Not().to_function_expr(self._ir).to_narwhals() class DummyCompliantExpr: diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index e6c1bcfdc9..632a447a43 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -176,6 +176,15 @@ def with_options(self, options: FunctionOptions, /) -> Self: options = self.options.with_flags(options.flags) return type(self)(input=self.input, function=self.function, options=options) + def __repr__(self) -> str: + if self.input: + first = self.input[0] + if len(self.input) >= 2: + return f"{first!r}.{self.function!r}({list(self.input[1:])!r})" + return f"{first!r}.{self.function!r}()" + else: + return f"{self.function!r}()" + class RollingExpr(FunctionExpr[_RollingT]): ... From b0a04f5dfd38654904cd97ef93f6a8c7bb64a08a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 15:14:10 +0100 Subject: [PATCH 064/368] feat: Add the other horizontals --- narwhals/_plan/demo.py | 30 ++++++++++++++++++++++++++++-- narwhals/_plan/functions.py | 15 +++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index cbe215359e..476439020a 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -3,6 +3,8 @@ import builtins import typing as t +from narwhals._plan import boolean +from narwhals._plan import functions as F # noqa: N812 from narwhals._plan.dummy import DummySeries from narwhals._plan.expr import All from narwhals._plan.expr import Column @@ -10,7 +12,6 @@ from narwhals._plan.expr import IndexColumns from narwhals._plan.expr import Len from narwhals._plan.expr import Nth -from narwhals._plan.functions import SumHorizontal from narwhals._plan.literal import ScalarLiteral from narwhals._plan.literal import SeriesLiteral from narwhals.dtypes import DType @@ -80,6 +81,31 @@ def sum(*columns: str) -> DummyExpr: return col(columns).sum() +def all_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr: + it = (expr._ir for expr in flatten(exprs)) + return boolean.AllHorizontal().to_function_expr(*it).to_narwhals() + + +def any_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr: + it = (expr._ir for expr in flatten(exprs)) + return boolean.AnyHorizontal().to_function_expr(*it).to_narwhals() + + def sum_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr: it = (expr._ir for expr in flatten(exprs)) - return SumHorizontal().to_function_expr(*it).to_narwhals() + return F.SumHorizontal().to_function_expr(*it).to_narwhals() + + +def min_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr: + it = (expr._ir for expr in flatten(exprs)) + return F.MinHorizontal().to_function_expr(*it).to_narwhals() + + +def max_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr: + it = (expr._ir for expr in flatten(exprs)) + return F.MaxHorizontal().to_function_expr(*it).to_narwhals() + + +def mean_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr: + it = (expr._ir for expr in flatten(exprs)) + return F.MeanHorizontal().to_function_expr(*it).to_narwhals() diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index c2a8a69ab9..24df8f0f6a 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -19,6 +19,9 @@ from narwhals.typing import FillNullStrategy +# TODO @dangotbanned: repr + + class Abs(Function): @property def function_options(self) -> FunctionOptions: @@ -215,6 +218,9 @@ def function_options(self) -> FunctionOptions: FunctionFlags.INPUT_WILDCARD_EXPANSION ) + def __repr__(self) -> str: + return "sum_horizontal" + class MinHorizontal(Function): @property @@ -223,6 +229,9 @@ def function_options(self) -> FunctionOptions: FunctionFlags.INPUT_WILDCARD_EXPANSION ) + def __repr__(self) -> str: + return "min_horizontal" + class MaxHorizontal(Function): @property @@ -231,6 +240,9 @@ def function_options(self) -> FunctionOptions: FunctionFlags.INPUT_WILDCARD_EXPANSION ) + def __repr__(self) -> str: + return "max_horizontal" + class MeanHorizontal(Function): @property @@ -239,6 +251,9 @@ def function_options(self) -> FunctionOptions: FunctionFlags.INPUT_WILDCARD_EXPANSION ) + def __repr__(self) -> str: + return "mean_horizontal" + class EwmMean(Function): __slots__ = ("options",) From 5860917082c7edc8babd70df8ee355b238251110 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 15:17:24 +0100 Subject: [PATCH 065/368] docs: Add note on `nw.when` --- narwhals/_plan/expr.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 632a447a43..8a03c832cd 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -303,3 +303,10 @@ def __repr__(self) -> str: # NOTE: by_dtype, matches, numeric, boolean, string, categorical, datetime, all class Selector(ExprIR): ... + + +class Ternary(ExprIR): + """When-Then-Otherwise. + + Deferring this for now. + """ From 600db3c613ef91348d66ce239e68c07b8983cdf0 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 15:25:18 +0100 Subject: [PATCH 066/368] feat: Add `concat_str` --- narwhals/_plan/demo.py | 15 +++++++++++ narwhals/_plan/strings.py | 53 +++++++++++++++++++++++++++++++++------ 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index 476439020a..743c80e0f7 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -14,6 +14,7 @@ from narwhals._plan.expr import Nth from narwhals._plan.literal import ScalarLiteral from narwhals._plan.literal import SeriesLiteral +from narwhals._plan.strings import ConcatHorizontal from narwhals.dtypes import DType from narwhals.dtypes import Unknown from narwhals.utils import flatten @@ -109,3 +110,17 @@ def max_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr: def mean_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr: it = (expr._ir for expr in flatten(exprs)) return F.MeanHorizontal().to_function_expr(*it).to_narwhals() + + +def concat_str( + exprs: DummyExpr | t.Iterable[DummyExpr], + *more_exprs: DummyExpr, + separator: str = "", + ignore_nulls: bool = False, +) -> DummyExpr: + it = (expr._ir for expr in flatten([*flatten([exprs]), *more_exprs])) + return ( + ConcatHorizontal(separator=separator, ignore_nulls=ignore_nulls) + .to_function_expr(*it) + .to_narwhals() + ) diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/strings.py index 4ac3531d4b..8b64875cac 100644 --- a/narwhals/_plan/strings.py +++ b/narwhals/_plan/strings.py @@ -5,11 +5,15 @@ from narwhals._plan.options import FunctionOptions +# TODO @dangotbanned: repr class StringFunction(Function): @property def function_options(self) -> FunctionOptions: return FunctionOptions.elementwise() + def __repr__(self) -> str: + return "StringFunction" + class ConcatHorizontal(StringFunction): """`nw.functions.concat_str`.""" @@ -23,17 +27,27 @@ class ConcatHorizontal(StringFunction): def function_options(self) -> FunctionOptions: return super().function_options.with_flags(FunctionFlags.INPUT_WILDCARD_EXPANSION) + def __repr__(self) -> str: + return "str.concat_horizontal" + class Contains(StringFunction): __slots__ = ("literal",) literal: bool + def __repr__(self) -> str: + return "str.contains" -class EndsWith(StringFunction): ... +class EndsWith(StringFunction): + def __repr__(self) -> str: + return "str.ends_with" -class LenChars(StringFunction): ... + +class LenChars(StringFunction): + def __repr__(self) -> str: + return "str.len_chars" class Replace(StringFunction): @@ -41,6 +55,9 @@ class Replace(StringFunction): literal: bool + def __repr__(self) -> str: + return "str.replace" + class ReplaceAll(StringFunction): """`polars` uses a single node for this and `Replace`. @@ -52,6 +69,9 @@ class ReplaceAll(StringFunction): literal: bool + def __repr__(self) -> str: + return "str.replace_all" + class Slice(StringFunction): """We're using for `Head`, `Tail` as well. @@ -66,26 +86,41 @@ class Slice(StringFunction): offset: int length: int | None + def __repr__(self) -> str: + return "str.slice" + class Head(StringFunction): __slots__ = ("n",) n: int + def __repr__(self) -> str: + return "str.head" + class Tail(StringFunction): __slots__ = ("n",) n: int + def __repr__(self) -> str: + return "str.tail" + -class Split(StringFunction): ... +class Split(StringFunction): + def __repr__(self) -> str: + return "str.split" -class StartsWith(StringFunction): ... +class StartsWith(StringFunction): + def __repr__(self) -> str: + return "str.startswith" -class StripChars(StringFunction): ... +class StripChars(StringFunction): + def __repr__(self) -> str: + return "str.strip_chars" class ToDatetime(StringFunction): @@ -101,7 +136,11 @@ class ToDatetime(StringFunction): format: str | None -class ToLowercase(StringFunction): ... +class ToLowercase(StringFunction): + def __repr__(self) -> str: + return "str.to_lowercase" -class ToUppercase(StringFunction): ... +class ToUppercase(StringFunction): + def __repr__(self) -> str: + return "str.to_uppercase" From 20cdf4afb0a41a23e544fe71720f903ea8604742 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 15:29:36 +0100 Subject: [PATCH 067/368] feat: Add `.name.*` reprs --- narwhals/_plan/name.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/name.py b/narwhals/_plan/name.py index 3d4e6dd7e0..e65df5c4bd 100644 --- a/narwhals/_plan/name.py +++ b/narwhals/_plan/name.py @@ -10,7 +10,7 @@ class NameFunction(Function): - """`polars` version doesn't represent in the same way here. + """`polars` version doesn't represent these as `FunctionExpr`. https://github.com/pola-rs/polars/blob/6df23a09a81c640c21788607611e09d9f43b1abc/crates/polars-plan/src/dsl/name.rs """ @@ -19,6 +19,20 @@ class NameFunction(Function): def function_options(self) -> FunctionOptions: return FunctionOptions.elementwise() + def __repr__(self) -> str: + tp = type(self) + if tp is NameFunction: + return tp.__name__ + m: dict[type[NameFunction], str] = { + Keep: "keep", + Map: "map", + Suffix: "suffix", + Prefix: "prefix", + ToLowercase: "to_lowercase", + ToUppercase: "to_uppercase", + } + return f"name.{m[tp]}" + class Keep(NameFunction): ... From 58633e3ebb4fb4ce80d00041560070635f470e96 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 15:42:56 +0100 Subject: [PATCH 068/368] feat: Add `.dt.*` reprs --- narwhals/_plan/temporal.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/narwhals/_plan/temporal.py b/narwhals/_plan/temporal.py index 3f50f25f46..9485baefb3 100644 --- a/narwhals/_plan/temporal.py +++ b/narwhals/_plan/temporal.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing import cast from narwhals._plan.common import Function from narwhals._plan.options import FunctionOptions @@ -14,6 +15,38 @@ class TemporalFunction(Function): def function_options(self) -> FunctionOptions: return FunctionOptions.elementwise() + def __repr__(self) -> str: + tp = type(self) + if tp is TemporalFunction: + return tp.__name__ + elif tp is Timestamp: + tu = cast("Timestamp", self).time_unit + return f"dt.timestamp[{tu!r}]" + m: dict[type[TemporalFunction], str] = { + Year: "year", + Month: "month", + WeekDay: "weekday", + Day: "day", + OrdinalDay: "ordinal_day", + Date: "date", + Hour: "hour", + Minute: "minute", + Second: "second", + Millisecond: "millisecond", + Microsecond: "microsecond", + Nanosecond: "nanosecond", + TotalMinutes: "total_minutes", + TotalSeconds: "total_seconds", + TotalMilliseconds: "total_milliseconds", + TotalMicroseconds: "total_microseconds", + TotalNanoseconds: "total_nanoseconds", + ToString: "to_string", + ConvertTimeZone: "convert_time_zone", + ReplaceTimeZone: "replace_time_zone", + Truncate: "truncate", + } + return f"dt.{m[tp]}" + class Date(TemporalFunction): ... From 2818385bb15e925d710f3deafea0e09ecb9a2ac3 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 15:46:14 +0100 Subject: [PATCH 069/368] chore: Remove outdated notes --- narwhals/_plan/functions.py | 2 -- narwhals/_plan/strings.py | 1 - 2 files changed, 3 deletions(-) diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 24df8f0f6a..86f8bf9643 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -1,5 +1,3 @@ -"""TODO: Attributes.""" - from __future__ import annotations from typing import TYPE_CHECKING diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/strings.py index 8b64875cac..697b0b122d 100644 --- a/narwhals/_plan/strings.py +++ b/narwhals/_plan/strings.py @@ -5,7 +5,6 @@ from narwhals._plan.options import FunctionOptions -# TODO @dangotbanned: repr class StringFunction(Function): @property def function_options(self) -> FunctionOptions: From 68f7b483d7e979589b1866dbea119e5f1f154ba4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 15:48:21 +0100 Subject: [PATCH 070/368] revert: Remove unreachable empty slots handling Added this before the constructors enforced everything to be set --- narwhals/_plan/common.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index b28bc2ec8b..da5baa6be9 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -73,20 +73,16 @@ def __init_subclass__(cls, *args: Any, **kwds: Any) -> None: cls.__slots__ = () def __hash__(self) -> int: - empty = object() slots: tuple[str, ...] = self.__slots__ - return hash(tuple(getattr(self, name, empty) for name in slots)) + return hash(tuple(getattr(self, name) for name in slots)) def __eq__(self, other: object) -> bool: if self is other: return True elif type(self) is not type(other): return False - empty = object() slots: tuple[str, ...] = self.__slots__ - return all( - getattr(self, name, empty) == getattr(other, name, empty) for name in slots - ) + return all(getattr(self, name) == getattr(other, name) for name in slots) def __str__(self) -> str: # NOTE: Debug repr, closer to constructor From fc0bd0e4e29cfc41cd00bdeab70c25913decc1e7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 15:53:55 +0100 Subject: [PATCH 071/368] docs: Add some notes to `expr` --- narwhals/_plan/expr.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 8a03c832cd..1ba99867e1 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -1,3 +1,11 @@ +"""Top-level `Expr` nodes. + +Todo: +- `Selector` +- `Ternary` +- `Window` (investigate variants) +""" + from __future__ import annotations # NOTE: Needed to avoid naming collisions @@ -301,8 +309,8 @@ def __repr__(self) -> str: return "*" -# NOTE: by_dtype, matches, numeric, boolean, string, categorical, datetime, all -class Selector(ExprIR): ... +class Selector(ExprIR): + """by_dtype, matches, numeric, boolean, string, categorical, datetime, all.""" class Ternary(ExprIR): From 12069fc5ebae958cff9d3be20c953cf1b5956200 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 15:56:25 +0100 Subject: [PATCH 072/368] docs: Note of `functions` --- narwhals/_plan/functions.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 86f8bf9643..86943156a5 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -1,3 +1,9 @@ +"""General functions that aren't namespaced. + +Todo: +- repr +""" + from __future__ import annotations from typing import TYPE_CHECKING @@ -17,9 +23,6 @@ from narwhals.typing import FillNullStrategy -# TODO @dangotbanned: repr - - class Abs(Function): @property def function_options(self) -> FunctionOptions: From 0c4dedb246bc34aa96faa8958fe55b151e04ac71 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 17:37:27 +0100 Subject: [PATCH 073/368] feat: Finish `.over()` --- narwhals/_plan/dummy.py | 23 +++++++++++++++++++++++ narwhals/_plan/expr.py | 26 +++++++++++++++++++++++++- narwhals/_plan/window.py | 23 ++++++++++++++++++++--- 3 files changed, 68 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 0dfa6efdb1..4ea5e979bd 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -2,19 +2,24 @@ from __future__ import annotations +import typing as t from typing import TYPE_CHECKING from narwhals._plan import aggregation as agg from narwhals._plan import boolean from narwhals._plan import expr from narwhals._plan import operators as ops +from narwhals._plan.options import SortOptions +from narwhals._plan.window import Over from narwhals.dtypes import DType from narwhals.dtypes import Unknown +from narwhals.utils import flatten if TYPE_CHECKING: from typing_extensions import Self from narwhals._plan.common import ExprIR + from narwhals._plan.common import Seq from narwhals.typing import NativeSeries @@ -60,6 +65,24 @@ def n_unique(self) -> Self: def sum(self) -> Self: return self._from_ir(agg.Sum(expr=self._ir)) + def over( + self, + *partition_by: DummyExpr | t.Iterable[DummyExpr], + order_by: DummyExpr | t.Iterable[DummyExpr] | None = None, + descending: bool = False, + nulls_last: bool = False, + ) -> DummyExpr: + order: tuple[Seq[ExprIR], SortOptions] | None = None + partition = tuple(expr._ir for expr in flatten(partition_by)) + if not (partition) and order_by is None: + msg = "At least one of `partition_by` or `order_by` must be specified." + raise TypeError(msg) + if order_by is not None: + by = tuple(expr._ir for expr in flatten([order_by])) + options = SortOptions(descending=descending, nulls_last=nulls_last) + order = by, options + return Over().to_window_expr(self._ir, partition, order).to_narwhals() + def __eq__(self, other: DummyExpr) -> DummyExpr: # type: ignore[override] op = ops.Eq() return op.to_binary_expr(self._ir, other._ir).to_narwhals() diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 1ba99867e1..8ac5b7a00f 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -245,7 +245,12 @@ class WindowExpr(ExprIR): """ partition_by: Seq[ExprIR] - order_by: tuple[ExprIR, SortOptions] | None + order_by: tuple[Seq[ExprIR], SortOptions] | None + """Deviates from the `polars` version. + + - `order_by` starts the same as here, but `polars` reduces into a struct - becoming a single (nested) node. + """ + options: Window """Little confused on the nesting. @@ -258,6 +263,25 @@ class WindowExpr(ExprIR): Expr::Window { options: WindowType::Rolling(RollingGroupOptions) } """ + def __repr__(self) -> str: + if self.order_by is None: + return f"{self.expr!r}.over({list(self.partition_by)!r})" + order, _ = self.order_by + if not self.partition_by: + args = f"order_by={list(order)!r}" + else: + args = f"partition_by={list(self.partition_by)!r}, order_by={list(order)!r}" + return f"{self.expr!r}.over({args})" + + def __str__(self) -> str: + if self.order_by is None: + order_by = "None" + else: + order, opts = self.order_by + order_by = f"({order}, {opts})" + args = f"expr={self.expr}, partition_by={self.partition_by}, order_by={order_by}, options={self.options}" + return f"{type(self).__name__}({args})" + class Len(ExprIR): @property diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py index 61097efe2d..9ee9a2c910 100644 --- a/narwhals/_plan/window.py +++ b/narwhals/_plan/window.py @@ -1,9 +1,14 @@ -"""TODO: Figure out what `Over` should be holding or skip it and go straight to `WindowExpr`.""" - from __future__ import annotations +from typing import TYPE_CHECKING + from narwhals._plan.common import ExprIR +if TYPE_CHECKING: + from narwhals._plan.common import Seq + from narwhals._plan.expr import WindowExpr + from narwhals._plan.options import SortOptions + class Window(ExprIR): """Renamed from `WindowType`. @@ -17,4 +22,16 @@ class Window(ExprIR): # - `_plan.expr.WindowExpr` has: # - expr (last node) # - partition_by, optional order_by, `options` which is one of these classes? -class Over(Window): ... +class Over(Window): + def to_window_expr( + self, + expr: ExprIR, + partition_by: Seq[ExprIR], + order_by: tuple[Seq[ExprIR], SortOptions] | None, + /, + ) -> WindowExpr: + from narwhals._plan.expr import WindowExpr + + return WindowExpr( + expr=expr, partition_by=partition_by, order_by=order_by, options=self + ) From 291afaa05a2e70f159d86855a74bde87db11f9f8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 18:11:35 +0100 Subject: [PATCH 074/368] feat: Add most of the remaining aggs --- narwhals/_plan/aggregation.py | 2 +- narwhals/_plan/dummy.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index 4d461ed04b..2ad257aa79 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -55,7 +55,7 @@ class NUnique(Agg): ... class Quantile(Agg): __slots__ = (*Agg.__slots__, "interpolation", "quantile") - quantile: ExprIR + quantile: float interpolation: RollingInterpolationMethod diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 4ea5e979bd..7771addfb5 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -21,6 +21,7 @@ from narwhals._plan.common import ExprIR from narwhals._plan.common import Seq from narwhals.typing import NativeSeries + from narwhals.typing import RollingInterpolationMethod # NOTE: Overly simplified placeholders for mocking typing @@ -65,6 +66,25 @@ def n_unique(self) -> Self: def sum(self) -> Self: return self._from_ir(agg.Sum(expr=self._ir)) + def first(self) -> Self: + return self._from_ir(agg.First(expr=self._ir)) + + def last(self) -> Self: + return self._from_ir(agg.Last(expr=self._ir)) + + def var(self, *, ddof: int = 1) -> Self: + return self._from_ir(agg.Var(expr=self._ir, ddof=ddof)) + + def std(self, *, ddof: int = 1) -> Self: + return self._from_ir(agg.Std(expr=self._ir, ddof=ddof)) + + def quantile( + self, quantile: float, interpolation: RollingInterpolationMethod + ) -> Self: + return self._from_ir( + agg.Quantile(expr=self._ir, quantile=quantile, interpolation=interpolation) + ) + def over( self, *partition_by: DummyExpr | t.Iterable[DummyExpr], From f03ad55f2d5222f2083c26c37a215738b81e3252 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 18:27:19 +0100 Subject: [PATCH 075/368] feat: Add `sort_by` Related (#2547), (#2534) --- narwhals/_plan/dummy.py | 18 ++++++++++++++++++ narwhals/_plan/options.py | 10 ++++++++++ 2 files changed, 28 insertions(+) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 7771addfb5..2cf0676f6a 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -9,6 +9,7 @@ from narwhals._plan import boolean from narwhals._plan import expr from narwhals._plan import operators as ops +from narwhals._plan.options import SortMultipleOptions from narwhals._plan.options import SortOptions from narwhals._plan.window import Over from narwhals.dtypes import DType @@ -103,6 +104,23 @@ def over( order = by, options return Over().to_window_expr(self._ir, partition, order).to_narwhals() + def sort_by( + self, + by: DummyExpr | t.Iterable[DummyExpr], + *more_by: DummyExpr, + descending: bool | t.Iterable[bool] = False, + nulls_last: bool | t.Iterable[bool] = False, + ) -> Self: + if more_by: + by = (by, *more_by) if isinstance(by, DummyExpr) else (*by, *more_by) + else: + by = (by,) if isinstance(by, DummyExpr) else tuple(by) + sort_by = tuple(key._ir for key in by) + desc = (descending,) if isinstance(descending, bool) else tuple(descending) + nulls = (nulls_last,) if isinstance(nulls_last, bool) else tuple(nulls_last) + options = SortMultipleOptions(descending=desc, nulls_last=nulls) + return self._from_ir(expr.SortBy(expr=self._ir, by=sort_by, options=options)) + def __eq__(self, other: DummyExpr) -> DummyExpr: # type: ignore[override] op = ops.Eq() return op.to_binary_expr(self._ir, other._ir).to_narwhals() diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 0517cf82bc..ee014c8b0d 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -115,6 +115,10 @@ class SortOptions(Immutable): descending: bool nulls_last: bool + def __repr__(self) -> str: + args = f"descending={self.descending!r}, nulls_last={self.nulls_last!r}" + return f"{type(self).__name__}({args})" + class SortMultipleOptions(Immutable): __slots__ = ("descending", "nulls_last") @@ -122,6 +126,12 @@ class SortMultipleOptions(Immutable): descending: Seq[bool] nulls_last: Seq[bool] + def __repr__(self) -> str: + args = ( + f"descending={list(self.descending)!r}, nulls_last={list(self.nulls_last)!r}" + ) + return f"{type(self).__name__}({args})" + class RankOptions(Immutable): """https://github.com/narwhals-dev/narwhals/pull/2555.""" From cae49098479ba6270adf3f54c43145579fa71b18 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 18:51:19 +0100 Subject: [PATCH 076/368] revert: Make `first`, `last` orderable --- narwhals/_plan/aggregation.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index 2ad257aa79..434a00328e 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -29,14 +29,6 @@ def __repr__(self) -> str: class Count(Agg): ... -class First(Agg): - """https://github.com/narwhals-dev/narwhals/issues/2526.""" - - -class Last(Agg): - """https://github.com/narwhals-dev/narwhals/issues/2526.""" - - class Max(Agg): ... @@ -79,6 +71,14 @@ class Var(Agg): class OrderableAgg(Agg): ... +class First(OrderableAgg): + """https://github.com/narwhals-dev/narwhals/issues/2526.""" + + +class Last(OrderableAgg): + """https://github.com/narwhals-dev/narwhals/issues/2526.""" + + class ArgMin(OrderableAgg): ... From a482e8ebf1eb08d6c3e90662943fb380b7e809f9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 20:04:21 +0100 Subject: [PATCH 077/368] feat: Experiment with suggestions --- narwhals/_plan/demo.py | 46 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index 743c80e0f7..622cd70097 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -3,6 +3,7 @@ import builtins import typing as t +from narwhals._plan import aggregation as agg from narwhals._plan import boolean from narwhals._plan import functions as F # noqa: N812 from narwhals._plan.dummy import DummySeries @@ -17,10 +18,15 @@ from narwhals._plan.strings import ConcatHorizontal from narwhals.dtypes import DType from narwhals.dtypes import Unknown +from narwhals.exceptions import OrderDependentExprError from narwhals.utils import flatten if t.TYPE_CHECKING: + from typing_extensions import TypeIs + from narwhals._plan.dummy import DummyExpr + from narwhals._plan.expr import SortBy + from narwhals._plan.expr import WindowExpr from narwhals.typing import NonNestedLiteral @@ -124,3 +130,43 @@ def concat_str( .to_function_expr(*it) .to_narwhals() ) + + +def _is_order_enforcing_previous(obj: t.Any) -> TypeIs[SortBy]: + """In theory, we could add other nodes to this check.""" + from narwhals._plan.expr import SortBy + + allowed = (SortBy,) + return isinstance(obj, allowed) + + +def _is_order_enforcing_next(obj: t.Any) -> TypeIs[WindowExpr]: + """Not sure how this one would work.""" + from narwhals._plan.expr import WindowExpr + + return isinstance(obj, WindowExpr) and obj.order_by is not None + + +def _order_dependent_error(node: agg.OrderableAgg) -> OrderDependentExprError: + previous = node.expr + method = repr(node).removeprefix(f"{previous!r}.") + msg = ( + f"{method} is order-dependent and requires an ordering operation for lazy backends.\n" + f"Hint:\nInstead of:\n" + f" {node!r}\n\n" + "If you want to aggregate to a single value, try:\n" + f" {previous!r}.sort_by(...).{method}\n\n" + "Otherwise, try:\n" + f" {node!r}.over(order_by=...)" + ) + return OrderDependentExprError(msg) + + +def ensure_orderable_rules(*exprs: DummyExpr) -> tuple[DummyExpr, ...]: + for expr in exprs: + node = expr._ir + if isinstance(node, agg.OrderableAgg): + previous = node.expr + if not _is_order_enforcing_previous(previous): + raise _order_dependent_error(node) + return exprs From 047c815ce97c33d0dbf29fee965db15faf20cea4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 21:44:16 +0100 Subject: [PATCH 078/368] chore: fix typos https://results.pre-commit.ci/run/github/760058710/1747600769.DyFj2_uuQvGFXmK1NhtAaA --- narwhals/_plan/__init__.py | 2 +- narwhals/_plan/expr.py | 4 ++-- narwhals/_plan/options.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/__init__.py b/narwhals/_plan/__init__.py index 9a5ca74e61..bfc32b789c 100644 --- a/narwhals/_plan/__init__.py +++ b/narwhals/_plan/__init__.py @@ -1,4 +1,4 @@ -"""Brainstorming an `Expr` internal represention. +"""Brainstorming an `Expr` internal representation. Notes: - Each `Expr` method should be representable by a single node diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 8ac5b7a00f..a464ce3ee9 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -226,7 +226,7 @@ def __repr__(self) -> str: class WindowExpr(ExprIR): - """A fully specified `.over()`, that occured after another expression. + """A fully specified `.over()`, that occurred after another expression. I think we want variants for partitioned, ordered, both. @@ -310,7 +310,7 @@ def __repr__(self) -> str: class IndexColumns(ExprIR): """Renamed from `IndexColumn`. - `Nth` provides the singlular variant. + `Nth` provides the singular variant. https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L80 """ diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index ee014c8b0d..485708725d 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -180,7 +180,7 @@ class RollingOptionsFixedWindow(Immutable): window_size: int min_samples: int - """Renamed from `min_periods`, re-uses `window_size` if null.""" + """Renamed from `min_periods`, reuses `window_size` if null.""" center: bool fn_params: RollingVarParams | None From bbba0fe1edfc3c0d68f4ffa56c2e33b3632d9750 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 May 2025 22:04:37 +0100 Subject: [PATCH 079/368] fix: Avoid dtypes imports https://results.pre-commit.ci/run/github/760058710/1747601128.e2dXEDsSTO2XfoU3gLZrRg --- narwhals/_plan/demo.py | 4 ++-- narwhals/_plan/dummy.py | 8 +++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index 622cd70097..0e5b1f8452 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -17,8 +17,8 @@ from narwhals._plan.literal import SeriesLiteral from narwhals._plan.strings import ConcatHorizontal from narwhals.dtypes import DType -from narwhals.dtypes import Unknown from narwhals.exceptions import OrderDependentExprError +from narwhals.utils import Version from narwhals.utils import flatten if t.TYPE_CHECKING: @@ -56,7 +56,7 @@ def lit( if isinstance(value, DummySeries): return SeriesLiteral(value=value).to_literal().to_narwhals() if dtype is None or not isinstance(dtype, DType): - dtype = Unknown() + dtype = Version.MAIN.dtypes.Unknown() return ScalarLiteral(value=value, dtype=dtype).to_literal().to_narwhals() diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 2cf0676f6a..807bbf5346 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -13,7 +13,7 @@ from narwhals._plan.options import SortOptions from narwhals._plan.window import Over from narwhals.dtypes import DType -from narwhals.dtypes import Unknown +from narwhals.utils import Version from narwhals.utils import flatten if TYPE_CHECKING: @@ -43,7 +43,7 @@ def alias(self, name: str) -> Self: return self._from_ir(expr.Alias(expr=self._ir, name=name)) def cast(self, dtype: DType | type[DType]) -> Self: - dtype = dtype if isinstance(dtype, DType) else Unknown() + dtype = dtype if isinstance(dtype, DType) else Version.MAIN.dtypes.Unknown() return self._from_ir(expr.Cast(expr=self._ir, dtype=dtype)) def count(self) -> Self: @@ -215,9 +215,7 @@ class DummyCompliantSeries: @property def dtype(self) -> DType: - from narwhals.dtypes import Float64 - - return Float64() + return Version.MAIN.dtypes.Float64() @property def name(self) -> str: From ebd05428b755f83212c9ca24d5ecbf3896bda88a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 19 May 2025 10:42:47 +0100 Subject: [PATCH 080/368] feat: Add the concept of `Version`-ing `Version` is never stored on an `IR` node and should only be needed when converting back to another representation (`to_*`) --- narwhals/_plan/common.py | 14 +++-- narwhals/_plan/dummy.py | 115 +++++++++++++++++++++++++-------------- 2 files changed, 83 insertions(+), 46 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index da5baa6be9..ee19f17927 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -3,6 +3,8 @@ from typing import TYPE_CHECKING from typing import TypeVar +from narwhals.utils import Version + if TYPE_CHECKING: from typing import Any from typing import Callable @@ -119,15 +121,17 @@ def __init__(self, **kwds: Any) -> None: class ExprIR(Immutable): """Anything that can be a node on a graph of expressions.""" - def to_narwhals(self) -> DummyExpr: - from narwhals._plan.dummy import DummyExpr + def to_narwhals(self, version: Version = Version.MAIN) -> DummyExpr: + from narwhals._plan import dummy - return DummyExpr._from_ir(self) + if version is Version.MAIN: + return dummy.DummyExpr._from_ir(self) + return dummy.DummyExprV1._from_ir(self) - def to_compliant(self) -> DummyCompliantExpr: + def to_compliant(self, version: Version = Version.MAIN) -> DummyCompliantExpr: from narwhals._plan.dummy import DummyCompliantExpr - return DummyCompliantExpr._from_ir(self) + return DummyCompliantExpr._from_ir(self, version) @property def is_scalar(self) -> bool: diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 807bbf5346..5385abf8da 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -14,6 +14,7 @@ from narwhals._plan.window import Over from narwhals.dtypes import DType from narwhals.utils import Version +from narwhals.utils import _hasattr_static from narwhals.utils import flatten if TYPE_CHECKING: @@ -29,9 +30,10 @@ # Entirely ignoring namespace + function binding class DummyExpr: _ir: ExprIR + _version: t.ClassVar[Version] = Version.MAIN def __repr__(self) -> str: - return f"Narwhals DummyExpr:\n{self._ir!r}" + return f"Narwhals DummyExpr ({self.version.name.lower()}):\n{self._ir!r}" @classmethod def _from_ir(cls, ir: ExprIR, /) -> Self: @@ -39,11 +41,15 @@ def _from_ir(cls, ir: ExprIR, /) -> Self: obj._ir = ir return obj + @property + def version(self) -> Version: + return self._version + def alias(self, name: str) -> Self: return self._from_ir(expr.Alias(expr=self._ir, name=name)) def cast(self, dtype: DType | type[DType]) -> Self: - dtype = dtype if isinstance(dtype, DType) else Version.MAIN.dtypes.Unknown() + dtype = dtype if isinstance(dtype, DType) else self.version.dtypes.Unknown() return self._from_ir(expr.Cast(expr=self._ir, dtype=dtype)) def count(self) -> Self: @@ -92,7 +98,7 @@ def over( order_by: DummyExpr | t.Iterable[DummyExpr] | None = None, descending: bool = False, nulls_last: bool = False, - ) -> DummyExpr: + ) -> Self: order: tuple[Seq[ExprIR], SortOptions] | None = None partition = tuple(expr._ir for expr in flatten(partition_by)) if not (partition) and order_by is None: @@ -102,7 +108,7 @@ def over( by = tuple(expr._ir for expr in flatten([order_by])) options = SortOptions(descending=descending, nulls_last=nulls_last) order = by, options - return Over().to_window_expr(self._ir, partition, order).to_narwhals() + return self._from_ir(Over().to_window_expr(self._ir, partition, order)) def sort_by( self, @@ -121,78 +127,98 @@ def sort_by( options = SortMultipleOptions(descending=desc, nulls_last=nulls) return self._from_ir(expr.SortBy(expr=self._ir, by=sort_by, options=options)) - def __eq__(self, other: DummyExpr) -> DummyExpr: # type: ignore[override] + def __eq__(self, other: DummyExpr) -> Self: # type: ignore[override] op = ops.Eq() - return op.to_binary_expr(self._ir, other._ir).to_narwhals() + return self._from_ir(op.to_binary_expr(self._ir, other._ir)) - def __ne__(self, other: DummyExpr) -> DummyExpr: # type: ignore[override] + def __ne__(self, other: DummyExpr) -> Self: # type: ignore[override] op = ops.NotEq() - return op.to_binary_expr(self._ir, other._ir).to_narwhals() + return self._from_ir(op.to_binary_expr(self._ir, other._ir)) - def __lt__(self, other: DummyExpr) -> DummyExpr: + def __lt__(self, other: DummyExpr) -> Self: op = ops.Lt() - return op.to_binary_expr(self._ir, other._ir).to_narwhals() + return self._from_ir(op.to_binary_expr(self._ir, other._ir)) - def __le__(self, other: DummyExpr) -> DummyExpr: + def __le__(self, other: DummyExpr) -> Self: op = ops.LtEq() - return op.to_binary_expr(self._ir, other._ir).to_narwhals() + return self._from_ir(op.to_binary_expr(self._ir, other._ir)) - def __gt__(self, other: DummyExpr) -> DummyExpr: + def __gt__(self, other: DummyExpr) -> Self: op = ops.Gt() - return op.to_binary_expr(self._ir, other._ir).to_narwhals() + return self._from_ir(op.to_binary_expr(self._ir, other._ir)) - def __ge__(self, other: DummyExpr) -> DummyExpr: + def __ge__(self, other: DummyExpr) -> Self: op = ops.GtEq() - return op.to_binary_expr(self._ir, other._ir).to_narwhals() + return self._from_ir(op.to_binary_expr(self._ir, other._ir)) - def __add__(self, other: DummyExpr) -> DummyExpr: + def __add__(self, other: DummyExpr) -> Self: op = ops.Add() - return op.to_binary_expr(self._ir, other._ir).to_narwhals() + return self._from_ir(op.to_binary_expr(self._ir, other._ir)) - def __sub__(self, other: DummyExpr) -> DummyExpr: + def __sub__(self, other: DummyExpr) -> Self: op = ops.Sub() - return op.to_binary_expr(self._ir, other._ir).to_narwhals() + return self._from_ir(op.to_binary_expr(self._ir, other._ir)) - def __mul__(self, other: DummyExpr) -> DummyExpr: + def __mul__(self, other: DummyExpr) -> Self: op = ops.Multiply() - return op.to_binary_expr(self._ir, other._ir).to_narwhals() + return self._from_ir(op.to_binary_expr(self._ir, other._ir)) - def __truediv__(self, other: DummyExpr) -> DummyExpr: + def __truediv__(self, other: DummyExpr) -> Self: op = ops.TrueDivide() - return op.to_binary_expr(self._ir, other._ir).to_narwhals() + return self._from_ir(op.to_binary_expr(self._ir, other._ir)) - def __floordiv__(self, other: DummyExpr) -> DummyExpr: + def __floordiv__(self, other: DummyExpr) -> Self: op = ops.FloorDivide() - return op.to_binary_expr(self._ir, other._ir).to_narwhals() + return self._from_ir(op.to_binary_expr(self._ir, other._ir)) - def __mod__(self, other: DummyExpr) -> DummyExpr: + def __mod__(self, other: DummyExpr) -> Self: op = ops.Modulus() - return op.to_binary_expr(self._ir, other._ir).to_narwhals() + return self._from_ir(op.to_binary_expr(self._ir, other._ir)) - def __and__(self, other: DummyExpr) -> DummyExpr: + def __and__(self, other: DummyExpr) -> Self: op = ops.And() - return op.to_binary_expr(self._ir, other._ir).to_narwhals() + return self._from_ir(op.to_binary_expr(self._ir, other._ir)) - def __or__(self, other: DummyExpr) -> DummyExpr: + def __or__(self, other: DummyExpr) -> Self: op = ops.Or() - return op.to_binary_expr(self._ir, other._ir).to_narwhals() + return self._from_ir(op.to_binary_expr(self._ir, other._ir)) + + def __invert__(self) -> Self: + return self._from_ir(boolean.Not().to_function_expr(self._ir)) + - def __invert__(self) -> DummyExpr: - return boolean.Not().to_function_expr(self._ir).to_narwhals() +class DummyExprV1(DummyExpr): + _version: t.ClassVar[Version] = Version.V1 class DummyCompliantExpr: _ir: ExprIR + _version: Version + + @property + def version(self) -> Version: + return self._version @classmethod - def _from_ir(cls, ir: ExprIR, /) -> Self: + def _from_ir(cls, ir: ExprIR, /, version: Version) -> Self: obj = cls.__new__(cls) obj._ir = ir + obj._version = version return obj + def to_narwhals(self) -> DummyExpr: + if self.version is Version.MAIN: + return DummyExpr._from_ir(self._ir) + return DummyExprV1._from_ir(self._ir) + class DummySeries: _compliant: DummyCompliantSeries + _version: t.ClassVar[Version] = Version.MAIN + + @property + def version(self) -> Version: + return self._version @property def dtype(self) -> DType: @@ -205,31 +231,38 @@ def name(self) -> str: @classmethod def from_native(cls, native: NativeSeries, /) -> Self: obj = cls.__new__(cls) - obj._compliant = DummyCompliantSeries.from_native(native) + obj._compliant = DummyCompliantSeries.from_native(native, cls._version) return obj +class DummySeriesV1(DummySeries): + _version: t.ClassVar[Version] = Version.V1 + + class DummyCompliantSeries: _native: NativeSeries _name: str + _version: Version + + @property + def version(self) -> Version: + return self._version @property def dtype(self) -> DType: - return Version.MAIN.dtypes.Float64() + return self.version.dtypes.Float64() @property def name(self) -> str: return self._name @classmethod - def from_native(cls, native: NativeSeries, /) -> Self: - from narwhals.utils import _hasattr_static - + def from_native(cls, native: NativeSeries, /, version: Version) -> Self: name: str = "" - if _hasattr_static(native, "name"): name = getattr(native, "name", name) obj = cls.__new__(cls) obj._native = native obj._name = name + obj._version = version return obj From afa1c3222e68aab39a22e437e822d599deb4850c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 19 May 2025 15:53:24 +0100 Subject: [PATCH 081/368] feat: Start adding `IntoExpr` parsing - `lit` deals with the `Series` case - Very basic tests --- narwhals/_plan/common.py | 42 ++++++++ narwhals/_plan/demo.py | 13 +++ narwhals/_plan/expr_parsing.py | 180 ++++++++++++++++++++++++++++++++ tests/plan/__init__.py | 0 tests/plan/expr_parsing_test.py | 37 +++++++ 5 files changed, 272 insertions(+) create mode 100644 narwhals/_plan/expr_parsing.py create mode 100644 tests/plan/__init__.py create mode 100644 tests/plan/expr_parsing_test.py diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index ee19f17927..702fc50892 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -1,5 +1,7 @@ from __future__ import annotations +import datetime as dt +from decimal import Decimal from typing import TYPE_CHECKING from typing import TypeVar @@ -12,12 +14,15 @@ from typing_extensions import Never from typing_extensions import Self from typing_extensions import TypeAlias + from typing_extensions import TypeIs from typing_extensions import dataclass_transform from narwhals._plan.dummy import DummyCompliantExpr from narwhals._plan.dummy import DummyExpr + from narwhals._plan.dummy import DummySeries from narwhals._plan.expr import FunctionExpr from narwhals._plan.options import FunctionOptions + from narwhals.typing import NonNestedLiteral else: # NOTE: This isn't important to the proposal, just wanted IDE support @@ -58,6 +63,9 @@ def decorator(cls_or_fn: T) -> T: Udf: TypeAlias = "Callable[[Any], Any]" """Placeholder for `map_batches(function=...)`.""" +IntoExprColumn: TypeAlias = "DummyExpr | DummySeries | str" +IntoExpr: TypeAlias = "NonNestedLiteral | IntoExprColumn" + @dataclass_transform(kw_only_default=True, frozen_default=True) class Immutable: @@ -162,3 +170,37 @@ def to_function_expr(self, *inputs: ExprIR) -> FunctionExpr[Self]: # Feel like it should be the union of `input` & `function` PLACEHOLDER = FunctionOptions.default() # noqa: N806 return FunctionExpr(input=inputs, function=self, options=PLACEHOLDER) + + +_NON_NESTED_LITERAL_TPS = ( + int, + float, + str, + dt.date, + dt.time, + dt.timedelta, + bytes, + Decimal, +) + + +def is_non_nested_literal(obj: Any) -> TypeIs[NonNestedLiteral]: + return obj is None or isinstance(obj, _NON_NESTED_LITERAL_TPS) + + +def is_expr(obj: Any) -> TypeIs[DummyExpr]: + from narwhals._plan.dummy import DummyExpr + + return isinstance(obj, DummyExpr) + + +def is_series(obj: Any) -> TypeIs[DummySeries]: + from narwhals._plan.dummy import DummySeries + + return isinstance(obj, DummySeries) + + +def is_iterable_reject(obj: Any) -> TypeIs[str | bytes | DummySeries]: + from narwhals._plan.dummy import DummySeries + + return isinstance(obj, (str, bytes, DummySeries)) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index 0e5b1f8452..d1f17c8613 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -5,7 +5,11 @@ from narwhals._plan import aggregation as agg from narwhals._plan import boolean +from narwhals._plan import expr_parsing as parse from narwhals._plan import functions as F # noqa: N812 +from narwhals._plan.common import ExprIR +from narwhals._plan.common import IntoExpr +from narwhals._plan.common import is_non_nested_literal from narwhals._plan.dummy import DummySeries from narwhals._plan.expr import All from narwhals._plan.expr import Column @@ -57,6 +61,9 @@ def lit( return SeriesLiteral(value=value).to_literal().to_narwhals() if dtype is None or not isinstance(dtype, DType): dtype = Version.MAIN.dtypes.Unknown() + if not is_non_nested_literal(value): + msg = f"{type(value).__name__!r} is not supported in `nw.lit`, got: {value!r}." + raise TypeError(msg) return ScalarLiteral(value=value, dtype=dtype).to_literal().to_narwhals() @@ -170,3 +177,9 @@ def ensure_orderable_rules(*exprs: DummyExpr) -> tuple[DummyExpr, ...]: if not _is_order_enforcing_previous(previous): raise _order_dependent_error(node) return exprs + + +def select_context( + *exprs: IntoExpr | t.Iterable[IntoExpr], **named_exprs: IntoExpr +) -> tuple[ExprIR, ...]: + return parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs) diff --git a/narwhals/_plan/expr_parsing.py b/narwhals/_plan/expr_parsing.py new file mode 100644 index 0000000000..c9000b36bf --- /dev/null +++ b/narwhals/_plan/expr_parsing.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +# ruff: noqa: A002 +from typing import TYPE_CHECKING +from typing import Iterable +from typing import Sequence +from typing import TypeVar + +from narwhals._plan.common import is_expr +from narwhals._plan.common import is_iterable_reject +from narwhals.dependencies import get_polars +from narwhals.dependencies import is_pandas_dataframe +from narwhals.dependencies import is_pandas_series +from narwhals.exceptions import InvalidIntoExprError + +if TYPE_CHECKING: + from typing import Any + from typing import Iterator + + from typing_extensions import TypeAlias + from typing_extensions import TypeIs + + from narwhals._plan.common import ExprIR + from narwhals._plan.common import IntoExpr + from narwhals._plan.common import Seq + from narwhals.dtypes import DType + +T = TypeVar("T") + +_RaisesInvalidIntoExprError: TypeAlias = "Any" +""" +Placeholder for multiple `Iterable[IntoExpr]`. + +We only support cases `a`, `b`, but the typing for most contexts is more permissive: + +>>> import polars as pl +>>> df = pl.DataFrame({"one": ["A", "B", "A"], "two": [1, 2, 3], "three": [4, 5, 6]}) +>>> a = ("one", "two") +>>> b = (["one", "two"],) +>>> +>>> c = ("one", ["two"]) +>>> d = (["one"], "two") +>>> [df.select(*into) for into in (a, b, c, d)] +[shape: (3, 2) + ┌─────┬─────┐ + │ one ┆ two │ + │ --- ┆ --- │ + │ str ┆ i64 │ + ╞═════╪═════╡ + │ A ┆ 1 │ + │ B ┆ 2 │ + │ A ┆ 3 │ + └─────┴─────┘, + shape: (3, 2) + ┌─────┬─────┐ + │ one ┆ two │ + │ --- ┆ --- │ + │ str ┆ i64 │ + ╞═════╪═════╡ + │ A ┆ 1 │ + │ B ┆ 2 │ + │ A ┆ 3 │ + └─────┴─────┘, + shape: (3, 2) + ┌─────┬───────────┐ + │ one ┆ literal │ + │ --- ┆ --- │ + │ str ┆ list[str] │ + ╞═════╪═══════════╡ + │ A ┆ ["two"] │ + │ B ┆ ["two"] │ + │ A ┆ ["two"] │ + └─────┴───────────┘, + shape: (3, 2) + ┌───────────┬─────┐ + │ literal ┆ two │ + │ --- ┆ --- │ + │ list[str] ┆ i64 │ + ╞═══════════╪═════╡ + │ ["one"] ┆ 1 │ + │ ["one"] ┆ 2 │ + │ ["one"] ┆ 3 │ + └───────────┴─────┘] +""" + + +def parse_into_expr_ir( + input: IntoExpr, *, str_as_lit: bool = False, dtype: DType | None = None +) -> ExprIR: + """Parse a single input into an `ExprIR` node.""" + from narwhals._plan import demo as nwd + + if is_expr(input): + expr = input + elif isinstance(input, str) and not str_as_lit: + expr = nwd.col(input) + else: + expr = nwd.lit(input, dtype=dtype) + return expr._ir + + +def parse_into_seq_of_expr_ir( + first_input: IntoExpr | Iterable[IntoExpr] = (), + *more_inputs: IntoExpr | _RaisesInvalidIntoExprError, + **named_inputs: IntoExpr, +) -> Seq[ExprIR]: + """Parse variadic inputs into a flat sequence of `ExprIR` nodes.""" + return tuple(_parse_into_iter_expr_ir(first_input, *more_inputs, **named_inputs)) + + +def _parse_into_iter_expr_ir( + first_input: IntoExpr | Iterable[IntoExpr], + *more_inputs: IntoExpr, + **named_inputs: IntoExpr, +) -> Iterator[ExprIR]: + if not _is_empty_sequence(first_input): + # NOTE: These need to be separated to introduce an intersection type + # Otherwise, `str | bytes` always passes through typing + if _is_iterable(first_input) and not is_iterable_reject(first_input): + if more_inputs: + raise _invalid_into_expr_error(first_input, more_inputs, named_inputs) + else: + yield from _parse_positional_inputs(first_input) + else: + yield parse_into_expr_ir(first_input) + else: + # NOTE: Passthrough case for no inputs - but gets skipped when calling next + yield from () + if more_inputs: + yield from _parse_positional_inputs(more_inputs) + if named_inputs: + yield from _parse_named_inputs(named_inputs) + + +def _parse_positional_inputs(inputs: Iterable[IntoExpr], /) -> Iterator[ExprIR]: + for into in inputs: + yield parse_into_expr_ir(into) + + +def _parse_named_inputs(named_inputs: dict[str, IntoExpr], /) -> Iterator[ExprIR]: + from narwhals._plan.expr import Alias + + for name, input in named_inputs.items(): + yield Alias(expr=parse_into_expr_ir(input), name=name) + + +def _is_iterable(obj: Iterable[T] | Any) -> TypeIs[Iterable[T]]: + if is_pandas_dataframe(obj) or is_pandas_series(obj): + msg = f"Expected Narwhals class or scalar, got: {type(obj)}. Perhaps you forgot a `nw.from_native` somewhere?" + raise TypeError(msg) + if _is_polars(obj): + msg = ( + f"Expected Narwhals class or scalar, got: {type(obj)}.\n\n" + "Hint: Perhaps you\n" + "- forgot a `nw.from_native` somewhere?\n" + "- used `pl.col` instead of `nw.col`?" + ) + raise TypeError(msg) + return isinstance(obj, Iterable) + + +def _is_empty_sequence(obj: Any) -> bool: + return isinstance(obj, Sequence) and not obj + + +def _is_polars(obj: Any) -> bool: + return (pl := get_polars()) is not None and isinstance( + obj, (pl.Series, pl.Expr, pl.DataFrame, pl.LazyFrame) + ) + + +def _invalid_into_expr_error( + first_input: Any, more_inputs: Any, named_inputs: Any +) -> InvalidIntoExprError: + msg = ( + f"Passing both iterable and positional inputs is not supported.\n" + f"Hint:\nInstead try collecting all arguments into a {type(first_input).__name__!r}\n" + f"{first_input!r}\n{more_inputs!r}\n{named_inputs!r}" + ) + return InvalidIntoExprError(msg) diff --git a/tests/plan/__init__.py b/tests/plan/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py new file mode 100644 index 0000000000..f9d40ffaa4 --- /dev/null +++ b/tests/plan/expr_parsing_test.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Iterable + +import pytest + +import narwhals as nw +import narwhals._plan.demo as nwd +from narwhals._plan.common import ExprIR + +if TYPE_CHECKING: + from narwhals._plan.common import IntoExpr + from narwhals._plan.common import Seq + + +@pytest.mark.parametrize( + ("exprs", "named_exprs"), + [ + ([nwd.col("a")], {}), + (["a"], {}), + ([], {"a": "b"}), + ([], {"a": nwd.col("b")}), + (["a", "b", nwd.col("c", "d", "e")], {"g": nwd.lit(1)}), + ([["a", "b", "c"]], {"q": nwd.lit(5, nw.Int8())}), + ( + [[nwd.nth(1), nwd.nth(2, 3, 4)]], + {"n": nwd.col("p").count(), "other n": nwd.len()}, + ), + ], +) +def test_parsing( + exprs: Seq[IntoExpr | Iterable[IntoExpr]], named_exprs: dict[str, IntoExpr] +) -> None: + assert all( + isinstance(node, ExprIR) for node in nwd.select_context(*exprs, **named_exprs) + ) From 0bada4851ebd08380da1d000e1d2e553229e1689 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 19 May 2025 16:40:25 +0100 Subject: [PATCH 082/368] feat: Permissive parsing in `*_horiztonal` Threw in some tests for hashing as well --- narwhals/_plan/demo.py | 24 +++++++++--------- tests/plan/expr_parsing_test.py | 45 +++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 12 deletions(-) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index d1f17c8613..a4a3c6e50a 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -95,33 +95,33 @@ def sum(*columns: str) -> DummyExpr: return col(columns).sum() -def all_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr: - it = (expr._ir for expr in flatten(exprs)) +def all_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> DummyExpr: + it = parse.parse_into_seq_of_expr_ir(*exprs) return boolean.AllHorizontal().to_function_expr(*it).to_narwhals() -def any_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr: - it = (expr._ir for expr in flatten(exprs)) +def any_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> DummyExpr: + it = parse.parse_into_seq_of_expr_ir(*exprs) return boolean.AnyHorizontal().to_function_expr(*it).to_narwhals() -def sum_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr: - it = (expr._ir for expr in flatten(exprs)) +def sum_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> DummyExpr: + it = parse.parse_into_seq_of_expr_ir(*exprs) return F.SumHorizontal().to_function_expr(*it).to_narwhals() -def min_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr: - it = (expr._ir for expr in flatten(exprs)) +def min_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> DummyExpr: + it = parse.parse_into_seq_of_expr_ir(*exprs) return F.MinHorizontal().to_function_expr(*it).to_narwhals() -def max_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr: - it = (expr._ir for expr in flatten(exprs)) +def max_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> DummyExpr: + it = parse.parse_into_seq_of_expr_ir(*exprs) return F.MaxHorizontal().to_function_expr(*it).to_narwhals() -def mean_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr: - it = (expr._ir for expr in flatten(exprs)) +def mean_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> DummyExpr: + it = parse.parse_into_seq_of_expr_ir(*exprs) return F.MeanHorizontal().to_function_expr(*it).to_narwhals() diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index f9d40ffaa4..eb88879f26 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -1,13 +1,19 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing import Callable from typing import Iterable import pytest import narwhals as nw import narwhals._plan.demo as nwd +from narwhals._plan import boolean +from narwhals._plan import functions as F # noqa: N812 from narwhals._plan.common import ExprIR +from narwhals._plan.common import Function +from narwhals._plan.dummy import DummyExpr +from narwhals._plan.expr import FunctionExpr if TYPE_CHECKING: from narwhals._plan.common import IntoExpr @@ -35,3 +41,42 @@ def test_parsing( assert all( isinstance(node, ExprIR) for node in nwd.select_context(*exprs, **named_exprs) ) + + +@pytest.mark.parametrize( + ("function", "ir_node"), + [ + (nwd.all_horizontal, boolean.AllHorizontal), + (nwd.any_horizontal, boolean.AnyHorizontal), + (nwd.sum_horizontal, F.SumHorizontal), + (nwd.min_horizontal, F.MinHorizontal), + (nwd.max_horizontal, F.MaxHorizontal), + (nwd.mean_horizontal, F.MeanHorizontal), + ], +) +@pytest.mark.parametrize( + "args", + [ + ("a", "b", "c"), + (["a", "b", "c"]), + (nwd.col("d", "e", "f"), nwd.col("g"), "q", nwd.nth(9)), + ((nwd.lit(1),)), + ([nwd.lit(1), nwd.lit(2), nwd.lit(3)]), + ], +) +def test_function_expr_horizontal( + function: Callable[..., DummyExpr], + ir_node: type[Function], + args: Seq[IntoExpr | Iterable[IntoExpr]], +) -> None: + variadic = function(*args) + sequence = function(args) + assert isinstance(variadic, DummyExpr) + assert isinstance(sequence, DummyExpr) + variadic_node = variadic._ir + sequence_node = sequence._ir + unrelated_node = nwd.lit(1)._ir + assert isinstance(variadic_node, FunctionExpr) + assert isinstance(variadic_node.function, ir_node) + assert variadic_node == sequence_node + assert sequence_node != unrelated_node From 98292dcae05d651722b5a924dce6043c085e74f0 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 20 May 2025 21:43:38 +0100 Subject: [PATCH 083/368] exploring `pl.Expr.meta` - Seems that the first step should be adding an iteration method on nodes for traversal - `polars` has some clearly defined rules over the order nodes should appear in --- narwhals/_plan/meta.py | 85 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 narwhals/_plan/meta.py diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py new file mode 100644 index 0000000000..44f6aee872 --- /dev/null +++ b/narwhals/_plan/meta.py @@ -0,0 +1,85 @@ +"""`pl.Expr.meta` namespace functionality. + +- It seems like there might be a need to distinguish the top-level nodes for iterating + - polars_plan::dsl::expr::Expr +- https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/meta.rs#L11-L111 +- https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/plans/iterator.rs#L10-L105 +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any + + import polars as pl + + from narwhals._plan.common import ExprIR + + +class ExprIRMetaNamespace: + """Requires defining iterator behavior per node.""" + + def __init__(self, ir: ExprIR, /) -> None: + self._ir: ExprIR = ir + + def has_multiple_outputs(self) -> bool: + raise NotImplementedError + + def is_column(self) -> bool: + """Only one that doesn't require iter. + + https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/meta.rs#L65-L71 + """ + from narwhals._plan.expr import Column + + return isinstance(self._ir, Column) + + def is_column_selection(self, *, allow_aliasing: bool = False) -> bool: + raise NotImplementedError + + def is_literal(self, *, allow_aliasing: bool = False) -> bool: + raise NotImplementedError + + def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: + """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/utils.rs#L126-L127.""" + raise NotImplementedError + + # NOTE: Less important for us, but maybe nice to have + def pop(self) -> list[ExprIR]: + raise NotImplementedError + + def root_names(self) -> list[str]: + raise NotImplementedError + + def undo_aliases(self) -> ExprIR: + raise NotImplementedError + + # NOTE: We don't support `nw.col("*")` or other patterns in col + # Maybe not relevant at all + def is_regex_projection(self) -> bool: + raise NotImplementedError + + +def profile_polars_expr(expr: pl.Expr) -> dict[str, Any]: + """Gather all metadata for a native `Expr`. + + Eventual goal would be that a `nw.Expr` matches a `pl.Expr` in as much of this as possible. + """ + return { + "has_multiple_outputs": expr.meta.has_multiple_outputs(), + "is_column": expr.meta.is_column(), + "is_regex_projection": expr.meta.is_regex_projection(), + "is_column_selection": expr.meta.is_column_selection(), + "is_column_selection(allow_aliasing=True)": expr.meta.is_column_selection( + allow_aliasing=True + ), + "is_literal": expr.meta.is_literal(), + "is_literal(allow_aliasing=True)": expr.meta.is_literal(allow_aliasing=True), + "output_name": expr.meta.output_name(raise_if_undetermined=False), + "root_names": expr.meta.root_names(), + "pop": expr.meta.pop(), + "undo_aliases": expr.meta.undo_aliases(), + "expr": expr, + } From be6ec3ef31be18c93cce90aebd9104dedfe2c151 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 21 May 2025 13:40:54 +0100 Subject: [PATCH 084/368] feat: Add methods for all of `functions` Still need: - reprs - fix the hierarchy issue (https://github.com/narwhals-dev/narwhals/pull/2572#discussion_r2099719927) - Flag summing (https://github.com/narwhals-dev/narwhals/pull/2572#issuecomment-2891577685) --- narwhals/_plan/dummy.py | 250 ++++++++++++++++++++++++++++++++++++ narwhals/_plan/functions.py | 29 ++++- 2 files changed, 276 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 5385abf8da..bb64cc7da8 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -8,11 +8,18 @@ from narwhals._plan import aggregation as agg from narwhals._plan import boolean from narwhals._plan import expr +from narwhals._plan import expr_parsing as parse +from narwhals._plan import functions as F # noqa: N812 from narwhals._plan import operators as ops +from narwhals._plan.options import EWMOptions +from narwhals._plan.options import RankOptions +from narwhals._plan.options import RollingOptionsFixedWindow +from narwhals._plan.options import RollingVarParams from narwhals._plan.options import SortMultipleOptions from narwhals._plan.options import SortOptions from narwhals._plan.window import Over from narwhals.dtypes import DType +from narwhals.exceptions import ComputeError from narwhals.utils import Version from narwhals.utils import _hasattr_static from narwhals.utils import flatten @@ -21,9 +28,16 @@ from typing_extensions import Self from narwhals._plan.common import ExprIR + from narwhals._plan.common import IntoExpr + from narwhals._plan.common import IntoExprColumn from narwhals._plan.common import Seq + from narwhals._plan.common import Udf + from narwhals.typing import FillNullStrategy from narwhals.typing import NativeSeries + from narwhals.typing import NumericLiteral + from narwhals.typing import RankMethod from narwhals.typing import RollingInterpolationMethod + from narwhals.typing import TemporalLiteral # NOTE: Overly simplified placeholders for mocking typing @@ -127,6 +141,237 @@ def sort_by( options = SortMultipleOptions(descending=desc, nulls_last=nulls) return self._from_ir(expr.SortBy(expr=self._ir, by=sort_by, options=options)) + def abs(self) -> Self: + return self._from_ir(F.Abs().to_function_expr(self._ir)) + + def hist( + self, + bins: t.Sequence[float] | None = None, + *, + bin_count: int | None = None, + include_breakpoint: bool = True, + ) -> Self: + node: F.Hist + if bins is not None: + if bin_count is not None: + msg = "can only provide one of `bin_count` or `bins`" + raise ComputeError(msg) + node = F.HistBins(bins=tuple(bins), include_breakpoint=include_breakpoint) + elif bin_count is not None: + node = F.HistBinCount( + bin_count=bin_count, include_breakpoint=include_breakpoint + ) + else: + node = F.HistBinCount(include_breakpoint=include_breakpoint) + return self._from_ir(node.to_function_expr(self._ir)) + + def null_count(self) -> Self: + return self._from_ir(F.NullCount().to_function_expr(self._ir)) + + def fill_null( + self, + value: IntoExpr = None, + strategy: FillNullStrategy | None = None, + limit: int | None = None, + ) -> Self: + node: F.FillNullWithStrategy | F.FillNull + if strategy is not None: + node = F.FillNullWithStrategy(strategy=strategy, limit=limit) + else: + node = F.FillNull(value=parse.parse_into_expr_ir(value, str_as_lit=True)) + return self._from_ir(node.to_function_expr(self._ir)) + + def shift(self, n: int) -> Self: + return self._from_ir(F.Shift(n=n).to_function_expr(self._ir)) + + def drop_nulls(self) -> Self: + return self._from_ir(F.DropNulls().to_function_expr(self._ir)) + + def mode(self) -> Self: + return self._from_ir(F.Mode().to_function_expr(self._ir)) + + def skew(self) -> Self: + return self._from_ir(F.Skew().to_function_expr(self._ir)) + + def rank(self, method: RankMethod = "average", *, descending: bool = False) -> Self: + options = RankOptions(method=method, descending=descending) + return self._from_ir(F.Rank(options=options).to_function_expr(self._ir)) + + def clip( + self, + lower_bound: IntoExprColumn | NumericLiteral | TemporalLiteral | None = None, + upper_bound: IntoExprColumn | NumericLiteral | TemporalLiteral | None = None, + ) -> Self: + return self._from_ir( + F.Clip().to_function_expr( + self._ir, *parse.parse_into_seq_of_expr_ir(lower_bound, upper_bound) + ) + ) + + def cum_count(self, *, reverse: bool = False) -> Self: + return self._from_ir(F.CumCount(reverse=reverse).to_function_expr(self._ir)) + + def cum_min(self, *, reverse: bool = False) -> Self: + return self._from_ir(F.CumMin(reverse=reverse).to_function_expr(self._ir)) + + def cum_max(self, *, reverse: bool = False) -> Self: + return self._from_ir(F.CumMax(reverse=reverse).to_function_expr(self._ir)) + + def cum_prod(self, *, reverse: bool = False) -> Self: + return self._from_ir(F.CumProd(reverse=reverse).to_function_expr(self._ir)) + + def rolling_sum( + self, + window_size: int, + *, + min_samples: int | None = None, + center: bool = False, + ) -> Self: + min_samples = window_size if min_samples is None else min_samples + fn_params = None + options = RollingOptionsFixedWindow( + window_size=window_size, + min_samples=min_samples, + center=center, + fn_params=fn_params, + ) + function = F.RollingSum(options=options) + return self._from_ir(function.to_function_expr(self._ir)) + + def rolling_mean( + self, + window_size: int, + *, + min_samples: int | None = None, + center: bool = False, + ) -> Self: + min_samples = window_size if min_samples is None else min_samples + fn_params = None + options = RollingOptionsFixedWindow( + window_size=window_size, + min_samples=min_samples, + center=center, + fn_params=fn_params, + ) + function = F.RollingMean(options=options) + return self._from_ir(function.to_function_expr(self._ir)) + + def rolling_var( + self, + window_size: int, + *, + min_samples: int | None = None, + center: bool = False, + ddof: int = 1, + ) -> Self: + min_samples = window_size if min_samples is None else min_samples + fn_params = RollingVarParams(ddof=ddof) + options = RollingOptionsFixedWindow( + window_size=window_size, + min_samples=min_samples, + center=center, + fn_params=fn_params, + ) + function = F.RollingVar(options=options) + return self._from_ir(function.to_function_expr(self._ir)) + + def rolling_std( + self, + window_size: int, + *, + min_samples: int | None = None, + center: bool = False, + ddof: int = 1, + ) -> Self: + min_samples = window_size if min_samples is None else min_samples + fn_params = RollingVarParams(ddof=ddof) + options = RollingOptionsFixedWindow( + window_size=window_size, + min_samples=min_samples, + center=center, + fn_params=fn_params, + ) + function = F.RollingStd(options=options) + return self._from_ir(function.to_function_expr(self._ir)) + + def diff(self) -> Self: + return self._from_ir(F.Diff().to_function_expr(self._ir)) + + def unique(self) -> Self: + return self._from_ir(F.Unique().to_function_expr(self._ir)) + + def round(self, decimals: int = 0) -> Self: + return self._from_ir(F.Round(decimals=decimals).to_function_expr(self._ir)) + + def ewm_mean( + self, + *, + com: float | None = None, + span: float | None = None, + half_life: float | None = None, + alpha: float | None = None, + adjust: bool = True, + min_samples: int = 1, + ignore_nulls: bool = False, + ) -> Self: + options = EWMOptions( + com=com, + span=span, + half_life=half_life, + alpha=alpha, + adjust=adjust, + min_samples=min_samples, + ignore_nulls=ignore_nulls, + ) + return self._from_ir(F.EwmMean(options=options).to_function_expr(self._ir)) + + def replace_strict( + self, + old: t.Sequence[t.Any] | t.Mapping[t.Any, t.Any], + new: t.Sequence[t.Any] | None = None, + *, + return_dtype: DType | type[DType] | None = None, + ) -> Self: + before: Seq[t.Any] + after: Seq[t.Any] + if new is None: + if not isinstance(old, t.Mapping): + msg = "`new` argument is required if `old` argument is not a Mapping type" + raise TypeError(msg) + before = tuple(old) + after = tuple(old.values()) + elif isinstance(old, t.Mapping): + # NOTE: polars raises later when this occurs + # TypeError: cannot create expression literal for value of type dict. + # Hint: Pass `allow_object=True` to accept any value and create a literal of type Object. + msg = "`new` argument cannot be used if `old` argument is a Mapping type" + raise TypeError(msg) + else: + before = tuple(old) + after = tuple(new) + function = F.ReplaceStrict(old=before, new=after, return_dtype=return_dtype) + return self._from_ir(function.to_function_expr(self._ir)) + + def gather_every(self, n: int, offset: int = 0) -> Self: + return self._from_ir(F.GatherEvery(n=n, offset=offset).to_function_expr(self._ir)) + + def map_batches( + self, + function: Udf, + return_dtype: DType | None = None, + *, + is_elementwise: bool = False, + returns_scalar: bool = False, + ) -> Self: + return self._from_ir( + F.MapBatches( + function=function, + return_dtype=return_dtype, + is_elementwise=is_elementwise, + returns_scalar=returns_scalar, + ).to_function_expr(self._ir) + ) + def __eq__(self, other: DummyExpr) -> Self: # type: ignore[override] op = ops.Eq() return self._from_ir(op.to_binary_expr(self._ir, other._ir)) @@ -186,6 +431,11 @@ def __or__(self, other: DummyExpr) -> Self: def __invert__(self) -> Self: return self._from_ir(boolean.Not().to_function_expr(self._ir)) + def __pow__(self, other: IntoExpr) -> Self: + exponent = parse.parse_into_expr_ir(other, str_as_lit=True) + base = self._ir + return self._from_ir(F.Pow().to_function_expr(base, exponent)) + class DummyExprV1(DummyExpr): _version: t.ClassVar[Version] = Version.V1 diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 86943156a5..c2d6f3490a 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -12,8 +12,11 @@ from narwhals._plan.common import Function from narwhals._plan.options import FunctionFlags from narwhals._plan.options import FunctionOptions +from narwhals.exceptions import ComputeError if TYPE_CHECKING: + from typing import Any + from narwhals._plan.common import Seq from narwhals._plan.common import Udf from narwhals._plan.options import EWMOptions @@ -44,15 +47,28 @@ def function_options(self) -> FunctionOptions: class HistBins(Hist): """Subclasses for each variant.""" - __slots__ = (*Hist.__slots__, "bins") + __slots__ = ("bins", *Hist.__slots__) bins: Seq[float] + def __init__(self, *, bins: Seq[float], include_breakpoint: bool = True) -> None: + for i in range(1, len(bins)): + if bins[i - 1] >= bins[i]: + msg = "bins must increase monotonically" + raise ComputeError(msg) + object.__setattr__(self, "bins", bins) + object.__setattr__(self, "include_breakpoint", include_breakpoint) + class HistBinCount(Hist): - __slots__ = (*Hist.__slots__, "bin_count") + __slots__ = ("bin_count", *Hist.__slots__) bin_count: int + """Polars (v1.20) sets `bin_count=10` if neither `bins` or `bin_count` are provided.""" + + def __init__(self, *, bin_count: int = 10, include_breakpoint: bool = True) -> None: + object.__setattr__(self, "bin_count", bin_count) + object.__setattr__(self, "include_breakpoint", include_breakpoint) class NullCount(Function): @@ -267,8 +283,10 @@ def function_options(self) -> FunctionOptions: class ReplaceStrict(Function): - __slots__ = ("return_dtype",) + __slots__ = ("new", "old", "return_dtype") + old: Seq[Any] + new: Seq[Any] return_dtype: DType | type[DType] | None @property @@ -277,6 +295,11 @@ def function_options(self) -> FunctionOptions: class GatherEvery(Function): + __slots__ = ("n", "offset") + + n: int + offset: int + @property def function_options(self) -> FunctionOptions: return FunctionOptions.groupwise() From 0707dbd562e64c3b3f92f9b6ffefdd83d3a1f798 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 21 May 2025 13:44:34 +0100 Subject: [PATCH 085/368] style(ruff): shrinkage --- narwhals/_plan/dummy.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index bb64cc7da8..5b87dc15f2 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -221,11 +221,7 @@ def cum_prod(self, *, reverse: bool = False) -> Self: return self._from_ir(F.CumProd(reverse=reverse).to_function_expr(self._ir)) def rolling_sum( - self, - window_size: int, - *, - min_samples: int | None = None, - center: bool = False, + self, window_size: int, *, min_samples: int | None = None, center: bool = False ) -> Self: min_samples = window_size if min_samples is None else min_samples fn_params = None @@ -239,11 +235,7 @@ def rolling_sum( return self._from_ir(function.to_function_expr(self._ir)) def rolling_mean( - self, - window_size: int, - *, - min_samples: int | None = None, - center: bool = False, + self, window_size: int, *, min_samples: int | None = None, center: bool = False ) -> Self: min_samples = window_size if min_samples is None else min_samples fn_params = None From f5e8d08b2f9243aa387a22f38e84d8457725b68a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 21 May 2025 14:05:16 +0100 Subject: [PATCH 086/368] feat: Add `polars_expr_to_dict` --- narwhals/_plan/meta.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 44f6aee872..9fe34d7168 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -62,7 +62,7 @@ def is_regex_projection(self) -> bool: raise NotImplementedError -def profile_polars_expr(expr: pl.Expr) -> dict[str, Any]: +def polars_expr_metadata(expr: pl.Expr) -> dict[str, Any]: """Gather all metadata for a native `Expr`. Eventual goal would be that a `nw.Expr` matches a `pl.Expr` in as much of this as possible. @@ -83,3 +83,15 @@ def profile_polars_expr(expr: pl.Expr) -> dict[str, Any]: "undo_aliases": expr.meta.undo_aliases(), "expr": expr, } + + +def polars_expr_to_dict(expr: pl.Expr) -> dict[str, Any]: + """Serialize a native `Expr`, roundtrip back to `dict`. + + Using to inspect [`FunctionOptions`] and ensure we combine them in a similar way. + + [`FunctionOptions`]: https://github.com/narwhals-dev/narwhals/pull/2572#issuecomment-2891577685 + """ + import json + + return json.loads(expr.meta.serialize(format="json")) # type: ignore[no-any-return] From ae8cc5c6fcac746d783e98e7eb81af8ef9c15e5f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 21 May 2025 14:21:51 +0100 Subject: [PATCH 087/368] docs: Add more notes on `name.py` Stumbled into this when trying to inspect `pl.col("a").name.prefix("-1")` --- narwhals/_plan/name.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/name.py b/narwhals/_plan/name.py index e65df5c4bd..0930e54481 100644 --- a/narwhals/_plan/name.py +++ b/narwhals/_plan/name.py @@ -10,9 +10,12 @@ class NameFunction(Function): - """`polars` version doesn't represent these as `FunctionExpr`. + """`polars` version [doesn't represent as `FunctionExpr`]. - https://github.com/pola-rs/polars/blob/6df23a09a81c640c21788607611e09d9f43b1abc/crates/polars-plan/src/dsl/name.rs + Also [doesn't support serialization]. + + [doesn't represent as `FunctionExpr`]: https://github.com/pola-rs/polars/blob/6df23a09a81c640c21788607611e09d9f43b1abc/crates/polars-plan/src/dsl/name.rs + [doesn't support serialization]: https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr_dyn_fn.rs#L145-L151 """ @property @@ -34,16 +37,24 @@ def __repr__(self) -> str: return f"name.{m[tp]}" -class Keep(NameFunction): ... +class Keep(NameFunction): + """Returns ``Expr::KeepName``.""" class Map(NameFunction): + """Returns ``Expr::RenameAlias``. + + https://github.com/pola-rs/polars/blob/6df23a09a81c640c21788607611e09d9f43b1abc/crates/polars-plan/src/dsl/name.rs#L28-L38 + """ + __slots__ = ("function",) function: AliasName class Prefix(NameFunction): + """Each of these depend on `Map`.""" + __slots__ = ("prefix",) prefix: str From e6ca72b97df070e345f99503f53a146b353052b2 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 21 May 2025 14:37:19 +0100 Subject: [PATCH 088/368] fix: Remove `ExprIR` from `Function` bases Resolves (https://github.com/narwhals-dev/narwhals/pull/2572#discussion_r2099719927) --- narwhals/_plan/common.py | 4 +++- narwhals/_plan/expr.py | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 702fc50892..9a62b61826 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -146,9 +146,11 @@ def is_scalar(self) -> bool: return False -class Function(ExprIR): +class Function(Immutable): """Shared by expr functions and namespace functions. + Only valid in `FunctionExpr.function` + https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L114 """ diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index a464ce3ee9..98ce9ed542 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -180,6 +180,10 @@ class FunctionExpr(ExprIR, t.Generic[_FunctionT]): 2. The union of (1) and any `FunctionOptions` in `inputs` """ + @property + def is_scalar(self) -> bool: + return self.function.is_scalar + def with_options(self, options: FunctionOptions, /) -> Self: options = self.options.with_flags(options.flags) return type(self)(input=self.input, function=self.function, options=options) From b7bbbef2155905f011af8fac8b7b81fe94126e6f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 21 May 2025 14:51:59 +0100 Subject: [PATCH 089/368] refactor: Rebase and rename `AnonymousFunctionExpr` -> `AnonymousExpr` --- narwhals/_plan/expr.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 98ce9ed542..e896662629 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -19,7 +19,7 @@ from narwhals._plan.common import Function from narwhals._plan.common import Seq - from narwhals._plan.functions import MapBatches + from narwhals._plan.functions import MapBatches # noqa: F401 from narwhals._plan.functions import RollingWindow from narwhals._plan.literal import LiteralValue from narwhals._plan.operators import Operator @@ -201,19 +201,9 @@ def __repr__(self) -> str: class RollingExpr(FunctionExpr[_RollingT]): ... -class AnonymousFunctionExpr(ExprIR): +class AnonymousExpr(FunctionExpr["MapBatches"]): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L158-L166.""" - __slots__ = ("function", "input", "options") - - input: Seq[ExprIR] - function: MapBatches - options: FunctionOptions - - @property - def is_scalar(self) -> bool: - return self.function.function_options.returns_scalar() - class Filter(ExprIR): __slots__ = ("by", "expr") From ecb3767edc55e359f300662bbe7c9c97254833bc Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 21 May 2025 15:04:11 +0100 Subject: [PATCH 090/368] feat: Fill out `functions` reprs --- narwhals/_plan/functions.py | 81 +++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index c2d6f3490a..927f2f3e0b 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -31,6 +31,9 @@ class Abs(Function): def function_options(self) -> FunctionOptions: return FunctionOptions.elementwise() + def __repr__(self) -> str: + return "abs" + class Hist(Function): """Only supported for `Series` so far.""" @@ -43,6 +46,9 @@ class Hist(Function): def function_options(self) -> FunctionOptions: return FunctionOptions.groupwise() + def __repr__(self) -> str: + return "hist" + class HistBins(Hist): """Subclasses for each variant.""" @@ -76,12 +82,18 @@ class NullCount(Function): def function_options(self) -> FunctionOptions: return FunctionOptions.aggregation() + def __repr__(self) -> str: + return "null_count" + class Pow(Function): @property def function_options(self) -> FunctionOptions: return FunctionOptions.elementwise() + def __repr__(self) -> str: + return "pow" + class FillNull(Function): __slots__ = ("value",) @@ -92,6 +104,9 @@ class FillNull(Function): def function_options(self) -> FunctionOptions: return FunctionOptions.elementwise() + def __repr__(self) -> str: + return "fill_null" + class FillNullWithStrategy(Function): """We don't support this variant in a lot of backends, so worth keeping it split out. @@ -114,6 +129,9 @@ def function_options(self) -> FunctionOptions: else FunctionOptions.groupwise() ) + def __repr__(self) -> str: + return "fill_null_with_strategy" + class Shift(Function): __slots__ = ("n",) @@ -125,24 +143,36 @@ class Shift(Function): def function_options(self) -> FunctionOptions: return FunctionOptions.length_preserving() + def __repr__(self) -> str: + return "shift" + class DropNulls(Function): @property def function_options(self) -> FunctionOptions: return FunctionOptions.row_separable() + def __repr__(self) -> str: + return "drop_nulls" + class Mode(Function): @property def function_options(self) -> FunctionOptions: return FunctionOptions.groupwise() + def __repr__(self) -> str: + return "mode" + class Skew(Function): @property def function_options(self) -> FunctionOptions: return FunctionOptions.aggregation() + def __repr__(self) -> str: + return "skew" + class Rank(Function): __slots__ = ("options",) @@ -153,12 +183,18 @@ class Rank(Function): def function_options(self) -> FunctionOptions: return FunctionOptions.groupwise() + def __repr__(self) -> str: + return "rank" + class Clip(Function): @property def function_options(self) -> FunctionOptions: return FunctionOptions.elementwise() + def __repr__(self) -> str: + return "clip" + class CumAgg(Function): __slots__ = ("reverse",) @@ -170,6 +206,18 @@ class CumAgg(Function): def function_options(self) -> FunctionOptions: return FunctionOptions.length_preserving() + def __repr__(self) -> str: + tp = type(self) + if tp is CumAgg: + return tp.__name__ + m: dict[type[CumAgg], str] = { + CumCount: "count", + CumMin: "min", + CumMax: "max", + CumProd: "prod", + } + return f"cum_{m[tp]}" + class RollingWindow(Function): __slots__ = ("options",) @@ -181,6 +229,18 @@ def function_options(self) -> FunctionOptions: """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/function_expr/mod.rs#L1276.""" return FunctionOptions.length_preserving() + def __repr__(self) -> str: + tp = type(self) + if tp is RollingWindow: + return tp.__name__ + m: dict[type[RollingWindow], str] = { + RollingSum: "sum", + RollingMean: "mean", + RollingVar: "var", + RollingStd: "std", + } + return f"rolling_{m[tp]}" + class CumCount(CumAgg): ... @@ -211,12 +271,18 @@ class Diff(Function): def function_options(self) -> FunctionOptions: return FunctionOptions.length_preserving() + def __repr__(self) -> str: + return "diff" + class Unique(Function): @property def function_options(self) -> FunctionOptions: return FunctionOptions.groupwise() + def __repr__(self) -> str: + return "unique" + class Round(Function): __slots__ = ("decimals",) @@ -227,6 +293,9 @@ class Round(Function): def function_options(self) -> FunctionOptions: return FunctionOptions.elementwise() + def __repr__(self) -> str: + return "round" + class SumHorizontal(Function): @property @@ -281,6 +350,9 @@ class EwmMean(Function): def function_options(self) -> FunctionOptions: return FunctionOptions.length_preserving() + def __repr__(self) -> str: + return "ewm_mean" + class ReplaceStrict(Function): __slots__ = ("new", "old", "return_dtype") @@ -293,6 +365,9 @@ class ReplaceStrict(Function): def function_options(self) -> FunctionOptions: return FunctionOptions.elementwise() + def __repr__(self) -> str: + return "replace_strict" + class GatherEvery(Function): __slots__ = ("n", "offset") @@ -304,6 +379,9 @@ class GatherEvery(Function): def function_options(self) -> FunctionOptions: return FunctionOptions.groupwise() + def __repr__(self) -> str: + return "gather_every" + class MapBatches(Function): __slots__ = ("function", "is_elementwise", "return_dtype", "returns_scalar") @@ -322,3 +400,6 @@ def function_options(self) -> FunctionOptions: if self.returns_scalar: options = options.with_flags(FunctionFlags.RETURNS_SCALAR) return options + + def __repr__(self) -> str: + return "map_batches" From bd6c0eb89f5ab07532cf58c7c4f65eb7c00239b8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 21 May 2025 15:58:32 +0100 Subject: [PATCH 091/368] fix: Align `fill_null` w/ `polars` --- narwhals/_plan/dummy.py | 15 +++++++++------ narwhals/_plan/functions.py | 5 ----- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 5b87dc15f2..4a660ca738 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -124,6 +124,10 @@ def over( order = by, options return self._from_ir(Over().to_window_expr(self._ir, partition, order)) + def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: + options = SortOptions(descending=descending, nulls_last=nulls_last) + return self._from_ir(expr.Sort(expr=self._ir, options=options)) + def sort_by( self, by: DummyExpr | t.Iterable[DummyExpr], @@ -174,12 +178,11 @@ def fill_null( strategy: FillNullStrategy | None = None, limit: int | None = None, ) -> Self: - node: F.FillNullWithStrategy | F.FillNull - if strategy is not None: - node = F.FillNullWithStrategy(strategy=strategy, limit=limit) - else: - node = F.FillNull(value=parse.parse_into_expr_ir(value, str_as_lit=True)) - return self._from_ir(node.to_function_expr(self._ir)) + if strategy is None: + ir = parse.parse_into_expr_ir(value, str_as_lit=True) + return self._from_ir(F.FillNull().to_function_expr(self._ir, ir)) + fill = F.FillNullWithStrategy(strategy=strategy, limit=limit) + return self._from_ir(fill.to_function_expr(self._ir)) def shift(self, n: int) -> Self: return self._from_ir(F.Shift(n=n).to_function_expr(self._ir)) diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 927f2f3e0b..5c16fd6db3 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -8,7 +8,6 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import ExprIR from narwhals._plan.common import Function from narwhals._plan.options import FunctionFlags from narwhals._plan.options import FunctionOptions @@ -96,10 +95,6 @@ def __repr__(self) -> str: class FillNull(Function): - __slots__ = ("value",) - - value: ExprIR - @property def function_options(self) -> FunctionOptions: return FunctionOptions.elementwise() From a0f7c9b9cd65a1d2a1ea54658a810d2621bbd157 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 21 May 2025 16:27:32 +0100 Subject: [PATCH 092/368] fix: Expand tuples in `__str__` --- narwhals/_plan/common.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 9a62b61826..3199484ccc 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -97,7 +97,7 @@ def __eq__(self, other: object) -> bool: def __str__(self) -> str: # NOTE: Debug repr, closer to constructor slots: tuple[str, ...] = self.__slots__ - fields = ", ".join(f"{name}={getattr(self, name)}" for name in slots) + fields = ", ".join(f"{_field_str(name, getattr(self, name))}" for name in slots) return f"{type(self).__name__}({fields})" def __init__(self, **kwds: Any) -> None: @@ -126,6 +126,13 @@ def __init__(self, **kwds: Any) -> None: raise TypeError(msg) +def _field_str(name: str, value: Any) -> str: + if isinstance(value, tuple): + inner = ", ".join(f"{v}" for v in value) + return f"{name}=[{inner}]" + return f"{name}={value}" + + class ExprIR(Immutable): """Anything that can be a node on a graph of expressions.""" From 0982b3a1016a618a0f2a8cdfaafac37b8f1b274b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 21 May 2025 16:59:13 +0100 Subject: [PATCH 093/368] feat: Partially resolve `FunctionExpr.options` - 1 step closer to the understanding for (https://github.com/narwhals-dev/narwhals/pull/2572#issuecomment-2891577685) - There's still some magic going on when `polars` serializes - Need to track down where `'collect_groups': 'ElementWise'` and `'collect_groups': 'GroupWise'` first appear - Seems like the flags get reduced --- narwhals/_plan/common.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 3199484ccc..8bfaf3f41f 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -173,12 +173,11 @@ def is_scalar(self) -> bool: def to_function_expr(self, *inputs: ExprIR) -> FunctionExpr[Self]: from narwhals._plan.expr import FunctionExpr - from narwhals._plan.options import FunctionOptions - # NOTE: Still need to figure out how these should be generated - # Feel like it should be the union of `input` & `function` - PLACEHOLDER = FunctionOptions.default() # noqa: N806 - return FunctionExpr(input=inputs, function=self, options=PLACEHOLDER) + # NOTE: Still need to figure out if using a closure is needed + options = self.function_options + # https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L442-L450. + return FunctionExpr(input=inputs, function=self, options=options) _NON_NESTED_LITERAL_TPS = ( From ef10243f86d6b56965603613fc5d521e42002803 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 21 May 2025 17:56:25 +0100 Subject: [PATCH 094/368] fix: Rebase `Window`, `LiteralValue` onto `Immutable` --- narwhals/_plan/literal.py | 8 ++++++-- narwhals/_plan/window.py | 5 +++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index 2a7ec58b1c..be76c415e7 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import ExprIR +from narwhals._plan.common import Immutable if TYPE_CHECKING: from narwhals._plan.dummy import DummySeries @@ -11,13 +11,17 @@ from narwhals.typing import NonNestedLiteral -class LiteralValue(ExprIR): +class LiteralValue(Immutable): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/plans/lit.rs#L67-L73.""" @property def dtype(self) -> DType: raise NotImplementedError + @property + def is_scalar(self) -> bool: + return False + def to_literal(self) -> Literal: from narwhals._plan.expr import Literal diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py index 9ee9a2c910..1e34a699a5 100644 --- a/narwhals/_plan/window.py +++ b/narwhals/_plan/window.py @@ -2,15 +2,16 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import ExprIR +from narwhals._plan.common import Immutable if TYPE_CHECKING: + from narwhals._plan.common import ExprIR from narwhals._plan.common import Seq from narwhals._plan.expr import WindowExpr from narwhals._plan.options import SortOptions -class Window(ExprIR): +class Window(Immutable): """Renamed from `WindowType`. https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/options/mod.rs#L139 From 394627dcbe3e3aa26d412e0286a3821b6cef70ea Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 21 May 2025 20:16:28 +0100 Subject: [PATCH 095/368] feat: Add Expr graph iteration methods Huge step towards `ExprMetaNamespace` --- narwhals/_plan/aggregation.py | 10 +++++ narwhals/_plan/common.py | 54 +++++++++++++++++++++++ narwhals/_plan/expr.py | 80 +++++++++++++++++++++++++++++++++++ 3 files changed, 144 insertions(+) diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index 434a00328e..9fb4fa64f7 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -5,6 +5,8 @@ from narwhals._plan.common import ExprIR if TYPE_CHECKING: + from typing import Iterator + from narwhals.typing import RollingInterpolationMethod @@ -25,6 +27,14 @@ def __repr__(self) -> str: name = m.get(tp, tp.__name__.lower()) return f"{self.expr!r}.{name}()" + def iter_left(self) -> Iterator[ExprIR]: + yield from self.expr.iter_left() + yield self + + def iter_right(self) -> Iterator[ExprIR]: + yield self + yield from self.expr.iter_right() + class Count(Agg): ... diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 8bfaf3f41f..b249ed58e8 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from typing import Any from typing import Callable + from typing import Iterator from typing_extensions import Never from typing_extensions import Self @@ -152,6 +153,59 @@ def to_compliant(self, version: Version = Version.MAIN) -> DummyCompliantExpr: def is_scalar(self) -> bool: return False + def iter_left(self) -> Iterator[ExprIR]: + """Yield nodes root->leaf. + + Examples: + >>> from narwhals._plan import demo as nwd + >>> + >>> a = nwd.col("a") + >>> b = a.alias("b") + >>> c = b.min().alias("c") + >>> d = c.over(nwd.col("e"), nwd.col("f")) + >>> + >>> list(a._ir.iter_left()) + [col('a')] + >>> + >>> list(b._ir.iter_left()) + [col('a'), col('a').alias('b')] + >>> + >>> list(c._ir.iter_left()) + [col('a'), col('a').alias('b'), col('a').alias('b').min(), col('a').alias('b').min().alias('c')] + >>> + >>> list(d._ir.iter_left()) + [col('a'), col('a').alias('b'), col('a').alias('b').min(), col('a').alias('b').min().alias('c'), col('e'), col('f'), col('a').alias('b').min().alias('c').over([col('e'), col('f')])] + """ + yield self + + def iter_right(self) -> Iterator[ExprIR]: + """Yield nodes leaf->root. + + Note: + Identical to `iter_left` for root nodes. + + Examples: + >>> from narwhals._plan import demo as nwd + >>> + >>> a = nwd.col("a") + >>> b = a.alias("b") + >>> c = b.min().alias("c") + >>> d = c.over(nwd.col("e"), nwd.col("f")) + >>> + >>> list(a._ir.iter_right()) + [col('a')] + >>> + >>> list(b._ir.iter_right()) + [col('a').alias('b'), col('a')] + >>> + >>> list(c._ir.iter_right()) + [col('a').alias('b').min().alias('c'), col('a').alias('b').min(), col('a').alias('b'), col('a')] + >>> + >>> list(d._ir.iter_right()) + [col('a').alias('b').min().alias('c').over([col('e'), col('f')]), col('f'), col('e'), col('a').alias('b').min().alias('c'), col('a').alias('b').min(), col('a').alias('b'), col('a')] + """ + yield self + class Function(Immutable): """Shared by expr functions and namespace functions. diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index e896662629..526696ec70 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -46,6 +46,14 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: return f"{self.expr!r}.alias({self.name!r})" + def iter_left(self) -> t.Iterator[ExprIR]: + yield from self.expr.iter_left() + yield self + + def iter_right(self) -> t.Iterator[ExprIR]: + yield self + yield from self.expr.iter_right() + class Column(ExprIR): __slots__ = ("name",) @@ -107,6 +115,16 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: return f"[({self.left!r}) {self.op!r} ({self.right!r})]" + def iter_left(self) -> t.Iterator[ExprIR]: + yield from self.left.iter_left() + yield from self.right.iter_left() + yield self + + def iter_right(self) -> t.Iterator[ExprIR]: + yield self + yield from self.right.iter_right() + yield from self.left.iter_right() + class Cast(ExprIR): __slots__ = ("dtype", "expr") @@ -121,6 +139,14 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: return f"{self.expr!r}.cast({self.dtype!r})" + def iter_left(self) -> t.Iterator[ExprIR]: + yield from self.expr.iter_left() + yield self + + def iter_right(self) -> t.Iterator[ExprIR]: + yield self + yield from self.expr.iter_right() + class Sort(ExprIR): __slots__ = ("expr", "options") @@ -136,6 +162,14 @@ def __repr__(self) -> str: direction = "desc" if self.options.descending else "asc" return f"{self.expr!r}.sort({direction})" + def iter_left(self) -> t.Iterator[ExprIR]: + yield from self.expr.iter_left() + yield self + + def iter_right(self) -> t.Iterator[ExprIR]: + yield self + yield from self.expr.iter_right() + class SortBy(ExprIR): """https://github.com/narwhals-dev/narwhals/issues/2534.""" @@ -153,6 +187,18 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: return f"{self.expr!r}.sort_by(by={self.by!r}, options={self.options!r})" + def iter_left(self) -> t.Iterator[ExprIR]: + yield from self.expr.iter_left() + for e in self.by: + yield from e.iter_left() + yield self + + def iter_right(self) -> t.Iterator[ExprIR]: + yield self + for e in reversed(self.by): + yield from e.iter_right() + yield from self.expr.iter_right() + class FunctionExpr(ExprIR, t.Generic[_FunctionT]): """**Representing `Expr::Function`**. @@ -197,6 +243,16 @@ def __repr__(self) -> str: else: return f"{self.function!r}()" + def iter_left(self) -> t.Iterator[ExprIR]: + for e in self.input: + yield from e.iter_left() + yield self + + def iter_right(self) -> t.Iterator[ExprIR]: + yield self + for e in reversed(self.input): + yield from e.iter_right() + class RollingExpr(FunctionExpr[_RollingT]): ... @@ -218,6 +274,16 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: return f"{self.expr!r}.filter({self.by!r})" + def iter_left(self) -> t.Iterator[ExprIR]: + yield from self.expr.iter_left() + yield from self.by.iter_left() + yield self + + def iter_right(self) -> t.Iterator[ExprIR]: + yield self + yield from self.by.iter_right() + yield from self.expr.iter_right() + class WindowExpr(ExprIR): """A fully specified `.over()`, that occurred after another expression. @@ -276,6 +342,20 @@ def __str__(self) -> str: args = f"expr={self.expr}, partition_by={self.partition_by}, order_by={order_by}, options={self.options}" return f"{type(self).__name__}({args})" + def iter_left(self) -> t.Iterator[ExprIR]: + # NOTE: `order_by` is never considered in `polars` + # https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/plans/iterator.rs#L76-L86 + yield from self.expr.iter_left() + for e in self.partition_by: + yield from e.iter_left() + yield self + + def iter_right(self) -> t.Iterator[ExprIR]: + yield self + for e in reversed(self.partition_by): + yield from e.iter_right() + yield from self.expr.iter_right() + class Len(ExprIR): @property From 882eff073d05171fc809f2ded219742973ef117b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 21 May 2025 22:06:31 +0100 Subject: [PATCH 096/368] =?UTF-8?q?feat:=20Implement=20most=20of=20`ExprIR?= =?UTF-8?q?MetaNamespace`=20=F0=9F=9A=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- narwhals/_plan/expr.py | 8 ++++ narwhals/_plan/literal.py | 8 ++++ narwhals/_plan/meta.py | 91 +++++++++++++++++++++++++++++++++++---- 3 files changed, 98 insertions(+), 9 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 526696ec70..b609265ae8 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -88,6 +88,10 @@ def is_scalar(self) -> bool: def dtype(self) -> DType: return self.value.dtype + @property + def name(self) -> str: + return self.value.name + def __repr__(self) -> str: return f"lit({self.value!r})" @@ -362,6 +366,10 @@ class Len(ExprIR): def is_scalar(self) -> bool: return True + @property + def name(self) -> str: + return "len" + def __repr__(self) -> str: return "len()" diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index be76c415e7..f9a70219d4 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -18,6 +18,10 @@ class LiteralValue(Immutable): def dtype(self) -> DType: raise NotImplementedError + @property + def name(self) -> str: + return "literal" + @property def is_scalar(self) -> bool: return False @@ -58,6 +62,10 @@ class SeriesLiteral(LiteralValue): def dtype(self) -> DType: return self.value.dtype + @property + def name(self) -> str: + return self.value.name + def __repr__(self) -> str: return "Series" diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 9fe34d7168..4220590401 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -10,6 +10,9 @@ from typing import TYPE_CHECKING +from narwhals.exceptions import ComputeError +from narwhals.utils import Version + if TYPE_CHECKING: from typing import Any @@ -25,26 +28,31 @@ def __init__(self, ir: ExprIR, /) -> None: self._ir: ExprIR = ir def has_multiple_outputs(self) -> bool: - raise NotImplementedError + return any(_has_multiple_outputs(e) for e in self._ir.iter_left()) def is_column(self) -> bool: - """Only one that doesn't require iter. - - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/meta.rs#L65-L71 - """ from narwhals._plan.expr import Column return isinstance(self._ir, Column) def is_column_selection(self, *, allow_aliasing: bool = False) -> bool: - raise NotImplementedError + return all( + _is_column_selection(e, allow_aliasing=allow_aliasing) + for e in self._ir.iter_left() + ) def is_literal(self, *, allow_aliasing: bool = False) -> bool: - raise NotImplementedError + return all( + _is_literal(e, allow_aliasing=allow_aliasing) for e in self._ir.iter_left() + ) def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: - """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/utils.rs#L126-L127.""" - raise NotImplementedError + ok_or_err = _expr_output_name(self._ir) + if isinstance(ok_or_err, ComputeError): + if raise_if_undetermined: + raise ok_or_err + return None + return ok_or_err # NOTE: Less important for us, but maybe nice to have def pop(self) -> list[ExprIR]: @@ -62,6 +70,71 @@ def is_regex_projection(self) -> bool: raise NotImplementedError +def _expr_output_name(ir: ExprIR) -> str | ComputeError: + from narwhals._plan import expr + + for e in ir.iter_right(): + if isinstance(e, expr.WindowExpr): + return _expr_output_name(e.expr) + if isinstance(e, (expr.Column, expr.Alias, expr.Literal, expr.Len)): + return e.name + if isinstance(e, expr.All): + msg = "cannot determine output column without a context for this expression" + return ComputeError(msg) + if isinstance(e, (expr.Columns, expr.IndexColumns, expr.Nth)): + msg = "this expression may produce multiple output names" + return ComputeError(msg) + continue + msg = f"unable to find root column name for expr '{ir!r}' when calling 'output_name'" + return ComputeError(msg) + + +def _has_multiple_outputs(ir: ExprIR) -> bool: + from narwhals._plan import expr + + return isinstance(ir, (expr.Columns, expr.IndexColumns, expr.Selector, expr.All)) + + +def _is_literal(ir: ExprIR, *, allow_aliasing: bool) -> bool: + from narwhals._plan import expr + from narwhals._plan.literal import ScalarLiteral + + if isinstance(ir, expr.Literal): + return True + if isinstance(ir, expr.Alias): + return allow_aliasing + if isinstance(ir, expr.Cast): + return ( + isinstance(ir.expr, expr.Literal) + and isinstance(ir.expr, ScalarLiteral) + and isinstance(ir.expr.dtype, Version.MAIN.dtypes.Datetime) + ) + return False + + +def _is_column_selection(ir: ExprIR, *, allow_aliasing: bool) -> bool: + from narwhals._plan import expr + + if isinstance( + ir, + ( + expr.Column, + expr.Columns, + expr.Exclude, + expr.Nth, + expr.IndexColumns, + expr.Selector, + expr.All, + ), + ): + return True + # TODO @dangotbanned: Add `KeepName`, `RenameAlias` here later (see `_plan.name`) + aliasing_types = (expr.Alias,) + if isinstance(ir, aliasing_types): + return allow_aliasing + return False + + def polars_expr_metadata(expr: pl.Expr) -> dict[str, Any]: """Gather all metadata for a native `Expr`. From 6d9b9e1078f6318bc079bf8903b56372f340006d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 21 May 2025 22:10:52 +0100 Subject: [PATCH 097/368] feat: Add `meta` accessor --- narwhals/_plan/dummy.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 4a660ca738..99760450a4 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -32,6 +32,7 @@ from narwhals._plan.common import IntoExprColumn from narwhals._plan.common import Seq from narwhals._plan.common import Udf + from narwhals._plan.meta import ExprIRMetaNamespace from narwhals.typing import FillNullStrategy from narwhals.typing import NativeSeries from narwhals.typing import NumericLiteral @@ -431,6 +432,12 @@ def __pow__(self, other: IntoExpr) -> Self: base = self._ir return self._from_ir(F.Pow().to_function_expr(base, exponent)) + @property + def meta(self) -> ExprIRMetaNamespace: + from narwhals._plan.meta import ExprIRMetaNamespace + + return ExprIRMetaNamespace(self._ir) + class DummyExprV1(DummyExpr): _version: t.ClassVar[Version] = Version.V1 From 4f1ea19bde75a20e92950df1519e108ed6d2107d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 21 May 2025 22:41:26 +0100 Subject: [PATCH 098/368] test: Extra doctests for `meta.output_name` --- narwhals/_plan/meta.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 4220590401..7c466468ff 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -47,6 +47,32 @@ def is_literal(self, *, allow_aliasing: bool = False) -> bool: ) def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: + """Get the output name of this expression. + + Examples: + >>> from narwhals._plan import demo as nwd + >>> + >>> a = nwd.col("a") + >>> b = a.alias("b") + >>> c = b.min().alias("c") + >>> c_over = c.over(nwd.col("e"), nwd.col("f")) + >>> c_over_sort = c_over.sort_by(nwd.nth(9), nwd.col("g", "h")) + >>> + >>> a.meta.output_name() + 'a' + >>> b.meta.output_name() + 'b' + >>> c.meta.output_name() + 'c' + >>> c_over.meta.output_name() + 'c' + >>> c_over_sort.meta.output_name() + 'c' + >>> nwd.lit(1).meta.output_name() + 'literal' + >>> nwd.len().meta.output_name() + 'len' + """ ok_or_err = _expr_output_name(self._ir) if isinstance(ok_or_err, ComputeError): if raise_if_undetermined: @@ -74,7 +100,8 @@ def _expr_output_name(ir: ExprIR) -> str | ComputeError: from narwhals._plan import expr for e in ir.iter_right(): - if isinstance(e, expr.WindowExpr): + if isinstance(e, (expr.WindowExpr, expr.SortBy)): + # Don't follow `over(partition_by=...)` or `sort_by(by=...) return _expr_output_name(e.expr) if isinstance(e, (expr.Column, expr.Alias, expr.Literal, expr.Len)): return e.name From 9569d99f74531f49d9498c0e57f66aec33048b3b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 21 May 2025 22:44:38 +0100 Subject: [PATCH 099/368] docs: Update namespace class doc --- narwhals/_plan/meta.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 7c466468ff..4aee876cac 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -22,7 +22,7 @@ class ExprIRMetaNamespace: - """Requires defining iterator behavior per node.""" + """Methods to modify and traverse existing expressions.""" def __init__(self, ir: ExprIR, /) -> None: self._ir: ExprIR = ir @@ -90,11 +90,6 @@ def root_names(self) -> list[str]: def undo_aliases(self) -> ExprIR: raise NotImplementedError - # NOTE: We don't support `nw.col("*")` or other patterns in col - # Maybe not relevant at all - def is_regex_projection(self) -> bool: - raise NotImplementedError - def _expr_output_name(ir: ExprIR) -> str | ComputeError: from narwhals._plan import expr From f1d58dbbc830244bbc5d8875c38b57cdd7a7a72c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 22 May 2025 11:11:52 +0100 Subject: [PATCH 100/368] chore: Align repr more closely w/ `polars` Mainly trying to make the comparison between complex expressions easier - Escape strings in `Immutable.__str__` - Now calling `str(ExprIR._ir)` gives you something that (when the top level quotes are removed) can be formatted by `ruff` - Represent `FunctionFlags` in the shorter version `polars` uses --- narwhals/_plan/common.py | 2 ++ narwhals/_plan/options.py | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index b249ed58e8..6b617f3069 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -131,6 +131,8 @@ def _field_str(name: str, value: Any) -> str: if isinstance(value, tuple): inner = ", ".join(f"{v}" for v in value) return f"{name}=[{inner}]" + elif isinstance(value, str): + return f"{name}={value!r}" return f"{name}={value}" diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 485708725d..7ddcd9f199 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -48,6 +48,10 @@ def is_length_preserving(self) -> bool: def default() -> FunctionFlags: return FunctionFlags.ALLOW_GROUP_AWARE + def __str__(self) -> str: + name = self.name or "" + return name.replace("|", " | ") + class FunctionOptions(Immutable): """ExprMetadata` but less god object. @@ -59,6 +63,9 @@ class FunctionOptions(Immutable): flags: FunctionFlags + def __str__(self) -> str: + return f"{type(self).__name__}(flags='{self.flags}')" + def is_elementwise(self) -> bool: return self.flags.is_elementwise() From 3afb8fc2aa337a80f947027e85a74cdc823b31e1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 May 2025 12:40:26 +0000 Subject: [PATCH 101/368] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- narwhals/_plan/boolean.py | 3 +- narwhals/_plan/common.py | 17 +++-------- narwhals/_plan/demo.py | 30 +++++++------------ narwhals/_plan/dummy.py | 52 ++++++++++++++++----------------- narwhals/_plan/expr.py | 13 ++++----- narwhals/_plan/expr_parsing.py | 24 +++++---------- narwhals/_plan/functions.py | 10 ++----- narwhals/_plan/operators.py | 3 +- narwhals/_plan/strings.py | 3 +- narwhals/_plan/temporal.py | 3 +- narwhals/_plan/window.py | 3 +- tests/plan/expr_parsing_test.py | 16 +++++----- 12 files changed, 69 insertions(+), 108 deletions(-) diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/boolean.py index c92d032754..48c2f7db76 100644 --- a/narwhals/_plan/boolean.py +++ b/narwhals/_plan/boolean.py @@ -5,8 +5,7 @@ import typing as t from narwhals._plan.common import Function -from narwhals._plan.options import FunctionFlags -from narwhals._plan.options import FunctionOptions +from narwhals._plan.options import FunctionFlags, FunctionOptions if t.TYPE_CHECKING: from narwhals.typing import ClosedInterval diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 6b617f3069..f283f803f8 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -2,25 +2,16 @@ import datetime as dt from decimal import Decimal -from typing import TYPE_CHECKING -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar from narwhals.utils import Version if TYPE_CHECKING: - from typing import Any - from typing import Callable - from typing import Iterator + from typing import Any, Callable, Iterator - from typing_extensions import Never - from typing_extensions import Self - from typing_extensions import TypeAlias - from typing_extensions import TypeIs - from typing_extensions import dataclass_transform + from typing_extensions import Never, Self, TypeAlias, TypeIs, dataclass_transform - from narwhals._plan.dummy import DummyCompliantExpr - from narwhals._plan.dummy import DummyExpr - from narwhals._plan.dummy import DummySeries + from narwhals._plan.dummy import DummyCompliantExpr, DummyExpr, DummySeries from narwhals._plan.expr import FunctionExpr from narwhals._plan.options import FunctionOptions from narwhals.typing import NonNestedLiteral diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index a4a3c6e50a..85e64da2ea 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -3,34 +3,26 @@ import builtins import typing as t -from narwhals._plan import aggregation as agg -from narwhals._plan import boolean -from narwhals._plan import expr_parsing as parse -from narwhals._plan import functions as F # noqa: N812 -from narwhals._plan.common import ExprIR -from narwhals._plan.common import IntoExpr -from narwhals._plan.common import is_non_nested_literal +from narwhals._plan import ( + aggregation as agg, + boolean, + expr_parsing as parse, + functions as F, # noqa: N812 +) +from narwhals._plan.common import ExprIR, IntoExpr, is_non_nested_literal from narwhals._plan.dummy import DummySeries -from narwhals._plan.expr import All -from narwhals._plan.expr import Column -from narwhals._plan.expr import Columns -from narwhals._plan.expr import IndexColumns -from narwhals._plan.expr import Len -from narwhals._plan.expr import Nth -from narwhals._plan.literal import ScalarLiteral -from narwhals._plan.literal import SeriesLiteral +from narwhals._plan.expr import All, Column, Columns, IndexColumns, Len, Nth +from narwhals._plan.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.strings import ConcatHorizontal from narwhals.dtypes import DType from narwhals.exceptions import OrderDependentExprError -from narwhals.utils import Version -from narwhals.utils import flatten +from narwhals.utils import Version, flatten if t.TYPE_CHECKING: from typing_extensions import TypeIs from narwhals._plan.dummy import DummyExpr - from narwhals._plan.expr import SortBy - from narwhals._plan.expr import WindowExpr + from narwhals._plan.expr import SortBy, WindowExpr from narwhals.typing import NonNestedLiteral diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 99760450a4..245f5da9f8 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -5,40 +5,40 @@ import typing as t from typing import TYPE_CHECKING -from narwhals._plan import aggregation as agg -from narwhals._plan import boolean -from narwhals._plan import expr -from narwhals._plan import expr_parsing as parse -from narwhals._plan import functions as F # noqa: N812 -from narwhals._plan import operators as ops -from narwhals._plan.options import EWMOptions -from narwhals._plan.options import RankOptions -from narwhals._plan.options import RollingOptionsFixedWindow -from narwhals._plan.options import RollingVarParams -from narwhals._plan.options import SortMultipleOptions -from narwhals._plan.options import SortOptions +from narwhals._plan import ( + aggregation as agg, + boolean, + expr, + expr_parsing as parse, + functions as F, # noqa: N812 + operators as ops, +) +from narwhals._plan.options import ( + EWMOptions, + RankOptions, + RollingOptionsFixedWindow, + RollingVarParams, + SortMultipleOptions, + SortOptions, +) from narwhals._plan.window import Over from narwhals.dtypes import DType from narwhals.exceptions import ComputeError -from narwhals.utils import Version -from narwhals.utils import _hasattr_static -from narwhals.utils import flatten +from narwhals.utils import Version, _hasattr_static, flatten if TYPE_CHECKING: from typing_extensions import Self - from narwhals._plan.common import ExprIR - from narwhals._plan.common import IntoExpr - from narwhals._plan.common import IntoExprColumn - from narwhals._plan.common import Seq - from narwhals._plan.common import Udf + from narwhals._plan.common import ExprIR, IntoExpr, IntoExprColumn, Seq, Udf from narwhals._plan.meta import ExprIRMetaNamespace - from narwhals.typing import FillNullStrategy - from narwhals.typing import NativeSeries - from narwhals.typing import NumericLiteral - from narwhals.typing import RankMethod - from narwhals.typing import RollingInterpolationMethod - from narwhals.typing import TemporalLiteral + from narwhals.typing import ( + FillNullStrategy, + NativeSeries, + NumericLiteral, + RankMethod, + RollingInterpolationMethod, + TemporalLiteral, + ) # NOTE: Overly simplified placeholders for mocking typing diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index b609265ae8..558e491351 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -17,15 +17,14 @@ if t.TYPE_CHECKING: from typing_extensions import Self - from narwhals._plan.common import Function - from narwhals._plan.common import Seq - from narwhals._plan.functions import MapBatches # noqa: F401 - from narwhals._plan.functions import RollingWindow + from narwhals._plan.common import Function, Seq + from narwhals._plan.functions import ( + MapBatches, # noqa: F401 + RollingWindow, + ) from narwhals._plan.literal import LiteralValue from narwhals._plan.operators import Operator - from narwhals._plan.options import FunctionOptions - from narwhals._plan.options import SortMultipleOptions - from narwhals._plan.options import SortOptions + from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions from narwhals._plan.window import Window from narwhals.dtypes import DType diff --git a/narwhals/_plan/expr_parsing.py b/narwhals/_plan/expr_parsing.py index c9000b36bf..a9944fceac 100644 --- a/narwhals/_plan/expr_parsing.py +++ b/narwhals/_plan/expr_parsing.py @@ -1,28 +1,18 @@ from __future__ import annotations # ruff: noqa: A002 -from typing import TYPE_CHECKING -from typing import Iterable -from typing import Sequence -from typing import TypeVar - -from narwhals._plan.common import is_expr -from narwhals._plan.common import is_iterable_reject -from narwhals.dependencies import get_polars -from narwhals.dependencies import is_pandas_dataframe -from narwhals.dependencies import is_pandas_series +from typing import TYPE_CHECKING, Iterable, Sequence, TypeVar + +from narwhals._plan.common import is_expr, is_iterable_reject +from narwhals.dependencies import get_polars, is_pandas_dataframe, is_pandas_series from narwhals.exceptions import InvalidIntoExprError if TYPE_CHECKING: - from typing import Any - from typing import Iterator + from typing import Any, Iterator - from typing_extensions import TypeAlias - from typing_extensions import TypeIs + from typing_extensions import TypeAlias, TypeIs - from narwhals._plan.common import ExprIR - from narwhals._plan.common import IntoExpr - from narwhals._plan.common import Seq + from narwhals._plan.common import ExprIR, IntoExpr, Seq from narwhals.dtypes import DType T = TypeVar("T") diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 5c16fd6db3..3fb7cfca14 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -9,18 +9,14 @@ from typing import TYPE_CHECKING from narwhals._plan.common import Function -from narwhals._plan.options import FunctionFlags -from narwhals._plan.options import FunctionOptions +from narwhals._plan.options import FunctionFlags, FunctionOptions from narwhals.exceptions import ComputeError if TYPE_CHECKING: from typing import Any - from narwhals._plan.common import Seq - from narwhals._plan.common import Udf - from narwhals._plan.options import EWMOptions - from narwhals._plan.options import RankOptions - from narwhals._plan.options import RollingOptionsFixedWindow + from narwhals._plan.common import Seq, Udf + from narwhals._plan.options import EWMOptions, RankOptions, RollingOptionsFixedWindow from narwhals.dtypes import DType from narwhals.typing import FillNullStrategy diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index fcdab07101..e5523f1eb7 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -5,8 +5,7 @@ if TYPE_CHECKING: from narwhals._plan.expr import BinaryExpr -from narwhals._plan.common import ExprIR -from narwhals._plan.common import Immutable +from narwhals._plan.common import ExprIR, Immutable class Operator(Immutable): diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/strings.py index 697b0b122d..ef64a3b79a 100644 --- a/narwhals/_plan/strings.py +++ b/narwhals/_plan/strings.py @@ -1,8 +1,7 @@ from __future__ import annotations from narwhals._plan.common import Function -from narwhals._plan.options import FunctionFlags -from narwhals._plan.options import FunctionOptions +from narwhals._plan.options import FunctionFlags, FunctionOptions class StringFunction(Function): diff --git a/narwhals/_plan/temporal.py b/narwhals/_plan/temporal.py index 9485baefb3..a40ad2f63e 100644 --- a/narwhals/_plan/temporal.py +++ b/narwhals/_plan/temporal.py @@ -1,7 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING -from typing import cast +from typing import TYPE_CHECKING, cast from narwhals._plan.common import Function from narwhals._plan.options import FunctionOptions diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py index 1e34a699a5..861e7baff8 100644 --- a/narwhals/_plan/window.py +++ b/narwhals/_plan/window.py @@ -5,8 +5,7 @@ from narwhals._plan.common import Immutable if TYPE_CHECKING: - from narwhals._plan.common import ExprIR - from narwhals._plan.common import Seq + from narwhals._plan.common import ExprIR, Seq from narwhals._plan.expr import WindowExpr from narwhals._plan.options import SortOptions diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index eb88879f26..84487a6e47 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -1,23 +1,21 @@ from __future__ import annotations -from typing import TYPE_CHECKING -from typing import Callable -from typing import Iterable +from typing import TYPE_CHECKING, Callable, Iterable import pytest import narwhals as nw import narwhals._plan.demo as nwd -from narwhals._plan import boolean -from narwhals._plan import functions as F # noqa: N812 -from narwhals._plan.common import ExprIR -from narwhals._plan.common import Function +from narwhals._plan import ( + boolean, + functions as F, # noqa: N812 +) +from narwhals._plan.common import ExprIR, Function from narwhals._plan.dummy import DummyExpr from narwhals._plan.expr import FunctionExpr if TYPE_CHECKING: - from narwhals._plan.common import IntoExpr - from narwhals._plan.common import Seq + from narwhals._plan.common import IntoExpr, Seq @pytest.mark.parametrize( From a9c3b3e28d1be3a4b961d4539aaf7bc3dd19759d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 22 May 2025 13:41:19 +0100 Subject: [PATCH 102/368] style(ruff): Apply new formatting --- narwhals/_plan/boolean.py | 3 +- narwhals/_plan/common.py | 17 +++-------- narwhals/_plan/demo.py | 30 +++++++------------ narwhals/_plan/dummy.py | 52 ++++++++++++++++----------------- narwhals/_plan/expr.py | 13 ++++----- narwhals/_plan/expr_parsing.py | 24 +++++---------- narwhals/_plan/functions.py | 10 ++----- narwhals/_plan/operators.py | 3 +- narwhals/_plan/strings.py | 3 +- narwhals/_plan/temporal.py | 3 +- narwhals/_plan/window.py | 3 +- tests/plan/expr_parsing_test.py | 16 +++++----- 12 files changed, 69 insertions(+), 108 deletions(-) diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/boolean.py index c92d032754..48c2f7db76 100644 --- a/narwhals/_plan/boolean.py +++ b/narwhals/_plan/boolean.py @@ -5,8 +5,7 @@ import typing as t from narwhals._plan.common import Function -from narwhals._plan.options import FunctionFlags -from narwhals._plan.options import FunctionOptions +from narwhals._plan.options import FunctionFlags, FunctionOptions if t.TYPE_CHECKING: from narwhals.typing import ClosedInterval diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 6b617f3069..f283f803f8 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -2,25 +2,16 @@ import datetime as dt from decimal import Decimal -from typing import TYPE_CHECKING -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar from narwhals.utils import Version if TYPE_CHECKING: - from typing import Any - from typing import Callable - from typing import Iterator + from typing import Any, Callable, Iterator - from typing_extensions import Never - from typing_extensions import Self - from typing_extensions import TypeAlias - from typing_extensions import TypeIs - from typing_extensions import dataclass_transform + from typing_extensions import Never, Self, TypeAlias, TypeIs, dataclass_transform - from narwhals._plan.dummy import DummyCompliantExpr - from narwhals._plan.dummy import DummyExpr - from narwhals._plan.dummy import DummySeries + from narwhals._plan.dummy import DummyCompliantExpr, DummyExpr, DummySeries from narwhals._plan.expr import FunctionExpr from narwhals._plan.options import FunctionOptions from narwhals.typing import NonNestedLiteral diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index a4a3c6e50a..85e64da2ea 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -3,34 +3,26 @@ import builtins import typing as t -from narwhals._plan import aggregation as agg -from narwhals._plan import boolean -from narwhals._plan import expr_parsing as parse -from narwhals._plan import functions as F # noqa: N812 -from narwhals._plan.common import ExprIR -from narwhals._plan.common import IntoExpr -from narwhals._plan.common import is_non_nested_literal +from narwhals._plan import ( + aggregation as agg, + boolean, + expr_parsing as parse, + functions as F, # noqa: N812 +) +from narwhals._plan.common import ExprIR, IntoExpr, is_non_nested_literal from narwhals._plan.dummy import DummySeries -from narwhals._plan.expr import All -from narwhals._plan.expr import Column -from narwhals._plan.expr import Columns -from narwhals._plan.expr import IndexColumns -from narwhals._plan.expr import Len -from narwhals._plan.expr import Nth -from narwhals._plan.literal import ScalarLiteral -from narwhals._plan.literal import SeriesLiteral +from narwhals._plan.expr import All, Column, Columns, IndexColumns, Len, Nth +from narwhals._plan.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.strings import ConcatHorizontal from narwhals.dtypes import DType from narwhals.exceptions import OrderDependentExprError -from narwhals.utils import Version -from narwhals.utils import flatten +from narwhals.utils import Version, flatten if t.TYPE_CHECKING: from typing_extensions import TypeIs from narwhals._plan.dummy import DummyExpr - from narwhals._plan.expr import SortBy - from narwhals._plan.expr import WindowExpr + from narwhals._plan.expr import SortBy, WindowExpr from narwhals.typing import NonNestedLiteral diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 99760450a4..245f5da9f8 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -5,40 +5,40 @@ import typing as t from typing import TYPE_CHECKING -from narwhals._plan import aggregation as agg -from narwhals._plan import boolean -from narwhals._plan import expr -from narwhals._plan import expr_parsing as parse -from narwhals._plan import functions as F # noqa: N812 -from narwhals._plan import operators as ops -from narwhals._plan.options import EWMOptions -from narwhals._plan.options import RankOptions -from narwhals._plan.options import RollingOptionsFixedWindow -from narwhals._plan.options import RollingVarParams -from narwhals._plan.options import SortMultipleOptions -from narwhals._plan.options import SortOptions +from narwhals._plan import ( + aggregation as agg, + boolean, + expr, + expr_parsing as parse, + functions as F, # noqa: N812 + operators as ops, +) +from narwhals._plan.options import ( + EWMOptions, + RankOptions, + RollingOptionsFixedWindow, + RollingVarParams, + SortMultipleOptions, + SortOptions, +) from narwhals._plan.window import Over from narwhals.dtypes import DType from narwhals.exceptions import ComputeError -from narwhals.utils import Version -from narwhals.utils import _hasattr_static -from narwhals.utils import flatten +from narwhals.utils import Version, _hasattr_static, flatten if TYPE_CHECKING: from typing_extensions import Self - from narwhals._plan.common import ExprIR - from narwhals._plan.common import IntoExpr - from narwhals._plan.common import IntoExprColumn - from narwhals._plan.common import Seq - from narwhals._plan.common import Udf + from narwhals._plan.common import ExprIR, IntoExpr, IntoExprColumn, Seq, Udf from narwhals._plan.meta import ExprIRMetaNamespace - from narwhals.typing import FillNullStrategy - from narwhals.typing import NativeSeries - from narwhals.typing import NumericLiteral - from narwhals.typing import RankMethod - from narwhals.typing import RollingInterpolationMethod - from narwhals.typing import TemporalLiteral + from narwhals.typing import ( + FillNullStrategy, + NativeSeries, + NumericLiteral, + RankMethod, + RollingInterpolationMethod, + TemporalLiteral, + ) # NOTE: Overly simplified placeholders for mocking typing diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index b609265ae8..558e491351 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -17,15 +17,14 @@ if t.TYPE_CHECKING: from typing_extensions import Self - from narwhals._plan.common import Function - from narwhals._plan.common import Seq - from narwhals._plan.functions import MapBatches # noqa: F401 - from narwhals._plan.functions import RollingWindow + from narwhals._plan.common import Function, Seq + from narwhals._plan.functions import ( + MapBatches, # noqa: F401 + RollingWindow, + ) from narwhals._plan.literal import LiteralValue from narwhals._plan.operators import Operator - from narwhals._plan.options import FunctionOptions - from narwhals._plan.options import SortMultipleOptions - from narwhals._plan.options import SortOptions + from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions from narwhals._plan.window import Window from narwhals.dtypes import DType diff --git a/narwhals/_plan/expr_parsing.py b/narwhals/_plan/expr_parsing.py index c9000b36bf..a9944fceac 100644 --- a/narwhals/_plan/expr_parsing.py +++ b/narwhals/_plan/expr_parsing.py @@ -1,28 +1,18 @@ from __future__ import annotations # ruff: noqa: A002 -from typing import TYPE_CHECKING -from typing import Iterable -from typing import Sequence -from typing import TypeVar - -from narwhals._plan.common import is_expr -from narwhals._plan.common import is_iterable_reject -from narwhals.dependencies import get_polars -from narwhals.dependencies import is_pandas_dataframe -from narwhals.dependencies import is_pandas_series +from typing import TYPE_CHECKING, Iterable, Sequence, TypeVar + +from narwhals._plan.common import is_expr, is_iterable_reject +from narwhals.dependencies import get_polars, is_pandas_dataframe, is_pandas_series from narwhals.exceptions import InvalidIntoExprError if TYPE_CHECKING: - from typing import Any - from typing import Iterator + from typing import Any, Iterator - from typing_extensions import TypeAlias - from typing_extensions import TypeIs + from typing_extensions import TypeAlias, TypeIs - from narwhals._plan.common import ExprIR - from narwhals._plan.common import IntoExpr - from narwhals._plan.common import Seq + from narwhals._plan.common import ExprIR, IntoExpr, Seq from narwhals.dtypes import DType T = TypeVar("T") diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 5c16fd6db3..3fb7cfca14 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -9,18 +9,14 @@ from typing import TYPE_CHECKING from narwhals._plan.common import Function -from narwhals._plan.options import FunctionFlags -from narwhals._plan.options import FunctionOptions +from narwhals._plan.options import FunctionFlags, FunctionOptions from narwhals.exceptions import ComputeError if TYPE_CHECKING: from typing import Any - from narwhals._plan.common import Seq - from narwhals._plan.common import Udf - from narwhals._plan.options import EWMOptions - from narwhals._plan.options import RankOptions - from narwhals._plan.options import RollingOptionsFixedWindow + from narwhals._plan.common import Seq, Udf + from narwhals._plan.options import EWMOptions, RankOptions, RollingOptionsFixedWindow from narwhals.dtypes import DType from narwhals.typing import FillNullStrategy diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index fcdab07101..e5523f1eb7 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -5,8 +5,7 @@ if TYPE_CHECKING: from narwhals._plan.expr import BinaryExpr -from narwhals._plan.common import ExprIR -from narwhals._plan.common import Immutable +from narwhals._plan.common import ExprIR, Immutable class Operator(Immutable): diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/strings.py index 697b0b122d..ef64a3b79a 100644 --- a/narwhals/_plan/strings.py +++ b/narwhals/_plan/strings.py @@ -1,8 +1,7 @@ from __future__ import annotations from narwhals._plan.common import Function -from narwhals._plan.options import FunctionFlags -from narwhals._plan.options import FunctionOptions +from narwhals._plan.options import FunctionFlags, FunctionOptions class StringFunction(Function): diff --git a/narwhals/_plan/temporal.py b/narwhals/_plan/temporal.py index 9485baefb3..a40ad2f63e 100644 --- a/narwhals/_plan/temporal.py +++ b/narwhals/_plan/temporal.py @@ -1,7 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING -from typing import cast +from typing import TYPE_CHECKING, cast from narwhals._plan.common import Function from narwhals._plan.options import FunctionOptions diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py index 1e34a699a5..861e7baff8 100644 --- a/narwhals/_plan/window.py +++ b/narwhals/_plan/window.py @@ -5,8 +5,7 @@ from narwhals._plan.common import Immutable if TYPE_CHECKING: - from narwhals._plan.common import ExprIR - from narwhals._plan.common import Seq + from narwhals._plan.common import ExprIR, Seq from narwhals._plan.expr import WindowExpr from narwhals._plan.options import SortOptions diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index eb88879f26..84487a6e47 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -1,23 +1,21 @@ from __future__ import annotations -from typing import TYPE_CHECKING -from typing import Callable -from typing import Iterable +from typing import TYPE_CHECKING, Callable, Iterable import pytest import narwhals as nw import narwhals._plan.demo as nwd -from narwhals._plan import boolean -from narwhals._plan import functions as F # noqa: N812 -from narwhals._plan.common import ExprIR -from narwhals._plan.common import Function +from narwhals._plan import ( + boolean, + functions as F, # noqa: N812 +) +from narwhals._plan.common import ExprIR, Function from narwhals._plan.dummy import DummyExpr from narwhals._plan.expr import FunctionExpr if TYPE_CHECKING: - from narwhals._plan.common import IntoExpr - from narwhals._plan.common import Seq + from narwhals._plan.common import IntoExpr, Seq @pytest.mark.parametrize( From 309db1ff245a96b85ab9803f5f18d5f32522472d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 22 May 2025 16:44:44 +0100 Subject: [PATCH 103/368] feat(typing): Generic `BinaryExpr`, default `TypeVar`s - Needed to add the `ruff` config in the end (#2578) - Prevented forward refs in `TypeVar` from showing as unused imports --- narwhals/_plan/expr.py | 26 ++++++++++---------------- narwhals/_plan/operators.py | 13 +++++++++++-- narwhals/_plan/typing.py | 19 +++++++++++++++++++ pyproject.toml | 1 + 4 files changed, 41 insertions(+), 18 deletions(-) create mode 100644 narwhals/_plan/typing.py diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 558e491351..216b056529 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -13,24 +13,18 @@ import typing as t from narwhals._plan.common import ExprIR +from narwhals._plan.typing import FunctionT, LeftT, OperatorT, RightT, RollingT if t.TYPE_CHECKING: from typing_extensions import Self - from narwhals._plan.common import Function, Seq - from narwhals._plan.functions import ( - MapBatches, # noqa: F401 - RollingWindow, - ) + from narwhals._plan.common import Seq + from narwhals._plan.functions import MapBatches # noqa: F401 from narwhals._plan.literal import LiteralValue - from narwhals._plan.operators import Operator from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions from narwhals._plan.window import Window from narwhals.dtypes import DType -_FunctionT = t.TypeVar("_FunctionT", bound="Function") -_RollingT = t.TypeVar("_RollingT", bound="RollingWindow") - class Alias(ExprIR): __slots__ = ("expr", "name") @@ -95,7 +89,7 @@ def __repr__(self) -> str: return f"lit({self.value!r})" -class BinaryExpr(ExprIR): +class BinaryExpr(ExprIR, t.Generic[LeftT, OperatorT, RightT]): """Application of two exprs via an `Operator`. This ✅ @@ -107,9 +101,9 @@ class BinaryExpr(ExprIR): __slots__ = ("left", "op", "right") - left: ExprIR - op: Operator - right: ExprIR + left: LeftT + op: OperatorT + right: RightT @property def is_scalar(self) -> bool: @@ -203,7 +197,7 @@ def iter_right(self) -> t.Iterator[ExprIR]: yield from self.expr.iter_right() -class FunctionExpr(ExprIR, t.Generic[_FunctionT]): +class FunctionExpr(ExprIR, t.Generic[FunctionT]): """**Representing `Expr::Function`**. https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L114-L120 @@ -214,7 +208,7 @@ class FunctionExpr(ExprIR, t.Generic[_FunctionT]): __slots__ = ("function", "input", "options") input: Seq[ExprIR] - function: _FunctionT + function: FunctionT """Enum type is named `FunctionExpr` in `polars`. Mirroring *exactly* doesn't make much sense in OOP. @@ -257,7 +251,7 @@ def iter_right(self) -> t.Iterator[ExprIR]: yield from e.iter_right() -class RollingExpr(FunctionExpr[_RollingT]): ... +class RollingExpr(FunctionExpr[RollingT]): ... class AnonymousExpr(FunctionExpr["MapBatches"]): diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index e5523f1eb7..bf1c2289be 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -3,9 +3,12 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from typing_extensions import Self + from narwhals._plan.expr import BinaryExpr + from narwhals._plan.typing import LeftT, RightT -from narwhals._plan.common import ExprIR, Immutable +from narwhals._plan.common import Immutable class Operator(Immutable): @@ -28,10 +31,13 @@ def __repr__(self) -> str: Modulus: "%", And: "&", Or: "|", + ExclusiveOr: "^", } return m[tp] - def to_binary_expr(self, left: ExprIR, right: ExprIR, /) -> BinaryExpr: + def to_binary_expr( + self, left: LeftT, right: RightT, / + ) -> BinaryExpr[LeftT, Self, RightT]: from narwhals._plan.expr import BinaryExpr return BinaryExpr(left=left, op=self, right=right) @@ -77,3 +83,6 @@ class And(Operator): ... class Or(Operator): ... + + +class ExclusiveOr(Operator): ... diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py new file mode 100644 index 0000000000..7dba26a2c0 --- /dev/null +++ b/narwhals/_plan/typing.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import typing as t + +from narwhals._typing_compat import TypeVar + +if t.TYPE_CHECKING: + from narwhals._plan import operators as ops + from narwhals._plan.common import ExprIR, Function + from narwhals._plan.functions import RollingWindow + +__all__ = ["FunctionT", "LeftT", "OperatorT", "RightT", "RollingT"] + + +FunctionT = TypeVar("FunctionT", bound="Function") +RollingT = TypeVar("RollingT", bound="RollingWindow") +LeftT = TypeVar("LeftT", bound="ExprIR", default="ExprIR") +OperatorT = TypeVar("OperatorT", bound="ops.Operator", default="ops.Operator") +RightT = TypeVar("RightT", bound="ExprIR", default="ExprIR") diff --git a/pyproject.toml b/pyproject.toml index 7ffa4dceb5..dfec792e70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,6 +130,7 @@ extend-exclude = ["**/this.py"] [tool.ruff.lint] preview = true explicit-preview-rules = true +typing-modules = ["narwhals._typing_compat"] extend-safe-fixes = [ "C419", # unnecessary-comprehension-in-call From 87b8402ba8ace8019f0f1cc58ce8b9981d57a709 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 22 May 2025 20:50:06 +0100 Subject: [PATCH 104/368] very wip `selectors` None of this is functional yet - Most of the upstream stuff is written in `python` - The rust enum isn't very helpful for us - (`polars_plan::dsl::selector::Selector`) - It is opaque to the kind of selection --- narwhals/_plan/dummy.py | 69 ++++++++++++++++- narwhals/_plan/expr.py | 37 ++++++++- narwhals/_plan/meta.py | 4 +- narwhals/_plan/operators.py | 30 ++++++-- narwhals/_plan/selectors.py | 144 ++++++++++++++++++++++++++++++++++++ narwhals/_plan/typing.py | 6 +- 6 files changed, 277 insertions(+), 13 deletions(-) create mode 100644 narwhals/_plan/selectors.py diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 245f5da9f8..745bb8f849 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -27,7 +27,7 @@ from narwhals.utils import Version, _hasattr_static, flatten if TYPE_CHECKING: - from typing_extensions import Self + from typing_extensions import Never, Self from narwhals._plan.common import ExprIR, IntoExpr, IntoExprColumn, Seq, Udf from narwhals._plan.meta import ExprIRMetaNamespace @@ -439,10 +439,77 @@ def meta(self) -> ExprIRMetaNamespace: return ExprIRMetaNamespace(self._ir) +class DummySelector(DummyExpr): + _ir: expr.SelectorIR + + @classmethod + def _from_ir(cls, ir: expr.SelectorIR, /) -> Self: # type: ignore[override] + obj = cls.__new__(cls) + obj._ir = ir + return obj + + def _to_expr(self) -> DummyExpr: + return self._ir.to_narwhals(self.version) + + # TODO @dangotbanned: Make a decision on selector root, binary op + # Current typing warnings are accurate, this isn't valid yet + def __or__(self, other: t.Any) -> Self | t.Any: + if isinstance(other, type(self)): + op = ops.Or() + return self._from_ir(op.to_binary_selector(self._ir, other._ir)) # type: ignore[arg-type] + return self._to_expr() | other + + def __and__(self, other: t.Any) -> Self | t.Any: + if isinstance(other, type(self)): + op = ops.And() + return self._from_ir(op.to_binary_selector(self._ir, other._ir)) # type: ignore[arg-type] + return self._to_expr() & other + + def __sub__(self, other: t.Any) -> Self | t.Any: + if isinstance(other, type(self)): + op = ops.Sub() + return self._from_ir(op.to_binary_selector(self._ir, other._ir)) # type: ignore[arg-type] + return self._to_expr() - other + + def __xor__(self, other: t.Any) -> Self | t.Any: + if isinstance(other, type(self)): + op = ops.ExclusiveOr() + return self._from_ir(op.to_binary_selector(self._ir, other._ir)) # type: ignore[arg-type] + return self._to_expr() ^ other + + def __invert__(self) -> Never: + raise NotImplementedError + + def __add__(self, other: t.Any) -> DummyExpr: # type: ignore[override] + if isinstance(other, type(self)): + msg = "unsupported operand type(s) for op: ('Selector' + 'Selector')" + raise TypeError(msg) + return self._to_expr() + other # type: ignore[no-any-return] + + def __rsub__(self, other: t.Any) -> Never: + raise NotImplementedError + + def __rand__(self, other: t.Any) -> Never: + raise NotImplementedError + + def __ror__(self, other: t.Any) -> Never: + raise NotImplementedError + + def __rxor__(self, other: t.Any) -> Never: + raise NotImplementedError + + def __radd__(self, other: t.Any) -> Never: + raise NotImplementedError + + class DummyExprV1(DummyExpr): _version: t.ClassVar[Version] = Version.V1 +class DummySelectorV1(DummySelector): + _version: t.ClassVar[Version] = Version.V1 + + class DummyCompliantExpr: _ir: ExprIR _version: Version diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 216b056529..f406aa46f5 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -13,15 +13,25 @@ import typing as t from narwhals._plan.common import ExprIR -from narwhals._plan.typing import FunctionT, LeftT, OperatorT, RightT, RollingT +from narwhals._plan.typing import ( + FunctionT, + LeftT, + OperatorT, + RightT, + RollingT, + SelectorOperatorT, +) +from narwhals.utils import Version if t.TYPE_CHECKING: from typing_extensions import Self from narwhals._plan.common import Seq + from narwhals._plan.dummy import DummySelector from narwhals._plan.functions import MapBatches # noqa: F401 from narwhals._plan.literal import LiteralValue from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions + from narwhals._plan.selectors import Selector from narwhals._plan.window import Window from narwhals.dtypes import DType @@ -408,9 +418,32 @@ def __repr__(self) -> str: return "*" -class Selector(ExprIR): +class SelectorIR(ExprIR): + """Not sure on this separation. + + - Need a cleaner way of including `BinarySelector`. + - Like that there's easy access to operands + - Dislike that it inherits node iteration, since upstream doesn't use it for selectors + """ + + __slots__ = ("selector",) + + selector: Selector """by_dtype, matches, numeric, boolean, string, categorical, datetime, all.""" + def to_narwhals(self, version: Version = Version.MAIN) -> DummySelector: + from narwhals._plan import dummy + + if version is Version.MAIN: + return dummy.DummySelector._from_ir(self) + return dummy.DummySelectorV1._from_ir(self) + + +class BinarySelector( + BinaryExpr["SelectorIR", SelectorOperatorT, "SelectorIR"], + t.Generic[SelectorOperatorT], +): ... + class Ternary(ExprIR): """When-Then-Otherwise. diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 4aee876cac..3507d6545f 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -114,7 +114,7 @@ def _expr_output_name(ir: ExprIR) -> str | ComputeError: def _has_multiple_outputs(ir: ExprIR) -> bool: from narwhals._plan import expr - return isinstance(ir, (expr.Columns, expr.IndexColumns, expr.Selector, expr.All)) + return isinstance(ir, (expr.Columns, expr.IndexColumns, expr.SelectorIR, expr.All)) def _is_literal(ir: ExprIR, *, allow_aliasing: bool) -> bool: @@ -145,7 +145,7 @@ def _is_column_selection(ir: ExprIR, *, allow_aliasing: bool) -> bool: expr.Exclude, expr.Nth, expr.IndexColumns, - expr.Selector, + expr.SelectorIR, expr.All, ), ): diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index bf1c2289be..aacafa8bf9 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: from typing_extensions import Self - from narwhals._plan.expr import BinaryExpr + from narwhals._plan.expr import BinaryExpr, BinarySelector, SelectorIR from narwhals._plan.typing import LeftT, RightT from narwhals._plan.common import Immutable @@ -14,8 +14,8 @@ class Operator(Immutable): def __repr__(self) -> str: tp = type(self) - if tp is Operator: - return "Operator" + if tp in {Operator, SelectorOperator}: + return tp.__name__ m = { Eq: "==", NotEq: "!=", @@ -43,6 +43,22 @@ def to_binary_expr( return BinaryExpr(left=left, op=self, right=right) +class SelectorOperator(Operator): + """Operators that can *also* be used in selectors. + + Remember that `Or` is named [`meta._selector_add`]! + + [`meta._selector_add`]: https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/meta.rs#L113-L124 + """ + + def to_binary_selector( + self, left: SelectorIR, right: SelectorIR, / + ) -> BinarySelector[Self]: + from narwhals._plan.expr import BinarySelector + + return BinarySelector(left=left, op=self, right=right) + + class Eq(Operator): ... @@ -64,7 +80,7 @@ class GtEq(Operator): ... class Add(Operator): ... -class Sub(Operator): ... +class Sub(SelectorOperator): ... class Multiply(Operator): ... @@ -79,10 +95,10 @@ class FloorDivide(Operator): ... class Modulus(Operator): ... -class And(Operator): ... +class And(SelectorOperator): ... -class Or(Operator): ... +class Or(SelectorOperator): ... -class ExclusiveOr(Operator): ... +class ExclusiveOr(SelectorOperator): ... diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/selectors.py new file mode 100644 index 0000000000..580f63d05e --- /dev/null +++ b/narwhals/_plan/selectors.py @@ -0,0 +1,144 @@ +"""Deviations from `polars`. + +- A `Selector` corresponds to a `nw.selectors` function +- Binary ops are represented as a subtype of `BinaryExpr` +""" + +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Iterable + +from narwhals._plan.common import Immutable, is_iterable_reject +from narwhals.utils import _parse_time_unit_and_time_zone + +if TYPE_CHECKING: + from datetime import timezone + from typing import Iterator, TypeVar + + from narwhals._plan.dummy import DummySelector + from narwhals._plan.expr import SelectorIR + from narwhals.dtypes import DType + from narwhals.typing import TimeUnit + + T = TypeVar("T") + + +class Selector(Immutable): + def to_selector(self) -> SelectorIR: + from narwhals._plan.expr import SelectorIR + + return SelectorIR(selector=self) + + +class All(Selector): ... + + +class ByDType(Selector): + __slots__ = ("dtypes",) + + dtypes: frozenset[DType | type[DType]] + + @staticmethod + def from_dtypes( + *dtypes: DType | type[DType] | Iterable[DType | type[DType]], + ) -> ByDType: + return ByDType(dtypes=frozenset(_flatten_hash_safe(dtypes))) + + +class Boolean(Selector): ... + + +class Categorical(Selector): ... + + +class Datetime(Selector): + """Should swallow the [`utils` functions]. + + Just re-wrapping them for now, since `CompliantSelectorNamespace` is still using them. + + [`utils` functions]: https://github.com/narwhals-dev/narwhals/blob/6d524ba04fca6fe2d6d25bdd69f75fabf1d79039/narwhals/utils.py#L1565-L1596 + """ + + __slots__ = ("time_units", "time_zones") + + time_units: frozenset[TimeUnit] + time_zones: frozenset[str | None] + + @staticmethod + def from_time_unit_and_time_zone( + time_unit: TimeUnit | Iterable[TimeUnit] | None, + time_zone: str | timezone | Iterable[str | timezone | None] | None, + /, + ) -> Datetime: + units, zones = _parse_time_unit_and_time_zone(time_unit, time_zone) + return Datetime(time_units=frozenset(units), time_zones=frozenset(zones)) + + +class Matches(Selector): + __slots__ = ("pattern",) + + pattern: re.Pattern[str] + + @staticmethod + def from_string(pattern: str, /) -> Matches: + return Matches(pattern=re.compile(pattern)) + + +class Numeric(Selector): ... + + +class String(Selector): ... + + +def all() -> DummySelector: + return All().to_selector().to_narwhals() + + +def by_dtype( + *dtypes: DType | type[DType] | Iterable[DType | type[DType]], +) -> DummySelector: + return ByDType.from_dtypes(*dtypes).to_selector().to_narwhals() + + +def boolean() -> DummySelector: + return Boolean().to_selector().to_narwhals() + + +def categorical() -> DummySelector: + return Categorical().to_selector().to_narwhals() + + +def datetime( + time_unit: TimeUnit | Iterable[TimeUnit] | None = None, + time_zone: str | timezone | Iterable[str | timezone | None] | None = ("*", None), +) -> DummySelector: + return ( + Datetime.from_time_unit_and_time_zone(time_unit, time_zone) + .to_selector() + .to_narwhals() + ) + + +def matches(pattern: str) -> DummySelector: + return Matches.from_string(pattern).to_selector().to_narwhals() + + +def numeric() -> DummySelector: + return Numeric().to_selector().to_narwhals() + + +def string() -> DummySelector: + return String().to_selector().to_narwhals() + + +def _flatten_hash_safe(iterable: Iterable[T | Iterable[T]], /) -> Iterator[T]: + """Fully unwrap all levels of nesting. + + Aiming to reduce the chances of passing an unhashable argument. + """ + for element in iterable: + if isinstance(element, Iterable) and not is_iterable_reject(element): + yield from _flatten_hash_safe(element) + else: + yield element # type: ignore[misc] diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 7dba26a2c0..337bd48ab8 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -9,7 +9,7 @@ from narwhals._plan.common import ExprIR, Function from narwhals._plan.functions import RollingWindow -__all__ = ["FunctionT", "LeftT", "OperatorT", "RightT", "RollingT"] +__all__ = ["FunctionT", "LeftT", "OperatorT", "RightT", "RollingT", "SelectorOperatorT"] FunctionT = TypeVar("FunctionT", bound="Function") @@ -17,3 +17,7 @@ LeftT = TypeVar("LeftT", bound="ExprIR", default="ExprIR") OperatorT = TypeVar("OperatorT", bound="ops.Operator", default="ops.Operator") RightT = TypeVar("RightT", bound="ExprIR", default="ExprIR") + +SelectorOperatorT = TypeVar( + "SelectorOperatorT", bound="ops.SelectorOperator", default="ops.SelectorOperator" +) From af274d314b6e9002bb577c4850d7fd4fa6ec5655 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 22 May 2025 21:05:37 +0100 Subject: [PATCH 105/368] feat: Allow `IntoExpr` in binary ops First step towards expressifying everywhere --- narwhals/_plan/dummy.py | 70 ++++++++++++++++++++++++----------------- 1 file changed, 42 insertions(+), 28 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 745bb8f849..7c2e142730 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -368,61 +368,75 @@ def map_batches( ).to_function_expr(self._ir) ) - def __eq__(self, other: DummyExpr) -> Self: # type: ignore[override] + def __eq__(self, other: IntoExpr) -> Self: # type: ignore[override] op = ops.Eq() - return self._from_ir(op.to_binary_expr(self._ir, other._ir)) + rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + return self._from_ir(op.to_binary_expr(self._ir, rhs)) - def __ne__(self, other: DummyExpr) -> Self: # type: ignore[override] + def __ne__(self, other: IntoExpr) -> Self: # type: ignore[override] op = ops.NotEq() - return self._from_ir(op.to_binary_expr(self._ir, other._ir)) + rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + return self._from_ir(op.to_binary_expr(self._ir, rhs)) - def __lt__(self, other: DummyExpr) -> Self: + def __lt__(self, other: IntoExpr) -> Self: op = ops.Lt() - return self._from_ir(op.to_binary_expr(self._ir, other._ir)) + rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + return self._from_ir(op.to_binary_expr(self._ir, rhs)) - def __le__(self, other: DummyExpr) -> Self: + def __le__(self, other: IntoExpr) -> Self: op = ops.LtEq() - return self._from_ir(op.to_binary_expr(self._ir, other._ir)) + rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + return self._from_ir(op.to_binary_expr(self._ir, rhs)) - def __gt__(self, other: DummyExpr) -> Self: + def __gt__(self, other: IntoExpr) -> Self: op = ops.Gt() - return self._from_ir(op.to_binary_expr(self._ir, other._ir)) + rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + return self._from_ir(op.to_binary_expr(self._ir, rhs)) - def __ge__(self, other: DummyExpr) -> Self: + def __ge__(self, other: IntoExpr) -> Self: op = ops.GtEq() - return self._from_ir(op.to_binary_expr(self._ir, other._ir)) + rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + return self._from_ir(op.to_binary_expr(self._ir, rhs)) - def __add__(self, other: DummyExpr) -> Self: + def __add__(self, other: IntoExpr) -> Self: op = ops.Add() - return self._from_ir(op.to_binary_expr(self._ir, other._ir)) + rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + return self._from_ir(op.to_binary_expr(self._ir, rhs)) - def __sub__(self, other: DummyExpr) -> Self: + def __sub__(self, other: IntoExpr) -> Self: op = ops.Sub() - return self._from_ir(op.to_binary_expr(self._ir, other._ir)) + rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + return self._from_ir(op.to_binary_expr(self._ir, rhs)) - def __mul__(self, other: DummyExpr) -> Self: + def __mul__(self, other: IntoExpr) -> Self: op = ops.Multiply() - return self._from_ir(op.to_binary_expr(self._ir, other._ir)) + rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + return self._from_ir(op.to_binary_expr(self._ir, rhs)) - def __truediv__(self, other: DummyExpr) -> Self: + def __truediv__(self, other: IntoExpr) -> Self: op = ops.TrueDivide() - return self._from_ir(op.to_binary_expr(self._ir, other._ir)) + rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + return self._from_ir(op.to_binary_expr(self._ir, rhs)) - def __floordiv__(self, other: DummyExpr) -> Self: + def __floordiv__(self, other: IntoExpr) -> Self: op = ops.FloorDivide() - return self._from_ir(op.to_binary_expr(self._ir, other._ir)) + rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + return self._from_ir(op.to_binary_expr(self._ir, rhs)) - def __mod__(self, other: DummyExpr) -> Self: + def __mod__(self, other: IntoExpr) -> Self: op = ops.Modulus() - return self._from_ir(op.to_binary_expr(self._ir, other._ir)) + rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + return self._from_ir(op.to_binary_expr(self._ir, rhs)) - def __and__(self, other: DummyExpr) -> Self: + def __and__(self, other: IntoExpr) -> Self: op = ops.And() - return self._from_ir(op.to_binary_expr(self._ir, other._ir)) + rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + return self._from_ir(op.to_binary_expr(self._ir, rhs)) - def __or__(self, other: DummyExpr) -> Self: + def __or__(self, other: IntoExpr) -> Self: op = ops.Or() - return self._from_ir(op.to_binary_expr(self._ir, other._ir)) + rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + return self._from_ir(op.to_binary_expr(self._ir, rhs)) def __invert__(self) -> Self: return self._from_ir(boolean.Not().to_function_expr(self._ir)) From fb303f627dc1a6084be7357a2fd5fa1bef020fe4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 23 May 2025 10:32:08 +0100 Subject: [PATCH 106/368] ci: Maybe `codespell` ignore `FirstT` Can't tell if this means `FirstT` will match the entry `firstt`, but preserve the `firstt` fix (https://github.com/codespell-project/codespell#ignoring-words) (https://github.com/narwhals-dev/narwhals/pull/2572#issuecomment-2902400535) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9e7a692997..81ee0692a2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: hooks: - id: codespell files: \.(py|rst|md)$ - args: [--ignore-words-list=ser] + args: [--ignore-words-list=ser,FirstT] exclude: ^docs/api-completeness.md$ - repo: https://github.com/pycqa/flake8 rev: '7.2.0' # todo: remove once https://github.com/astral-sh/ruff/issues/458 is addressed From c58c01a3cd1f0d7d98ef59ebf5e892181ceed01c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 23 May 2025 10:33:41 +0100 Subject: [PATCH 107/368] ci: Always lowercase? Related (https://github.com/narwhals-dev/narwhals/pull/2572/commits/fb303f627dc1a6084be7357a2fd5fa1bef020fe4) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 81ee0692a2..610f401772 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: hooks: - id: codespell files: \.(py|rst|md)$ - args: [--ignore-words-list=ser,FirstT] + args: [--ignore-words-list=ser,firstt] exclude: ^docs/api-completeness.md$ - repo: https://github.com/pycqa/flake8 rev: '7.2.0' # todo: remove once https://github.com/astral-sh/ruff/issues/458 is addressed From d04da78acfc90c63fe793b07fe333d26aa4d9d4b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 23 May 2025 10:34:51 +0100 Subject: [PATCH 108/368] Why was I writing first? --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 610f401772..ba57d8d0d7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: hooks: - id: codespell files: \.(py|rst|md)$ - args: [--ignore-words-list=ser,firstt] + args: [--ignore-words-list=ser,RightT] exclude: ^docs/api-completeness.md$ - repo: https://github.com/pycqa/flake8 rev: '7.2.0' # todo: remove once https://github.com/astral-sh/ruff/issues/458 is addressed From a221039bfe6a94b48d34004baa6ce0e882a8ab2a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 23 May 2025 10:36:12 +0100 Subject: [PATCH 109/368] =?UTF-8?q?ci:=20codespell=20=F0=9F=99=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ba57d8d0d7..0e51bfa2fb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: hooks: - id: codespell files: \.(py|rst|md)$ - args: [--ignore-words-list=ser,RightT] + args: [--ignore-words-list=ser,rightt] exclude: ^docs/api-completeness.md$ - repo: https://github.com/pycqa/flake8 rev: '7.2.0' # todo: remove once https://github.com/astral-sh/ruff/issues/458 is addressed From 89c45879ba9d23b6b3a4d593595fe64e57541704 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 23 May 2025 10:46:46 +0100 Subject: [PATCH 110/368] ci: fix `codespell` https://github.com/narwhals-dev/narwhals/pull/2572#discussion_r2104219677 --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0e51bfa2fb..6be4e91ceb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: hooks: - id: codespell files: \.(py|rst|md)$ - args: [--ignore-words-list=ser,rightt] + args: [--ignore-words-list=ser, --ignore-words-list=RightT] exclude: ^docs/api-completeness.md$ - repo: https://github.com/pycqa/flake8 rev: '7.2.0' # todo: remove once https://github.com/astral-sh/ruff/issues/458 is addressed From f1db275ddcb5d67d546062a011191ac11090451f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 23 May 2025 11:54:11 +0100 Subject: [PATCH 111/368] refactor: Add `ExprIRNamespace` Reducing boilerplate for the upcoming changes in `name.py` --- narwhals/_plan/common.py | 6 ++++++ narwhals/_plan/dummy.py | 2 +- narwhals/_plan/meta.py | 16 +++++++--------- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index f283f803f8..07db2e2da3 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -200,6 +200,12 @@ def iter_right(self) -> Iterator[ExprIR]: yield self +class ExprIRNamespace(Immutable): + __slots__ = ("ir",) + + ir: ExprIR + + class Function(Immutable): """Shared by expr functions and namespace functions. diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 7c2e142730..72978dd468 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -450,7 +450,7 @@ def __pow__(self, other: IntoExpr) -> Self: def meta(self) -> ExprIRMetaNamespace: from narwhals._plan.meta import ExprIRMetaNamespace - return ExprIRMetaNamespace(self._ir) + return ExprIRMetaNamespace(ir=self._ir) class DummySelector(DummyExpr): diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 3507d6545f..7ea7338c7a 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING +from narwhals._plan.common import ExprIRNamespace from narwhals.exceptions import ComputeError from narwhals.utils import Version @@ -21,29 +22,26 @@ from narwhals._plan.common import ExprIR -class ExprIRMetaNamespace: +class ExprIRMetaNamespace(ExprIRNamespace): """Methods to modify and traverse existing expressions.""" - def __init__(self, ir: ExprIR, /) -> None: - self._ir: ExprIR = ir - def has_multiple_outputs(self) -> bool: - return any(_has_multiple_outputs(e) for e in self._ir.iter_left()) + return any(_has_multiple_outputs(e) for e in self.ir.iter_left()) def is_column(self) -> bool: from narwhals._plan.expr import Column - return isinstance(self._ir, Column) + return isinstance(self.ir, Column) def is_column_selection(self, *, allow_aliasing: bool = False) -> bool: return all( _is_column_selection(e, allow_aliasing=allow_aliasing) - for e in self._ir.iter_left() + for e in self.ir.iter_left() ) def is_literal(self, *, allow_aliasing: bool = False) -> bool: return all( - _is_literal(e, allow_aliasing=allow_aliasing) for e in self._ir.iter_left() + _is_literal(e, allow_aliasing=allow_aliasing) for e in self.ir.iter_left() ) def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: @@ -73,7 +71,7 @@ def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: >>> nwd.len().meta.output_name() 'len' """ - ok_or_err = _expr_output_name(self._ir) + ok_or_err = _expr_output_name(self.ir) if isinstance(ok_or_err, ComputeError): if raise_if_undetermined: raise ok_or_err From 4a8c2d5442ccd2ff0f64d8bc04fd9bd34ab1a3ce Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 23 May 2025 12:27:38 +0100 Subject: [PATCH 112/368] feat: `name.py` rewrite + integrate - Now we've got the same semantics as polars -`Prefix` and `Suffix` are there to be cached nodes - A `lambda` would hash but would preduce a different one on each `prefix()` --- narwhals/_plan/common.py | 8 +++- narwhals/_plan/dummy.py | 9 ++++- narwhals/_plan/expr.py | 30 +++++++++++++++ narwhals/_plan/meta.py | 33 +++++++++------- narwhals/_plan/name.py | 82 ++++++++++++++++++++-------------------- 5 files changed, 106 insertions(+), 56 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 07db2e2da3..53c784985e 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -201,9 +201,13 @@ def iter_right(self) -> Iterator[ExprIR]: class ExprIRNamespace(Immutable): - __slots__ = ("ir",) + __slots__ = ("_ir",) - ir: ExprIR + _ir: ExprIR + + @classmethod + def from_expr(cls, expr: DummyExpr, /) -> Self: + return cls(_ir=expr._ir) class Function(Immutable): diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 72978dd468..fd11810b5c 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -31,6 +31,7 @@ from narwhals._plan.common import ExprIR, IntoExpr, IntoExprColumn, Seq, Udf from narwhals._plan.meta import ExprIRMetaNamespace + from narwhals._plan.name import ExprIRNameNamespace from narwhals.typing import ( FillNullStrategy, NativeSeries, @@ -450,7 +451,13 @@ def __pow__(self, other: IntoExpr) -> Self: def meta(self) -> ExprIRMetaNamespace: from narwhals._plan.meta import ExprIRMetaNamespace - return ExprIRMetaNamespace(ir=self._ir) + return ExprIRMetaNamespace.from_expr(self) + + @property + def name(self) -> ExprIRNameNamespace: + from narwhals._plan.name import ExprIRNameNamespace + + return ExprIRNameNamespace.from_expr(self) class DummySelector(DummyExpr): diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index f406aa46f5..7fc6e61564 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -12,7 +12,9 @@ # - Literal import typing as t +from narwhals._plan.aggregation import Agg, OrderableAgg from narwhals._plan.common import ExprIR +from narwhals._plan.name import KeepName, RenameAlias from narwhals._plan.typing import ( FunctionT, LeftT, @@ -35,6 +37,34 @@ from narwhals._plan.window import Window from narwhals.dtypes import DType +__all__ = [ + "Agg", + "Alias", + "All", + "AnonymousExpr", + "BinaryExpr", + "BinarySelector", + "Cast", + "Column", + "Columns", + "Exclude", + "Filter", + "FunctionExpr", + "IndexColumns", + "KeepName", + "Len", + "Literal", + "Nth", + "OrderableAgg", + "RenameAlias", + "RollingExpr", + "SelectorIR", + "Sort", + "SortBy", + "Ternary", + "WindowExpr", +] + class Alias(ExprIR): __slots__ = ("expr", "name") diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 7ea7338c7a..4027969c87 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -26,22 +26,22 @@ class ExprIRMetaNamespace(ExprIRNamespace): """Methods to modify and traverse existing expressions.""" def has_multiple_outputs(self) -> bool: - return any(_has_multiple_outputs(e) for e in self.ir.iter_left()) + return any(_has_multiple_outputs(e) for e in self._ir.iter_left()) def is_column(self) -> bool: from narwhals._plan.expr import Column - return isinstance(self.ir, Column) + return isinstance(self._ir, Column) def is_column_selection(self, *, allow_aliasing: bool = False) -> bool: return all( _is_column_selection(e, allow_aliasing=allow_aliasing) - for e in self.ir.iter_left() + for e in self._ir.iter_left() ) def is_literal(self, *, allow_aliasing: bool = False) -> bool: return all( - _is_literal(e, allow_aliasing=allow_aliasing) for e in self.ir.iter_left() + _is_literal(e, allow_aliasing=allow_aliasing) for e in self._ir.iter_left() ) def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: @@ -71,21 +71,30 @@ def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: >>> nwd.len().meta.output_name() 'len' """ - ok_or_err = _expr_output_name(self.ir) + ok_or_err = _expr_output_name(self._ir) if isinstance(ok_or_err, ComputeError): if raise_if_undetermined: raise ok_or_err return None return ok_or_err - # NOTE: Less important for us, but maybe nice to have - def pop(self) -> list[ExprIR]: - raise NotImplementedError - def root_names(self) -> list[str]: + """After a lot of indirection, [root_names] resolves [here]. + + [root_names]: https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/meta.rs#L27-L30 + + [here]: https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/utils.rs#L171-L195 + """ raise NotImplementedError + # NOTE: Needs to know about `KeepName`, `RenameAlias` def undo_aliases(self) -> ExprIR: + """https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/meta.rs#L45-L53.""" + raise NotImplementedError + + # NOTE: Less important for us, but maybe nice to have + def pop(self) -> list[ExprIR]: + """https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/meta.rs#L14-L25.""" raise NotImplementedError @@ -98,7 +107,7 @@ def _expr_output_name(ir: ExprIR) -> str | ComputeError: return _expr_output_name(e.expr) if isinstance(e, (expr.Column, expr.Alias, expr.Literal, expr.Len)): return e.name - if isinstance(e, expr.All): + if isinstance(e, (expr.All, expr.KeepName, expr.RenameAlias)): msg = "cannot determine output column without a context for this expression" return ComputeError(msg) if isinstance(e, (expr.Columns, expr.IndexColumns, expr.Nth)): @@ -148,9 +157,7 @@ def _is_column_selection(ir: ExprIR, *, allow_aliasing: bool) -> bool: ), ): return True - # TODO @dangotbanned: Add `KeepName`, `RenameAlias` here later (see `_plan.name`) - aliasing_types = (expr.Alias,) - if isinstance(ir, aliasing_types): + if isinstance(ir, (expr.Alias, expr.KeepName, expr.RenameAlias)): return allow_aliasing return False diff --git a/narwhals/_plan/name.py b/narwhals/_plan/name.py index 0930e54481..6fad09034e 100644 --- a/narwhals/_plan/name.py +++ b/narwhals/_plan/name.py @@ -2,71 +2,73 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import Function -from narwhals._plan.options import FunctionOptions +from narwhals._plan.common import ExprIR, ExprIRNamespace, Immutable if TYPE_CHECKING: from narwhals._compliant.typing import AliasName -class NameFunction(Function): - """`polars` version [doesn't represent as `FunctionExpr`]. +class KeepName(ExprIR): + """Keep the original root name.""" - Also [doesn't support serialization]. + __slots__ = ("expr",) - [doesn't represent as `FunctionExpr`]: https://github.com/pola-rs/polars/blob/6df23a09a81c640c21788607611e09d9f43b1abc/crates/polars-plan/src/dsl/name.rs - [doesn't support serialization]: https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr_dyn_fn.rs#L145-L151 - """ - - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() + expr: ExprIR def __repr__(self) -> str: - tp = type(self) - if tp is NameFunction: - return tp.__name__ - m: dict[type[NameFunction], str] = { - Keep: "keep", - Map: "map", - Suffix: "suffix", - Prefix: "prefix", - ToLowercase: "to_lowercase", - ToUppercase: "to_uppercase", - } - return f"name.{m[tp]}" - - -class Keep(NameFunction): - """Returns ``Expr::KeepName``.""" + return f"{self.expr!r}.name.keep()" -class Map(NameFunction): - """Returns ``Expr::RenameAlias``. - - https://github.com/pola-rs/polars/blob/6df23a09a81c640c21788607611e09d9f43b1abc/crates/polars-plan/src/dsl/name.rs#L28-L38 - """ - - __slots__ = ("function",) +class RenameAlias(ExprIR): + __slots__ = ("expr", "function") + expr: ExprIR function: AliasName + def __repr__(self) -> str: + return f".rename_alias({self.expr!r})" -class Prefix(NameFunction): - """Each of these depend on `Map`.""" +class Prefix(Immutable): __slots__ = ("prefix",) prefix: str + def __call__(self, name: str, /) -> str: + return f"{self.prefix}{name}" -class Suffix(NameFunction): + +class Suffix(Immutable): __slots__ = ("suffix",) suffix: str + def __call__(self, name: str, /) -> str: + return f"{name}{self.suffix}" + + +class ExprIRNameNamespace(ExprIRNamespace): + """Specialized expressions for modifying the name of existing expressions.""" + + def keep(self) -> KeepName: + return KeepName(expr=self._ir) + + def map(self, function: AliasName) -> RenameAlias: + """Define an alias by mapping a function over the original root column name.""" + return RenameAlias(expr=self._ir, function=function) + + def prefix(self, prefix: str) -> RenameAlias: + """Add a prefix to the root column name.""" + return self.map(Prefix(prefix=prefix)) -class ToLowercase(NameFunction): ... + def suffix(self, suffix: str) -> RenameAlias: + """Add a suffix to the root column name.""" + return self.map(Suffix(suffix=suffix)) + def to_lowercase(self) -> RenameAlias: + """Update the root column name to use lowercase characters.""" + return self.map(str.lower) -class ToUppercase(NameFunction): ... + def to_uppercase(self) -> RenameAlias: + """Update the root column name to use uppercase characters.""" + return self.map(str.upper) From b9eb25e56be22407f8c3a68860efced6fe2419a7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 23 May 2025 13:48:43 +0100 Subject: [PATCH 113/368] feat: Rewrite selectors, factor out `_BinaryOp` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🥳🥳🥳 - Has all the properties I'm wanting - Keeps the `Root` name from `polars` - Fixes the typing issues - Treats all `SelectorIR` objects the same for iteration (like in `rust`) --- narwhals/_plan/common.py | 16 ++++++++++- narwhals/_plan/dummy.py | 10 +++---- narwhals/_plan/expr.py | 54 +++++++++++++++++-------------------- narwhals/_plan/operators.py | 10 ++++--- narwhals/_plan/selectors.py | 8 +++--- narwhals/_plan/typing.py | 4 ++- 6 files changed, 56 insertions(+), 46 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 53c784985e..0bb27e588f 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -11,7 +11,12 @@ from typing_extensions import Never, Self, TypeAlias, TypeIs, dataclass_transform - from narwhals._plan.dummy import DummyCompliantExpr, DummyExpr, DummySeries + from narwhals._plan.dummy import ( + DummyCompliantExpr, + DummyExpr, + DummySelector, + DummySeries, + ) from narwhals._plan.expr import FunctionExpr from narwhals._plan.options import FunctionOptions from narwhals.typing import NonNestedLiteral @@ -200,6 +205,15 @@ def iter_right(self) -> Iterator[ExprIR]: yield self +class SelectorIR(ExprIR): + def to_narwhals(self, version: Version = Version.MAIN) -> DummySelector: + from narwhals._plan import dummy + + if version is Version.MAIN: + return dummy.DummySelector._from_ir(self) + return dummy.DummySelectorV1._from_ir(self) + + class ExprIRNamespace(Immutable): __slots__ = ("_ir",) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index fd11810b5c..87c308f672 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -472,30 +472,28 @@ def _from_ir(cls, ir: expr.SelectorIR, /) -> Self: # type: ignore[override] def _to_expr(self) -> DummyExpr: return self._ir.to_narwhals(self.version) - # TODO @dangotbanned: Make a decision on selector root, binary op - # Current typing warnings are accurate, this isn't valid yet def __or__(self, other: t.Any) -> Self | t.Any: if isinstance(other, type(self)): op = ops.Or() - return self._from_ir(op.to_binary_selector(self._ir, other._ir)) # type: ignore[arg-type] + return self._from_ir(op.to_binary_selector(self._ir, other._ir)) return self._to_expr() | other def __and__(self, other: t.Any) -> Self | t.Any: if isinstance(other, type(self)): op = ops.And() - return self._from_ir(op.to_binary_selector(self._ir, other._ir)) # type: ignore[arg-type] + return self._from_ir(op.to_binary_selector(self._ir, other._ir)) return self._to_expr() & other def __sub__(self, other: t.Any) -> Self | t.Any: if isinstance(other, type(self)): op = ops.Sub() - return self._from_ir(op.to_binary_selector(self._ir, other._ir)) # type: ignore[arg-type] + return self._from_ir(op.to_binary_selector(self._ir, other._ir)) return self._to_expr() - other def __xor__(self, other: t.Any) -> Self | t.Any: if isinstance(other, type(self)): op = ops.ExclusiveOr() - return self._from_ir(op.to_binary_selector(self._ir, other._ir)) # type: ignore[arg-type] + return self._from_ir(op.to_binary_selector(self._ir, other._ir)) return self._to_expr() ^ other def __invert__(self) -> Never: diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 7fc6e61564..28994d0a33 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -13,23 +13,23 @@ import typing as t from narwhals._plan.aggregation import Agg, OrderableAgg -from narwhals._plan.common import ExprIR +from narwhals._plan.common import ExprIR, SelectorIR from narwhals._plan.name import KeepName, RenameAlias from narwhals._plan.typing import ( FunctionT, + LeftSelectorT, LeftT, OperatorT, + RightSelectorT, RightT, RollingT, SelectorOperatorT, ) -from narwhals.utils import Version if t.TYPE_CHECKING: from typing_extensions import Self from narwhals._plan.common import Seq - from narwhals._plan.dummy import DummySelector from narwhals._plan.functions import MapBatches # noqa: F401 from narwhals._plan.literal import LiteralValue from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions @@ -58,6 +58,7 @@ "OrderableAgg", "RenameAlias", "RollingExpr", + "RootSelector", "SelectorIR", "Sort", "SortBy", @@ -129,16 +130,7 @@ def __repr__(self) -> str: return f"lit({self.value!r})" -class BinaryExpr(ExprIR, t.Generic[LeftT, OperatorT, RightT]): - """Application of two exprs via an `Operator`. - - This ✅ - - https://github.com/pola-rs/polars/blob/6df23a09a81c640c21788607611e09d9f43b1abc/crates/polars-plan/src/plans/aexpr/mod.rs#L152-L155 - - Not this ❌ - - https://github.com/pola-rs/polars/blob/da27decd9a1adabe0498b786585287eb730d1d91/crates/polars-plan/src/dsl/function_expr/mod.rs#L127 - """ - +class _BinaryOp(ExprIR, t.Generic[LeftT, OperatorT, RightT]): __slots__ = ("left", "op", "right") left: LeftT @@ -149,6 +141,12 @@ class BinaryExpr(ExprIR, t.Generic[LeftT, OperatorT, RightT]): def is_scalar(self) -> bool: return self.left.is_scalar and self.right.is_scalar + +class BinaryExpr( + _BinaryOp[LeftT, OperatorT, RightT], t.Generic[LeftT, OperatorT, RightT] +): + """Application of two exprs via an `Operator`.""" + def __repr__(self) -> str: return f"[({self.left!r}) {self.op!r} ({self.right!r})]" @@ -448,31 +446,27 @@ def __repr__(self) -> str: return "*" -class SelectorIR(ExprIR): - """Not sure on this separation. - - - Need a cleaner way of including `BinarySelector`. - - Like that there's easy access to operands - - Dislike that it inherits node iteration, since upstream doesn't use it for selectors - """ +# TODO @dangotbanned: reprs +class RootSelector(SelectorIR): + """A single selector expression.""" __slots__ = ("selector",) selector: Selector """by_dtype, matches, numeric, boolean, string, categorical, datetime, all.""" - def to_narwhals(self, version: Version = Version.MAIN) -> DummySelector: - from narwhals._plan import dummy - - if version is Version.MAIN: - return dummy.DummySelector._from_ir(self) - return dummy.DummySelectorV1._from_ir(self) - +# TODO @dangotbanned: reprs class BinarySelector( - BinaryExpr["SelectorIR", SelectorOperatorT, "SelectorIR"], - t.Generic[SelectorOperatorT], -): ... + _BinaryOp[LeftSelectorT, SelectorOperatorT, RightSelectorT], + SelectorIR, + t.Generic[LeftSelectorT, SelectorOperatorT, RightSelectorT], +): + """Application of two selector exprs via a set operator. + + Note: + `left` and `right` may also nest other `BinarySelector`s. + """ class Ternary(ExprIR): diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index aacafa8bf9..5d43c4b514 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -2,11 +2,13 @@ from typing import TYPE_CHECKING +from narwhals._plan.expr import BinarySelector + if TYPE_CHECKING: from typing_extensions import Self - from narwhals._plan.expr import BinaryExpr, BinarySelector, SelectorIR - from narwhals._plan.typing import LeftT, RightT + from narwhals._plan.expr import BinaryExpr, BinarySelector + from narwhals._plan.typing import LeftSelectorT, LeftT, RightSelectorT, RightT from narwhals._plan.common import Immutable @@ -52,8 +54,8 @@ class SelectorOperator(Operator): """ def to_binary_selector( - self, left: SelectorIR, right: SelectorIR, / - ) -> BinarySelector[Self]: + self, left: LeftSelectorT, right: RightSelectorT, / + ) -> BinarySelector[LeftSelectorT, Self, RightSelectorT]: from narwhals._plan.expr import BinarySelector return BinarySelector(left=left, op=self, right=right) diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/selectors.py index 580f63d05e..b55c1731f5 100644 --- a/narwhals/_plan/selectors.py +++ b/narwhals/_plan/selectors.py @@ -17,7 +17,7 @@ from typing import Iterator, TypeVar from narwhals._plan.dummy import DummySelector - from narwhals._plan.expr import SelectorIR + from narwhals._plan.expr import RootSelector from narwhals.dtypes import DType from narwhals.typing import TimeUnit @@ -25,10 +25,10 @@ class Selector(Immutable): - def to_selector(self) -> SelectorIR: - from narwhals._plan.expr import SelectorIR + def to_selector(self) -> RootSelector: + from narwhals._plan.expr import RootSelector - return SelectorIR(selector=self) + return RootSelector(selector=self) class All(Selector): ... diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 337bd48ab8..1878c94268 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -6,7 +6,7 @@ if t.TYPE_CHECKING: from narwhals._plan import operators as ops - from narwhals._plan.common import ExprIR, Function + from narwhals._plan.common import ExprIR, Function, SelectorIR from narwhals._plan.functions import RollingWindow __all__ = ["FunctionT", "LeftT", "OperatorT", "RightT", "RollingT", "SelectorOperatorT"] @@ -18,6 +18,8 @@ OperatorT = TypeVar("OperatorT", bound="ops.Operator", default="ops.Operator") RightT = TypeVar("RightT", bound="ExprIR", default="ExprIR") +LeftSelectorT = TypeVar("LeftSelectorT", bound="SelectorIR", default="SelectorIR") +RightSelectorT = TypeVar("RightSelectorT", bound="SelectorIR", default="SelectorIR") SelectorOperatorT = TypeVar( "SelectorOperatorT", bound="ops.SelectorOperator", default="ops.SelectorOperator" ) From 233c67087a7420890e3de19562d35142833966ad Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 23 May 2025 14:58:58 +0100 Subject: [PATCH 114/368] feat: Add selectors reprs --- narwhals/_plan/dummy.py | 13 +++++++++++++ narwhals/_plan/expr.py | 11 ++++++----- narwhals/_plan/selectors.py | 32 +++++++++++++++++++++++++++----- 3 files changed, 46 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 87c308f672..6d5990bb41 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -461,8 +461,21 @@ def name(self) -> ExprIRNameNamespace: class DummySelector(DummyExpr): + """Selectors placeholder. + + Examples: + >>> from narwhals._plan import selectors as ncs + >>> + >>> (ncs.matches("[^z]a") & ncs.string()) | ncs.datetime("us", None) + Narwhals DummySelector (main): + [([(ncs.matches(pattern='[^z]a')) & (ncs.string())]) | (ncs.datetime(time_unit=['us'], time_zone=[None]))] + """ + _ir: expr.SelectorIR + def __repr__(self) -> str: + return f"Narwhals DummySelector ({self.version.name.lower()}):\n{self._ir!r}" + @classmethod def _from_ir(cls, ir: expr.SelectorIR, /) -> Self: # type: ignore[override] obj = cls.__new__(cls) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 28994d0a33..fb0e179b63 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -141,15 +141,15 @@ class _BinaryOp(ExprIR, t.Generic[LeftT, OperatorT, RightT]): def is_scalar(self) -> bool: return self.left.is_scalar and self.right.is_scalar + def __repr__(self) -> str: + return f"[({self.left!r}) {self.op!r} ({self.right!r})]" + class BinaryExpr( _BinaryOp[LeftT, OperatorT, RightT], t.Generic[LeftT, OperatorT, RightT] ): """Application of two exprs via an `Operator`.""" - def __repr__(self) -> str: - return f"[({self.left!r}) {self.op!r} ({self.right!r})]" - def iter_left(self) -> t.Iterator[ExprIR]: yield from self.left.iter_left() yield from self.right.iter_left() @@ -446,7 +446,6 @@ def __repr__(self) -> str: return "*" -# TODO @dangotbanned: reprs class RootSelector(SelectorIR): """A single selector expression.""" @@ -455,8 +454,10 @@ class RootSelector(SelectorIR): selector: Selector """by_dtype, matches, numeric, boolean, string, categorical, datetime, all.""" + def __repr__(self) -> str: + return f"{self.selector!r}" + -# TODO @dangotbanned: reprs class BinarySelector( _BinaryOp[LeftSelectorT, SelectorOperatorT, RightSelectorT], SelectorIR, diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/selectors.py index b55c1731f5..f9ef444b60 100644 --- a/narwhals/_plan/selectors.py +++ b/narwhals/_plan/selectors.py @@ -31,7 +31,9 @@ def to_selector(self) -> RootSelector: return RootSelector(selector=self) -class All(Selector): ... +class All(Selector): + def __repr__(self) -> str: + return "ncs.all()" class ByDType(Selector): @@ -45,11 +47,21 @@ def from_dtypes( ) -> ByDType: return ByDType(dtypes=frozenset(_flatten_hash_safe(dtypes))) + def __repr__(self) -> str: + els = ", ".join( + tp.__name__ if isinstance(tp, type) else repr(tp) for tp in self.dtypes + ) + return f"ncs.by_dtype(dtypes=[{els}])" -class Boolean(Selector): ... +class Boolean(Selector): + def __repr__(self) -> str: + return "ncs.boolean()" -class Categorical(Selector): ... + +class Categorical(Selector): + def __repr__(self) -> str: + return "ncs.categorical()" class Datetime(Selector): @@ -74,6 +86,9 @@ def from_time_unit_and_time_zone( units, zones = _parse_time_unit_and_time_zone(time_unit, time_zone) return Datetime(time_units=frozenset(units), time_zones=frozenset(zones)) + def __repr__(self) -> str: + return f"ncs.datetime(time_unit={list(self.time_units)}, time_zone={list(self.time_zones)})" + class Matches(Selector): __slots__ = ("pattern",) @@ -84,11 +99,18 @@ class Matches(Selector): def from_string(pattern: str, /) -> Matches: return Matches(pattern=re.compile(pattern)) + def __repr__(self) -> str: + return f"ncs.matches(pattern={self.pattern.pattern!r})" + -class Numeric(Selector): ... +class Numeric(Selector): + def __repr__(self) -> str: + return "ncs.numeric()" -class String(Selector): ... +class String(Selector): + def __repr__(self) -> str: + return "ncs.string()" def all() -> DummySelector: From 756480225c8e6cc0e98fed75f0254dd2761db37a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 23 May 2025 15:39:03 +0100 Subject: [PATCH 115/368] feat: Expressify `over`, `sort_by` --- narwhals/_plan/dummy.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 6d5990bb41..c638a76b44 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -24,7 +24,7 @@ from narwhals._plan.window import Over from narwhals.dtypes import DType from narwhals.exceptions import ComputeError -from narwhals.utils import Version, _hasattr_static, flatten +from narwhals.utils import Version, _hasattr_static if TYPE_CHECKING: from typing_extensions import Never, Self @@ -110,18 +110,20 @@ def quantile( def over( self, - *partition_by: DummyExpr | t.Iterable[DummyExpr], - order_by: DummyExpr | t.Iterable[DummyExpr] | None = None, + *partition_by: IntoExpr | t.Iterable[IntoExpr], + order_by: IntoExpr | t.Iterable[IntoExpr] = None, descending: bool = False, nulls_last: bool = False, ) -> Self: + partition: Seq[ExprIR] = () order: tuple[Seq[ExprIR], SortOptions] | None = None - partition = tuple(expr._ir for expr in flatten(partition_by)) - if not (partition) and order_by is None: + if not (partition_by) and order_by is None: msg = "At least one of `partition_by` or `order_by` must be specified." raise TypeError(msg) + if partition_by: + partition = parse.parse_into_seq_of_expr_ir(*partition_by) if order_by is not None: - by = tuple(expr._ir for expr in flatten([order_by])) + by = parse.parse_into_seq_of_expr_ir(order_by) options = SortOptions(descending=descending, nulls_last=nulls_last) order = by, options return self._from_ir(Over().to_window_expr(self._ir, partition, order)) @@ -132,16 +134,12 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: def sort_by( self, - by: DummyExpr | t.Iterable[DummyExpr], - *more_by: DummyExpr, + by: IntoExpr | t.Iterable[IntoExpr], + *more_by: IntoExpr, descending: bool | t.Iterable[bool] = False, nulls_last: bool | t.Iterable[bool] = False, ) -> Self: - if more_by: - by = (by, *more_by) if isinstance(by, DummyExpr) else (*by, *more_by) - else: - by = (by,) if isinstance(by, DummyExpr) else tuple(by) - sort_by = tuple(key._ir for key in by) + sort_by = parse.parse_into_seq_of_expr_ir(by, *more_by) desc = (descending,) if isinstance(descending, bool) else tuple(descending) nulls = (nulls_last,) if isinstance(nulls_last, bool) else tuple(nulls_last) options = SortMultipleOptions(descending=desc, nulls_last=nulls) From 4808f2da12c87d7cf5bc21c03838d674f5227854 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 23 May 2025 15:47:03 +0100 Subject: [PATCH 116/368] feat: Expressify `concat_str` --- narwhals/_plan/demo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index 85e64da2ea..afc54b2a3c 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -118,12 +118,12 @@ def mean_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> DummyExpr: def concat_str( - exprs: DummyExpr | t.Iterable[DummyExpr], - *more_exprs: DummyExpr, + exprs: IntoExpr | t.Iterable[IntoExpr], + *more_exprs: IntoExpr, separator: str = "", ignore_nulls: bool = False, ) -> DummyExpr: - it = (expr._ir for expr in flatten([*flatten([exprs]), *more_exprs])) + it = parse.parse_into_seq_of_expr_ir(exprs, *more_exprs) return ( ConcatHorizontal(separator=separator, ignore_nulls=ignore_nulls) .to_function_expr(*it) From ffb0e39f67ab18ab1170fea3999e91a4403270eb Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 23 May 2025 19:33:32 +0100 Subject: [PATCH 117/368] feat: Implement `Expr.meta.root_names` - I think I'm gonna skip the other two, due to complexity - Locally I'm getting `root_names` to match polars --- narwhals/_plan/common.py | 7 ++++ narwhals/_plan/meta.py | 72 +++++++++++++++++++++++++++++++++++----- tests/plan/meta_test.py | 49 +++++++++++++++++++++++++++ 3 files changed, 119 insertions(+), 9 deletions(-) create mode 100644 tests/plan/meta_test.py diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 0bb27e588f..3319be31a0 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -18,6 +18,7 @@ DummySeries, ) from narwhals._plan.expr import FunctionExpr + from narwhals._plan.meta import ExprIRMetaNamespace from narwhals._plan.options import FunctionOptions from narwhals.typing import NonNestedLiteral @@ -204,6 +205,12 @@ def iter_right(self) -> Iterator[ExprIR]: """ yield self + @property + def meta(self) -> ExprIRMetaNamespace: + from narwhals._plan.meta import ExprIRMetaNamespace + + return ExprIRMetaNamespace(_ir=self) + class SelectorIR(ExprIR): def to_narwhals(self, version: Version = Version.MAIN) -> DummySelector: diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 4027969c87..216c4c919e 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -15,7 +15,7 @@ from narwhals.utils import Version if TYPE_CHECKING: - from typing import Any + from typing import Any, Iterator import polars as pl @@ -79,17 +79,27 @@ def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: return ok_or_err def root_names(self) -> list[str]: - """After a lot of indirection, [root_names] resolves [here]. + """Get the root column names.""" + return _expr_to_leaf_column_names(self._ir) - [root_names]: https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/meta.rs#L27-L30 + # NOTE: Seems too complex to do whilst keeping things immutable + def undo_aliases(self) -> ExprIR: + """Investigate components. - [here]: https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/utils.rs#L171-L195 - """ - raise NotImplementedError + Seems like it unnests each of these: + - `Alias.expr` + - `KeepName.expr` + - `RenameAlias.expr` - # NOTE: Needs to know about `KeepName`, `RenameAlias` - def undo_aliases(self) -> ExprIR: - """https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/meta.rs#L45-L53.""" + Notes: + - [`meta.undo_aliases`] + - [`Expr.map_expr`] + - [`TreeWalker.rewrite`] + + [`Expr.map_expr`]: https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/plans/iterator.rs#L146-L149 + [`meta.undo_aliases`]: https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/meta.rs#L45-L53 + [`TreeWalker.rewrite`]: https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/plans/visitor/visitors.rs#L46-L68 + """ raise NotImplementedError # NOTE: Less important for us, but maybe nice to have @@ -98,6 +108,50 @@ def pop(self) -> list[ExprIR]: raise NotImplementedError +def _expr_to_leaf_column_names(ir: ExprIR) -> list[str]: + """After a lot of indirection, [root_names] resolves [here]. + + [root_names]: https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/meta.rs#L27-L30 + [here]: https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/utils.rs#L171-L195 + """ + return list(_expr_to_leaf_column_names_iter(ir)) + + +def _expr_to_leaf_column_names_iter(ir: ExprIR) -> Iterator[str]: + for e in _expr_to_leaf_column_exprs_iter(ir): + result = _expr_to_leaf_column_name(e) + if isinstance(result, str): + yield result + + +def _expr_to_leaf_column_exprs_iter(ir: ExprIR) -> Iterator[ExprIR]: + from narwhals._plan import expr + + for outer in ir.iter_left(): + if isinstance(outer, (expr.Column, expr.All)): + yield outer + + +def _expr_to_leaf_column_name(ir: ExprIR) -> str | ComputeError: + leaves = list(_expr_to_leaf_column_exprs_iter(ir)) + if not len(leaves) <= 1: + msg = "found more than one root column name" + return ComputeError(msg) + if not leaves: + msg = "no root column name found" + return ComputeError(msg) + leaf = leaves[0] + from narwhals._plan import expr + + if isinstance(leaf, expr.Column): + return leaf.name + if isinstance(leaf, expr.All): + msg = "wildcard has no root column name" + return ComputeError(msg) + msg = f"Expected unreachable, got {type(leaf).__name__!r}\n\n{leaf}" + return ComputeError(msg) + + def _expr_output_name(ir: ExprIR) -> str | ComputeError: from narwhals._plan import expr diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py new file mode 100644 index 0000000000..7b5a19e0ae --- /dev/null +++ b/tests/plan/meta_test.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals._plan.demo as nwd + +if TYPE_CHECKING: + from narwhals._plan.dummy import DummyExpr + +pytest.importorskip("polars") +import polars as pl + + +@pytest.mark.parametrize( + ("nw_expr", "pl_expr", "expected"), + [ + ( + nwd.col("a").alias("b").min().alias("c").alias("d"), + pl.col("a").alias("b").min().alias("c").alias("d"), + ["a"], + ), + ( + (nwd.col("a") + (nwd.col("a") - nwd.col("b"))).alias("c"), + (pl.col("a") + (pl.col("a") - pl.col("b"))).alias("c"), + ["a", "a", "b"], + ), + ( + nwd.col("a").last().over("b", order_by="c"), + pl.col("a").last().over("b", order_by="c"), + ["a", "b"], + ), + ( + (nwd.col("a", "b", "c").sort().abs() * 20).max(), + (pl.col("a", "b", "c").sort().abs() * 20).max(), + [], + ), + (nwd.all().mean(), pl.all().mean(), []), + (nwd.all().mean().sort_by("d"), pl.all().mean().sort_by("d"), ["d"]), + ], +) +def test_meta_root_names( + nw_expr: DummyExpr, pl_expr: pl.Expr, expected: list[str] +) -> None: + pl_result = pl_expr.meta.root_names() + nw_result = nw_expr.meta.root_names() + assert nw_result == expected + assert nw_result == pl_result From f295d67ccb33ac2f87554000ee3b0f2000c1ac7c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 23 May 2025 20:03:02 +0100 Subject: [PATCH 118/368] test: polars backcompat https://github.com/narwhals-dev/narwhals/actions/runs/15216876843/job/42804317065?pr=2572 --- tests/plan/meta_test.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index 7b5a19e0ae..69042598fe 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -5,6 +5,7 @@ import pytest import narwhals._plan.demo as nwd +from tests.utils import POLARS_VERSION if TYPE_CHECKING: from narwhals._plan.dummy import DummyExpr @@ -12,6 +13,16 @@ pytest.importorskip("polars") import polars as pl +if POLARS_VERSION >= (1, 0): + # https://github.com/pola-rs/polars/pull/16743 + OVER_CASE = ( + nwd.col("a").last().over("b", order_by="c"), + pl.col("a").last().over("b", order_by="c"), + ["a", "b"], + ) +else: + OVER_CASE = (nwd.col("a").last().over("b"), pl.col("a").last().over("b"), ["a", "b"]) + @pytest.mark.parametrize( ("nw_expr", "pl_expr", "expected"), @@ -26,11 +37,7 @@ (pl.col("a") + (pl.col("a") - pl.col("b"))).alias("c"), ["a", "a", "b"], ), - ( - nwd.col("a").last().over("b", order_by="c"), - pl.col("a").last().over("b", order_by="c"), - ["a", "b"], - ), + OVER_CASE, ( (nwd.col("a", "b", "c").sort().abs() * 20).max(), (pl.col("a", "b", "c").sort().abs() * 20).max(), From 3e4449305e05dd1d3dc2d0d26d946bfe97eebae5 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 23 May 2025 20:09:52 +0100 Subject: [PATCH 119/368] cov https://github.com/narwhals-dev/narwhals/actions/runs/15217344754/job/42805778043?pr=2572 --- tests/plan/meta_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index 69042598fe..e4e0a41be4 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -20,7 +20,7 @@ pl.col("a").last().over("b", order_by="c"), ["a", "b"], ) -else: +else: # pragma: no cover OVER_CASE = (nwd.col("a").last().over("b"), pl.col("a").last().over("b"), ["a", "b"]) From 737c83a67a85c96acf663bdfddac488b68812ea7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 24 May 2025 18:02:13 +0100 Subject: [PATCH 120/368] feat: Split `IR` and `Expr` namespaces - The current version of `IRMetaNamespace` would be identical for both - The methods that return `ExprIR` aren't implemented - Need to repeat for the other namespaces --- narwhals/_plan/common.py | 30 ++++++++++++++++++++++++------ narwhals/_plan/dummy.py | 27 ++++++++++++++++++--------- narwhals/_plan/meta.py | 4 ++-- narwhals/_plan/name.py | 36 ++++++++++++++++++++++++++++++++---- narwhals/_plan/typing.py | 3 ++- 5 files changed, 78 insertions(+), 22 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 3319be31a0..57e94b0a48 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -2,8 +2,9 @@ import datetime as dt from decimal import Decimal -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar +from narwhals._plan.typing import IRNamespaceT from narwhals.utils import Version if TYPE_CHECKING: @@ -18,7 +19,7 @@ DummySeries, ) from narwhals._plan.expr import FunctionExpr - from narwhals._plan.meta import ExprIRMetaNamespace + from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.options import FunctionOptions from narwhals.typing import NonNestedLiteral @@ -206,10 +207,10 @@ def iter_right(self) -> Iterator[ExprIR]: yield self @property - def meta(self) -> ExprIRMetaNamespace: - from narwhals._plan.meta import ExprIRMetaNamespace + def meta(self) -> IRMetaNamespace: + from narwhals._plan.meta import IRMetaNamespace - return ExprIRMetaNamespace(_ir=self) + return IRMetaNamespace(_ir=self) class SelectorIR(ExprIR): @@ -221,7 +222,7 @@ def to_narwhals(self, version: Version = Version.MAIN) -> DummySelector: return dummy.DummySelectorV1._from_ir(self) -class ExprIRNamespace(Immutable): +class IRNamespace(Immutable): __slots__ = ("_ir",) _ir: ExprIR @@ -231,6 +232,23 @@ def from_expr(cls, expr: DummyExpr, /) -> Self: return cls(_ir=expr._ir) +class ExprNamespace(Immutable, Generic[IRNamespaceT]): + __slots__ = ("_expr",) + + _expr: DummyExpr + + @property + def _ir_namespace(self) -> type[IRNamespaceT]: + raise NotImplementedError + + @property + def _ir(self) -> IRNamespaceT: + return self._ir_namespace.from_expr(self._expr) + + def _to_narwhals(self, ir: ExprIR, /) -> DummyExpr: + return self._expr._from_ir(ir) + + class Function(Immutable): """Shared by expr functions and namespace functions. diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index c638a76b44..ac39a103f4 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -30,8 +30,8 @@ from typing_extensions import Never, Self from narwhals._plan.common import ExprIR, IntoExpr, IntoExprColumn, Seq, Udf - from narwhals._plan.meta import ExprIRMetaNamespace - from narwhals._plan.name import ExprIRNameNamespace + from narwhals._plan.meta import IRMetaNamespace + from narwhals._plan.name import ExprNameNamespace from narwhals.typing import ( FillNullStrategy, NativeSeries, @@ -446,16 +446,25 @@ def __pow__(self, other: IntoExpr) -> Self: return self._from_ir(F.Pow().to_function_expr(base, exponent)) @property - def meta(self) -> ExprIRMetaNamespace: - from narwhals._plan.meta import ExprIRMetaNamespace + def meta(self) -> IRMetaNamespace: + from narwhals._plan.meta import IRMetaNamespace - return ExprIRMetaNamespace.from_expr(self) + return IRMetaNamespace.from_expr(self) @property - def name(self) -> ExprIRNameNamespace: - from narwhals._plan.name import ExprIRNameNamespace - - return ExprIRNameNamespace.from_expr(self) + def name(self) -> ExprNameNamespace: + """Specialized expressions for modifying the name of existing expressions. + + Examples: + >>> from narwhals._plan import demo as nw + >>> + >>> renamed = nw.col("a", "b").name.suffix("_changed") + >>> str(renamed._ir) + "RenameAlias(expr=Columns(names=[a, b]), function=Suffix(suffix='_changed'))" + """ + from narwhals._plan.name import ExprNameNamespace + + return ExprNameNamespace(_expr=self) class DummySelector(DummyExpr): diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 216c4c919e..9454bd758d 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import ExprIRNamespace +from narwhals._plan.common import IRNamespace from narwhals.exceptions import ComputeError from narwhals.utils import Version @@ -22,7 +22,7 @@ from narwhals._plan.common import ExprIR -class ExprIRMetaNamespace(ExprIRNamespace): +class IRMetaNamespace(IRNamespace): """Methods to modify and traverse existing expressions.""" def has_multiple_outputs(self) -> bool: diff --git a/narwhals/_plan/name.py b/narwhals/_plan/name.py index 6fad09034e..d382c49292 100644 --- a/narwhals/_plan/name.py +++ b/narwhals/_plan/name.py @@ -2,10 +2,11 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import ExprIR, ExprIRNamespace, Immutable +from narwhals._plan.common import ExprIR, ExprNamespace, Immutable, IRNamespace if TYPE_CHECKING: from narwhals._compliant.typing import AliasName + from narwhals._plan.dummy import DummyExpr class KeepName(ExprIR): @@ -47,9 +48,7 @@ def __call__(self, name: str, /) -> str: return f"{name}{self.suffix}" -class ExprIRNameNamespace(ExprIRNamespace): - """Specialized expressions for modifying the name of existing expressions.""" - +class IRNameNamespace(IRNamespace): def keep(self) -> KeepName: return KeepName(expr=self._ir) @@ -72,3 +71,32 @@ def to_lowercase(self) -> RenameAlias: def to_uppercase(self) -> RenameAlias: """Update the root column name to use uppercase characters.""" return self.map(str.upper) + + +class ExprNameNamespace(ExprNamespace[IRNameNamespace]): + @property + def _ir_namespace(self) -> type[IRNameNamespace]: + return IRNameNamespace + + def keep(self) -> DummyExpr: + return self._to_narwhals(self._ir.keep()) + + def map(self, function: AliasName) -> DummyExpr: + """Define an alias by mapping a function over the original root column name.""" + return self._to_narwhals(self._ir.map(function)) + + def prefix(self, prefix: str) -> DummyExpr: + """Add a prefix to the root column name.""" + return self._to_narwhals(self._ir.prefix(prefix)) + + def suffix(self, suffix: str) -> DummyExpr: + """Add a suffix to the root column name.""" + return self._to_narwhals(self._ir.suffix(suffix)) + + def to_lowercase(self) -> DummyExpr: + """Update the root column name to use lowercase characters.""" + return self._to_narwhals(self._ir.to_lowercase()) + + def to_uppercase(self) -> DummyExpr: + """Update the root column name to use uppercase characters.""" + return self._to_narwhals(self._ir.to_uppercase()) diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 1878c94268..add10565cb 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -6,7 +6,7 @@ if t.TYPE_CHECKING: from narwhals._plan import operators as ops - from narwhals._plan.common import ExprIR, Function, SelectorIR + from narwhals._plan.common import ExprIR, Function, IRNamespace, SelectorIR from narwhals._plan.functions import RollingWindow __all__ = ["FunctionT", "LeftT", "OperatorT", "RightT", "RollingT", "SelectorOperatorT"] @@ -23,3 +23,4 @@ SelectorOperatorT = TypeVar( "SelectorOperatorT", bound="ops.SelectorOperator", default="ops.SelectorOperator" ) +IRNamespaceT = TypeVar("IRNamespaceT", bound="IRNamespace") From a3e29d1c8da03ba078628772cf0d562390a335f0 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 24 May 2025 18:25:24 +0100 Subject: [PATCH 121/368] feat: Add `cat`, `struct`, `list` namespaces to `DummyExpr` --- narwhals/_plan/categorical.py | 23 ++++++++++++++++++++++- narwhals/_plan/common.py | 7 +++++++ narwhals/_plan/dummy.py | 21 +++++++++++++++++++++ narwhals/_plan/lists.py | 21 ++++++++++++++++++++- narwhals/_plan/struct.py | 21 ++++++++++++++++++++- 5 files changed, 90 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/categorical.py b/narwhals/_plan/categorical.py index 44c2c7023c..47e045d926 100644 --- a/narwhals/_plan/categorical.py +++ b/narwhals/_plan/categorical.py @@ -1,8 +1,13 @@ from __future__ import annotations -from narwhals._plan.common import Function +from typing import TYPE_CHECKING + +from narwhals._plan.common import ExprNamespace, Function, IRNamespace from narwhals._plan.options import FunctionOptions +if TYPE_CHECKING: + from narwhals._plan.dummy import DummyExpr + class CategoricalFunction(Function): ... @@ -17,3 +22,19 @@ def function_options(self) -> FunctionOptions: def __repr__(self) -> str: return "cat.get_categories" + + +class IRCatNamespace(IRNamespace): + def get_categories(self) -> GetCategories: + return GetCategories() + + +class ExprCatNamespace(ExprNamespace[IRCatNamespace]): + @property + def _ir_namespace(self) -> type[IRCatNamespace]: + return IRCatNamespace + + def get_categories(self) -> DummyExpr: + return self._to_narwhals( + self._ir.get_categories().to_function_expr(self._expr._ir) + ) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 57e94b0a48..0c6b575733 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -19,6 +19,7 @@ DummySeries, ) from narwhals._plan.expr import FunctionExpr + from narwhals._plan.lists import IRListNamespace from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.options import FunctionOptions from narwhals.typing import NonNestedLiteral @@ -212,6 +213,12 @@ def meta(self) -> IRMetaNamespace: return IRMetaNamespace(_ir=self) + @property + def list(self) -> IRListNamespace: + from narwhals._plan.lists import IRListNamespace + + return IRListNamespace(_ir=self) + class SelectorIR(ExprIR): def to_narwhals(self, version: Version = Version.MAIN) -> DummySelector: diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index ac39a103f4..b7eac25457 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -29,9 +29,12 @@ if TYPE_CHECKING: from typing_extensions import Never, Self + from narwhals._plan.categorical import ExprCatNamespace from narwhals._plan.common import ExprIR, IntoExpr, IntoExprColumn, Seq, Udf + from narwhals._plan.lists import ExprListNamespace from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.name import ExprNameNamespace + from narwhals._plan.struct import ExprStructNamespace from narwhals.typing import ( FillNullStrategy, NativeSeries, @@ -466,6 +469,24 @@ def name(self) -> ExprNameNamespace: return ExprNameNamespace(_expr=self) + @property + def cat(self) -> ExprCatNamespace: + from narwhals._plan.categorical import ExprCatNamespace + + return ExprCatNamespace(_expr=self) + + @property + def struct(self) -> ExprStructNamespace: + from narwhals._plan.struct import ExprStructNamespace + + return ExprStructNamespace(_expr=self) + + @property + def list(self) -> ExprListNamespace: + from narwhals._plan.lists import ExprListNamespace + + return ExprListNamespace(_expr=self) + class DummySelector(DummyExpr): """Selectors placeholder. diff --git a/narwhals/_plan/lists.py b/narwhals/_plan/lists.py index 88da2223ab..f64b98235c 100644 --- a/narwhals/_plan/lists.py +++ b/narwhals/_plan/lists.py @@ -1,8 +1,13 @@ from __future__ import annotations -from narwhals._plan.common import Function +from typing import TYPE_CHECKING + +from narwhals._plan.common import ExprNamespace, Function, IRNamespace from narwhals._plan.options import FunctionOptions +if TYPE_CHECKING: + from narwhals._plan.dummy import DummyExpr + class ListFunction(Function): ... @@ -16,3 +21,17 @@ def function_options(self) -> FunctionOptions: def __repr__(self) -> str: return "list.len" + + +class IRListNamespace(IRNamespace): + def len(self) -> Len: + return Len() + + +class ExprListNamespace(ExprNamespace[IRListNamespace]): + @property + def _ir_namespace(self) -> type[IRListNamespace]: + return IRListNamespace + + def len(self) -> DummyExpr: + return self._to_narwhals(self._ir.len().to_function_expr(self._expr._ir)) diff --git a/narwhals/_plan/struct.py b/narwhals/_plan/struct.py index fa7dd3dc07..6ad3770cf4 100644 --- a/narwhals/_plan/struct.py +++ b/narwhals/_plan/struct.py @@ -1,8 +1,13 @@ from __future__ import annotations -from narwhals._plan.common import Function +from typing import TYPE_CHECKING + +from narwhals._plan.common import ExprNamespace, Function, IRNamespace from narwhals._plan.options import FunctionOptions +if TYPE_CHECKING: + from narwhals._plan.dummy import DummyExpr + class StructFunction(Function): ... @@ -20,3 +25,17 @@ def function_options(self) -> FunctionOptions: def __repr__(self) -> str: return f"struct.field_by_name({self.name!r})" + + +class IRStructNamespace(IRNamespace): + def field(self, name: str) -> FieldByName: + return FieldByName(name=name) + + +class ExprStructNamespace(ExprNamespace[IRStructNamespace]): + @property + def _ir_namespace(self) -> type[IRStructNamespace]: + return IRStructNamespace + + def field(self, name: str) -> DummyExpr: + return self._to_narwhals(self._ir.field(name).to_function_expr(self._expr._ir)) From aee0a7e8976a1de0056b58881764d70084f1ab9e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 25 May 2025 14:21:50 +0100 Subject: [PATCH 122/368] feat: Add `dt` namespace --- narwhals/_plan/dummy.py | 7 ++ narwhals/_plan/temporal.py | 204 ++++++++++++++++++++++++++++++++++++- 2 files changed, 206 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index b7eac25457..7cf73cc3a6 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -35,6 +35,7 @@ from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.name import ExprNameNamespace from narwhals._plan.struct import ExprStructNamespace + from narwhals._plan.temporal import ExprDateTimeNamespace from narwhals.typing import ( FillNullStrategy, NativeSeries, @@ -481,6 +482,12 @@ def struct(self) -> ExprStructNamespace: return ExprStructNamespace(_expr=self) + @property + def dt(self) -> ExprDateTimeNamespace: + from narwhals._plan.temporal import ExprDateTimeNamespace + + return ExprDateTimeNamespace(_expr=self) + @property def list(self) -> ExprListNamespace: from narwhals._plan.lists import ExprListNamespace diff --git a/narwhals/_plan/temporal.py b/narwhals/_plan/temporal.py index a40ad2f63e..d0b4a2cd1c 100644 --- a/narwhals/_plan/temporal.py +++ b/narwhals/_plan/temporal.py @@ -1,13 +1,23 @@ from __future__ import annotations -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, Literal, cast -from narwhals._plan.common import Function +from narwhals._plan.common import ExprNamespace, Function, IRNamespace from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: + from typing_extensions import TypeAlias, TypeIs + + from narwhals._duration import IntervalUnit + from narwhals._plan.dummy import DummyExpr from narwhals.typing import TimeUnit +PolarsTimeUnit: TypeAlias = Literal["ns", "us", "ms"] + + +def _is_polars_time_unit(obj: Any) -> TypeIs[PolarsTimeUnit]: + return obj in {"ns", "us", "ms"} + class TemporalFunction(Function): @property @@ -119,10 +129,194 @@ class ConvertTimeZone(TemporalFunction): class Timestamp(TemporalFunction): __slots__ = ("time_unit",) - time_unit: TimeUnit + time_unit: PolarsTimeUnit + + @staticmethod + def from_time_unit(time_unit: TimeUnit, /) -> Timestamp: + if not _is_polars_time_unit(time_unit): + from typing import get_args + + msg = ( + "invalid `time_unit`" + f"\n\nExpected one of {get_args(PolarsTimeUnit)}, got {time_unit!r}." + ) + raise ValueError(msg) + return Timestamp(time_unit=time_unit) class Truncate(TemporalFunction): - __slots__ = ("every",) + __slots__ = ("multiple", "unit") + + multiple: int + unit: IntervalUnit + + @staticmethod + def from_string(every: str, /) -> Truncate: + from narwhals._duration import parse_interval_string + + multiple, unit = parse_interval_string(every) + return Truncate(multiple=multiple, unit=unit) + + +class IRDateTimeNamespace(IRNamespace): + def date(self) -> Date: + return Date() + + def year(self) -> Year: + return Year() + + def month(self) -> Month: + return Month() + + def day(self) -> Day: + return Day() + + def hour(self) -> Hour: + return Hour() + + def minute(self) -> Minute: + return Minute() + + def second(self) -> Second: + return Second() - every: str + def millisecond(self) -> Millisecond: + return Millisecond() + + def microsecond(self) -> Microsecond: + return Microsecond() + + def nanosecond(self) -> Nanosecond: + return Nanosecond() + + def ordinal_day(self) -> OrdinalDay: + return OrdinalDay() + + def weekday(self) -> WeekDay: + return WeekDay() + + def total_minutes(self) -> TotalMinutes: + return TotalMinutes() + + def total_seconds(self) -> TotalSeconds: + return TotalSeconds() + + def total_milliseconds(self) -> TotalMilliseconds: + return TotalMilliseconds() + + def total_microseconds(self) -> TotalMicroseconds: + return TotalMicroseconds() + + def total_nanoseconds(self) -> TotalNanoseconds: + return TotalNanoseconds() + + def to_string(self, format: str) -> ToString: + return ToString(format=format) + + def replace_time_zone(self, time_zone: str | None) -> ReplaceTimeZone: + return ReplaceTimeZone(time_zone=time_zone) + + def convert_time_zone(self, time_zone: str) -> ConvertTimeZone: + return ConvertTimeZone(time_zone=time_zone) + + def timestamp(self, time_unit: TimeUnit = "us") -> Timestamp: + return Timestamp.from_time_unit(time_unit) + + def truncate(self, every: str) -> Truncate: + return Truncate.from_string(every) + + +class ExprDateTimeNamespace(ExprNamespace[IRDateTimeNamespace]): + @property + def _ir_namespace(self) -> type[IRDateTimeNamespace]: + return IRDateTimeNamespace + + def date(self) -> DummyExpr: + return self._to_narwhals(self._ir.date().to_function_expr(self._expr._ir)) + + def year(self) -> DummyExpr: + return self._to_narwhals(self._ir.year().to_function_expr(self._expr._ir)) + + def month(self) -> DummyExpr: + return self._to_narwhals(self._ir.month().to_function_expr(self._expr._ir)) + + def day(self) -> DummyExpr: + return self._to_narwhals(self._ir.day().to_function_expr(self._expr._ir)) + + def hour(self) -> DummyExpr: + return self._to_narwhals(self._ir.hour().to_function_expr(self._expr._ir)) + + def minute(self) -> DummyExpr: + return self._to_narwhals(self._ir.minute().to_function_expr(self._expr._ir)) + + def second(self) -> DummyExpr: + return self._to_narwhals(self._ir.second().to_function_expr(self._expr._ir)) + + def millisecond(self) -> DummyExpr: + return self._to_narwhals(self._ir.millisecond().to_function_expr(self._expr._ir)) + + def microsecond(self) -> DummyExpr: + return self._to_narwhals(self._ir.microsecond().to_function_expr(self._expr._ir)) + + def nanosecond(self) -> DummyExpr: + return self._to_narwhals(self._ir.nanosecond().to_function_expr(self._expr._ir)) + + def ordinal_day(self) -> DummyExpr: + return self._to_narwhals(self._ir.ordinal_day().to_function_expr(self._expr._ir)) + + def weekday(self) -> DummyExpr: + return self._to_narwhals(self._ir.weekday().to_function_expr(self._expr._ir)) + + def total_minutes(self) -> DummyExpr: + return self._to_narwhals( + self._ir.total_minutes().to_function_expr(self._expr._ir) + ) + + def total_seconds(self) -> DummyExpr: + return self._to_narwhals( + self._ir.total_seconds().to_function_expr(self._expr._ir) + ) + + def total_milliseconds(self) -> DummyExpr: + return self._to_narwhals( + self._ir.total_milliseconds().to_function_expr(self._expr._ir) + ) + + def total_microseconds(self) -> DummyExpr: + return self._to_narwhals( + self._ir.total_microseconds().to_function_expr(self._expr._ir) + ) + + def total_nanoseconds(self) -> DummyExpr: + return self._to_narwhals( + self._ir.total_nanoseconds().to_function_expr(self._expr._ir) + ) + + def to_string(self, format: str) -> DummyExpr: + return self._to_narwhals( + self._ir.to_string(format=format).to_function_expr(self._expr._ir) + ) + + def replace_time_zone(self, time_zone: str | None) -> DummyExpr: + return self._to_narwhals( + self._ir.replace_time_zone(time_zone=time_zone).to_function_expr( + self._expr._ir + ) + ) + + def convert_time_zone(self, time_zone: str) -> DummyExpr: + return self._to_narwhals( + self._ir.convert_time_zone(time_zone=time_zone).to_function_expr( + self._expr._ir + ) + ) + + def timestamp(self, time_unit: TimeUnit = "us") -> DummyExpr: + return self._to_narwhals( + self._ir.timestamp(time_unit=time_unit).to_function_expr(self._expr._ir) + ) + + def truncate(self, every: str) -> DummyExpr: + return self._to_narwhals( + self._ir.truncate(every=every).to_function_expr(self._expr._ir) + ) From 72c33ce57d48768098ba9509604b073b891c97da Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 25 May 2025 15:06:51 +0100 Subject: [PATCH 123/368] feat: Add `str` namespace - Not fully convinced on this namespace abstraction - At the very least, something to handle the `Function` -> `FunctionExpr` would reduce the boilerplate a lot --- narwhals/_plan/dummy.py | 7 ++ narwhals/_plan/strings.py | 174 +++++++++++++++++++++++++++++++++----- 2 files changed, 160 insertions(+), 21 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 7cf73cc3a6..7382e32786 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -34,6 +34,7 @@ from narwhals._plan.lists import ExprListNamespace from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.name import ExprNameNamespace + from narwhals._plan.strings import ExprStringNamespace from narwhals._plan.struct import ExprStructNamespace from narwhals._plan.temporal import ExprDateTimeNamespace from narwhals.typing import ( @@ -494,6 +495,12 @@ def list(self) -> ExprListNamespace: return ExprListNamespace(_expr=self) + @property + def str(self) -> ExprStringNamespace: + from narwhals._plan.strings import ExprStringNamespace + + return ExprStringNamespace(_expr=self) + class DummySelector(DummyExpr): """Selectors placeholder. diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/strings.py index ef64a3b79a..bb675e0896 100644 --- a/narwhals/_plan/strings.py +++ b/narwhals/_plan/strings.py @@ -1,8 +1,13 @@ from __future__ import annotations -from narwhals._plan.common import Function +from typing import TYPE_CHECKING + +from narwhals._plan.common import ExprNamespace, Function, IRNamespace from narwhals._plan.options import FunctionFlags, FunctionOptions +if TYPE_CHECKING: + from narwhals._plan.dummy import DummyExpr + class StringFunction(Function): @property @@ -30,8 +35,9 @@ def __repr__(self) -> str: class Contains(StringFunction): - __slots__ = ("literal",) + __slots__ = ("literal", "pattern") + pattern: str literal: bool def __repr__(self) -> str: @@ -39,6 +45,10 @@ def __repr__(self) -> str: class EndsWith(StringFunction): + __slots__ = ("suffix",) + + suffix: str + def __repr__(self) -> str: return "str.ends_with" @@ -49,9 +59,12 @@ def __repr__(self) -> str: class Replace(StringFunction): - __slots__ = ("literal",) + __slots__ = ("literal", "n", "pattern", "value") + pattern: str + value: str literal: bool + n: int def __repr__(self) -> str: return "str.replace" @@ -63,8 +76,10 @@ class ReplaceAll(StringFunction): https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/function_expr/strings.rs#L65-L70 """ - __slots__ = ("literal",) + __slots__ = ("literal", "pattern", "value") + pattern: str + value: str literal: bool def __repr__(self) -> str: @@ -88,35 +103,29 @@ def __repr__(self) -> str: return "str.slice" -class Head(StringFunction): - __slots__ = ("n",) - - n: int - - def __repr__(self) -> str: - return "str.head" - - -class Tail(StringFunction): - __slots__ = ("n",) - - n: int - - def __repr__(self) -> str: - return "str.tail" +class Split(StringFunction): + __slots__ = ("by",) + by: str -class Split(StringFunction): def __repr__(self) -> str: return "str.split" class StartsWith(StringFunction): + __slots__ = ("prefix",) + + prefix: str + def __repr__(self) -> str: return "str.startswith" class StripChars(StringFunction): + __slots__ = ("characters",) + + characters: str | None + def __repr__(self) -> str: return "str.strip_chars" @@ -133,6 +142,9 @@ class ToDatetime(StringFunction): format: str | None + def __repr__(self) -> str: + return "str.to_datetime" + class ToLowercase(StringFunction): def __repr__(self) -> str: @@ -142,3 +154,123 @@ def __repr__(self) -> str: class ToUppercase(StringFunction): def __repr__(self) -> str: return "str.to_uppercase" + + +class IRStringNamespace(IRNamespace): + def len_chars(self) -> LenChars: + return LenChars() + + def replace( + self, pattern: str, value: str, *, literal: bool = False, n: int = 1 + ) -> Replace: + return Replace(pattern=pattern, value=value, literal=literal, n=n) + + def replace_all( + self, pattern: str, value: str, *, literal: bool = False + ) -> ReplaceAll: + return ReplaceAll(pattern=pattern, value=value, literal=literal) + + def strip_chars(self, characters: str | None = None) -> StripChars: + return StripChars(characters=characters) + + def starts_with(self, prefix: str) -> StartsWith: + return StartsWith(prefix=prefix) + + def ends_with(self, suffix: str) -> EndsWith: + return EndsWith(suffix=suffix) + + def contains(self, pattern: str, *, literal: bool = False) -> Contains: + return Contains(pattern=pattern, literal=literal) + + def slice(self, offset: int, length: int | None = None) -> Slice: + return Slice(offset=offset, length=length) + + def head(self, n: int = 5) -> Slice: + return self.slice(0, n) + + def tail(self, n: int = 5) -> Slice: + return self.slice(-n) + + def split(self, by: str) -> Split: + return Split(by=by) + + def to_datetime(self, format: str | None = None) -> ToDatetime: + return ToDatetime(format=format) + + def to_lowercase(self) -> ToUppercase: + return ToUppercase() + + def to_uppercase(self) -> ToLowercase: + return ToLowercase() + + +class ExprStringNamespace(ExprNamespace[IRStringNamespace]): + @property + def _ir_namespace(self) -> type[IRStringNamespace]: + return IRStringNamespace + + def len_chars(self) -> DummyExpr: + return self._to_narwhals(self._ir.len_chars().to_function_expr(self._expr._ir)) + + def replace( + self, pattern: str, value: str, *, literal: bool = False, n: int = 1 + ) -> DummyExpr: + return self._to_narwhals( + self._ir.replace(pattern, value, literal=literal, n=n).to_function_expr( + self._expr._ir + ) + ) + + def replace_all( + self, pattern: str, value: str, *, literal: bool = False + ) -> DummyExpr: + return self._to_narwhals( + self._ir.replace_all(pattern, value, literal=literal).to_function_expr( + self._expr._ir + ) + ) + + def strip_chars(self, characters: str | None = None) -> DummyExpr: + return self._to_narwhals( + self._ir.strip_chars(characters).to_function_expr(self._expr._ir) + ) + + def starts_with(self, prefix: str) -> DummyExpr: + return self._to_narwhals( + self._ir.starts_with(prefix).to_function_expr(self._expr._ir) + ) + + def ends_with(self, suffix: str) -> DummyExpr: + return self._to_narwhals( + self._ir.ends_with(suffix).to_function_expr(self._expr._ir) + ) + + def contains(self, pattern: str, *, literal: bool = False) -> DummyExpr: + return self._to_narwhals( + self._ir.contains(pattern, literal=literal).to_function_expr(self._expr._ir) + ) + + def slice(self, offset: int, length: int | None = None) -> DummyExpr: + return self._to_narwhals( + self._ir.slice(offset, length).to_function_expr(self._expr._ir) + ) + + def head(self, n: int = 5) -> DummyExpr: + return self._to_narwhals(self._ir.head(n).to_function_expr(self._expr._ir)) + + def tail(self, n: int = 5) -> DummyExpr: + return self._to_narwhals(self._ir.tail(n).to_function_expr(self._expr._ir)) + + def split(self, by: str) -> DummyExpr: + return self._to_narwhals(self._ir.split(by).to_function_expr(self._expr._ir)) + + def to_datetime(self, format: str | None = None) -> DummyExpr: + return self._to_narwhals( + self._ir.to_datetime(format).to_function_expr(self._expr._ir) + ) + + def to_lowercase(self) -> DummyExpr: + return self._to_narwhals(self._ir.to_lowercase().to_function_expr(self._expr._ir)) + + def to_uppercase(self) -> DummyExpr: + return self._to_narwhals(self._ir.to_uppercase().to_function_expr(self._expr._ir)) From 35fb57821c3c155b862db56c92356df09799023c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 25 May 2025 16:27:41 +0100 Subject: [PATCH 124/368] =?UTF-8?q?feat:=20Implement=20chained=20`when-the?= =?UTF-8?q?n-otherwise`=20=F0=9F=A5=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Related (https://github.com/narwhals-dev/narwhals/issues/668#issuecomment-2904162550) - This would be how we should model it in *actual narwhals* - Almost identical to the `rust` version - See this in particular (https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/arity.rs#L89-L130) --- narwhals/_plan/demo.py | 29 ++++++++++++- narwhals/_plan/expr.py | 22 +++++++++- narwhals/_plan/when_then.py | 84 +++++++++++++++++++++++++++++++++++++ 3 files changed, 133 insertions(+), 2 deletions(-) create mode 100644 narwhals/_plan/when_then.py diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index afc54b2a3c..5a2de45b57 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -9,11 +9,12 @@ expr_parsing as parse, functions as F, # noqa: N812 ) -from narwhals._plan.common import ExprIR, IntoExpr, is_non_nested_literal +from narwhals._plan.common import ExprIR, IntoExpr, is_expr, is_non_nested_literal from narwhals._plan.dummy import DummySeries from narwhals._plan.expr import All, Column, Columns, IndexColumns, Len, Nth from narwhals._plan.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.strings import ConcatHorizontal +from narwhals._plan.when_then import When from narwhals.dtypes import DType from narwhals.exceptions import OrderDependentExprError from narwhals.utils import Version, flatten @@ -131,6 +132,32 @@ def concat_str( ) +def when(*predicates: IntoExpr | t.Iterable[IntoExpr]) -> When: + """Start a `when-then-otherwise` expression. + + Examples: + >>> from narwhals._plan import demo as nwd + + >>> when_then_many = ( + ... nwd.when(nwd.col("x") == "a") + ... .then(1) + ... .when(nwd.col("x") == "b") + ... .then(2) + ... .when(nwd.col("x") == "c") + ... .then(3) + ... .otherwise(4) + ... ) + >>> when_then_many + Narwhals DummyExpr (main): + .when([(col('x')) == (lit(str: a))]).then(lit(int: 1)).otherwise(.when([(col('x')) == (lit(str: b))]).then(lit(int: 2)).otherwise(.when([(col('x')) == (lit(str: c))]).then(lit(int: 3)).otherwise(lit(int: 4)))) + """ + if builtins.len(predicates) == 1 and is_expr(predicates[0]): + expr = predicates[0] + else: + expr = all_horizontal(*predicates) + return When._from_expr(expr) + + def _is_order_enforcing_previous(obj: t.Any) -> TypeIs[SortBy]: """In theory, we could add other nodes to this check.""" from narwhals._plan.expr import SortBy diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index fb0e179b63..906094fa30 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -13,7 +13,7 @@ import typing as t from narwhals._plan.aggregation import Agg, OrderableAgg -from narwhals._plan.common import ExprIR, SelectorIR +from narwhals._plan.common import ExprIR, SelectorIR, _field_str from narwhals._plan.name import KeepName, RenameAlias from narwhals._plan.typing import ( FunctionT, @@ -475,3 +475,23 @@ class Ternary(ExprIR): Deferring this for now. """ + + __slots__ = ("falsy", "predicate", "truthy") + + predicate: ExprIR + truthy: ExprIR + falsy: ExprIR + + def __str__(self) -> str: + # NOTE: Default slot ordering made it difficult to read + fields = ( + _field_str("predicate", self.predicate), + _field_str("truthy", self.truthy), + _field_str("falsy", self.falsy), + ) + return f"{type(self).__name__}({', '.join(fields)})" + + def __repr__(self) -> str: + return ( + f".when({self.predicate!r}).then({self.truthy!r}).otherwise({self.falsy!r})" + ) diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py new file mode 100644 index 0000000000..5c9690bc0c --- /dev/null +++ b/narwhals/_plan/when_then.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from narwhals._plan.common import Immutable +from narwhals._plan.expr_parsing import parse_into_expr_ir + +if TYPE_CHECKING: + from narwhals._plan.common import ExprIR, IntoExpr, Seq + from narwhals._plan.dummy import DummyExpr + from narwhals._plan.expr import Ternary + + +class When(Immutable): + __slots__ = ("condition",) + + condition: ExprIR + + def then(self, expr: IntoExpr, /) -> Then: + return Then(condition=self.condition, statement=parse_into_expr_ir(expr)) + + @staticmethod + def _from_expr(expr: DummyExpr, /) -> When: + return When(condition=expr._ir) + + +class Then(Immutable): + __slots__ = ("condition", "statement") + + condition: ExprIR + statement: ExprIR + + def when(self, condition: IntoExpr, /) -> ChainedWhen: + return ChainedWhen( + conditions=(self.condition, parse_into_expr_ir(condition)), + statements=(self.statement,), + ) + + def otherwise(self, statement: IntoExpr, /) -> DummyExpr: + return ternary_expr( + self.condition, self.condition, parse_into_expr_ir(statement) + ).to_narwhals() + + +class ChainedWhen(Immutable): + __slots__ = ("conditions", "statements") + + conditions: Seq[ExprIR] + statements: Seq[ExprIR] + + def then(self, statement: IntoExpr, /) -> ChainedThen: + return ChainedThen( + conditions=self.conditions, + statements=(*self.statements, parse_into_expr_ir(statement)), + ) + + +class ChainedThen(Immutable): + """https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/arity.rs#L89-L130.""" + + __slots__ = ("conditions", "statements") + + conditions: Seq[ExprIR] + statements: Seq[ExprIR] + + def when(self, condition: IntoExpr, /) -> ChainedWhen: + return ChainedWhen( + conditions=(*self.conditions, parse_into_expr_ir(condition)), + statements=self.statements, + ) + + def otherwise(self, statement: IntoExpr, /) -> DummyExpr: + otherwise = parse_into_expr_ir(statement) + it_conditions = reversed(self.conditions) + it_statements = reversed(self.statements) + for e in it_conditions: + otherwise = ternary_expr(e, next(it_statements), otherwise) + return otherwise.to_narwhals() + + +def ternary_expr(predicate: ExprIR, truthy: ExprIR, falsy: ExprIR, /) -> Ternary: + from narwhals._plan.expr import Ternary + + return Ternary(predicate=predicate, truthy=truthy, falsy=falsy) From 563076d5e22ce9faccfdce5f99ada865bd13f9db Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 25 May 2025 18:07:23 +0100 Subject: [PATCH 125/368] feat: Support optional `.otherwise(...)` Somewhat of typing nightmare but gets the job done for now --- narwhals/_plan/demo.py | 4 ++++ narwhals/_plan/when_then.py | 46 ++++++++++++++++++++++++++++++------- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index 5a2de45b57..ac0e25887f 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -150,6 +150,10 @@ def when(*predicates: IntoExpr | t.Iterable[IntoExpr]) -> When: >>> when_then_many Narwhals DummyExpr (main): .when([(col('x')) == (lit(str: a))]).then(lit(int: 1)).otherwise(.when([(col('x')) == (lit(str: b))]).then(lit(int: 2)).otherwise(.when([(col('x')) == (lit(str: c))]).then(lit(int: 3)).otherwise(lit(int: 4)))) + >>> + >>> nwd.when(nwd.col("y") == "b").then(1) + Narwhals DummyExpr (main): + .when([(col('y')) == (lit(str: b))]).then(lit(int: 1)).otherwise(lit(null)) """ if builtins.len(predicates) == 1 and is_expr(predicates[0]): expr = predicates[0] diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index 5c9690bc0c..40eaae9971 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -2,12 +2,12 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import Immutable +from narwhals._plan.common import Immutable, is_expr +from narwhals._plan.dummy import DummyExpr from narwhals._plan.expr_parsing import parse_into_expr_ir if TYPE_CHECKING: from narwhals._plan.common import ExprIR, IntoExpr, Seq - from narwhals._plan.dummy import DummyExpr from narwhals._plan.expr import Ternary @@ -24,7 +24,7 @@ def _from_expr(expr: DummyExpr, /) -> When: return When(condition=expr._ir) -class Then(Immutable): +class Then(Immutable, DummyExpr): __slots__ = ("condition", "statement") condition: ExprIR @@ -37,9 +37,23 @@ def when(self, condition: IntoExpr, /) -> ChainedWhen: ) def otherwise(self, statement: IntoExpr, /) -> DummyExpr: - return ternary_expr( - self.condition, self.condition, parse_into_expr_ir(statement) - ).to_narwhals() + return self._from_ir(self._otherwise(statement)) + + def _otherwise(self, statement: IntoExpr = None, /) -> ExprIR: + return ternary_expr(self.condition, self.statement, parse_into_expr_ir(statement)) + + @property + def _ir(self) -> ExprIR: # type: ignore[override] + return self._otherwise() + + @classmethod + def _from_ir(cls, ir: ExprIR, /) -> DummyExpr: # type: ignore[override] + return DummyExpr._from_ir(ir) + + def __eq__(self, value: object) -> DummyExpr | bool: # type: ignore[override] + if is_expr(value): + return super(DummyExpr, self).__eq__(value) + return super().__eq__(value) class ChainedWhen(Immutable): @@ -55,7 +69,7 @@ def then(self, statement: IntoExpr, /) -> ChainedThen: ) -class ChainedThen(Immutable): +class ChainedThen(Immutable, DummyExpr): """https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/arity.rs#L89-L130.""" __slots__ = ("conditions", "statements") @@ -70,12 +84,28 @@ def when(self, condition: IntoExpr, /) -> ChainedWhen: ) def otherwise(self, statement: IntoExpr, /) -> DummyExpr: + return self._from_ir(self._otherwise(statement)) + + def _otherwise(self, statement: IntoExpr = None, /) -> ExprIR: otherwise = parse_into_expr_ir(statement) it_conditions = reversed(self.conditions) it_statements = reversed(self.statements) for e in it_conditions: otherwise = ternary_expr(e, next(it_statements), otherwise) - return otherwise.to_narwhals() + return otherwise + + @property + def _ir(self) -> ExprIR: # type: ignore[override] + return self._otherwise() + + @classmethod + def _from_ir(cls, ir: ExprIR, /) -> DummyExpr: # type: ignore[override] + return DummyExpr._from_ir(ir) + + def __eq__(self, value: object) -> DummyExpr | bool: # type: ignore[override] + if is_expr(value): + return super(DummyExpr, self).__eq__(value) + return super().__eq__(value) def ternary_expr(predicate: ExprIR, truthy: ExprIR, falsy: ExprIR, /) -> Ternary: From b80e0e1d0268b3bcbf42ad8625fd59180ce51027 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 25 May 2025 20:51:35 +0100 Subject: [PATCH 126/368] feat: Infer `DType` in `lit` --- narwhals/_plan/common.py | 19 +++++++++++++++++++ narwhals/_plan/demo.py | 12 +++++++++--- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 0c6b575733..fd8c61d012 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -22,6 +22,7 @@ from narwhals._plan.lists import IRListNamespace from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.options import FunctionOptions + from narwhals.dtypes import DType from narwhals.typing import NonNestedLiteral else: @@ -315,3 +316,21 @@ def is_iterable_reject(obj: Any) -> TypeIs[str | bytes | DummySeries]: from narwhals._plan.dummy import DummySeries return isinstance(obj, (str, bytes, DummySeries)) + + +def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) -> DType: + dtypes = version.dtypes + mapping: dict[type[NonNestedLiteral], type[DType]] = { + int: dtypes.Int64, + float: dtypes.Float64, + str: dtypes.String, + bool: dtypes.Boolean, + dt.datetime: dtypes.Datetime, + dt.date: dtypes.Date, + dt.time: dtypes.Time, + dt.timedelta: dtypes.Duration, + bytes: dtypes.Binary, + Decimal: dtypes.Decimal, + type(None): dtypes.Unknown, + } + return mapping.get(type(obj), dtypes.Unknown)() diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index ac0e25887f..abdd6aa9ef 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -9,7 +9,13 @@ expr_parsing as parse, functions as F, # noqa: N812 ) -from narwhals._plan.common import ExprIR, IntoExpr, is_expr, is_non_nested_literal +from narwhals._plan.common import ( + ExprIR, + IntoExpr, + is_expr, + is_non_nested_literal, + py_to_narwhals_dtype, +) from narwhals._plan.dummy import DummySeries from narwhals._plan.expr import All, Column, Columns, IndexColumns, Len, Nth from narwhals._plan.literal import ScalarLiteral, SeriesLiteral @@ -52,11 +58,11 @@ def lit( ) -> DummyExpr: if isinstance(value, DummySeries): return SeriesLiteral(value=value).to_literal().to_narwhals() - if dtype is None or not isinstance(dtype, DType): - dtype = Version.MAIN.dtypes.Unknown() if not is_non_nested_literal(value): msg = f"{type(value).__name__!r} is not supported in `nw.lit`, got: {value!r}." raise TypeError(msg) + if dtype is None or not isinstance(dtype, DType): + dtype = py_to_narwhals_dtype(value, Version.MAIN) return ScalarLiteral(value=value, dtype=dtype).to_literal().to_narwhals() From 92694ce439ea6882e8ea5994142d37bb500042f0 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 26 May 2025 19:27:18 +0100 Subject: [PATCH 127/368] feat(DRAFT): Mock up `to_compliant` as an adapter An experiment towards (https://github.com/narwhals-dev/narwhals/pull/2572#issuecomment-2897164339) --- narwhals/_namespace.py | 4 +++- narwhals/_plan/common.py | 15 ++++--------- narwhals/_plan/dummy.py | 4 ++++ narwhals/_plan/expr.py | 11 ++++++++++ narwhals/_plan/literal.py | 26 ++++++++++++++++++----- narwhals/_plan/typing.py | 11 ++++++++++ tests/plan/to_compliant_test.py | 37 +++++++++++++++++++++++++++++++++ 7 files changed, 91 insertions(+), 17 deletions(-) create mode 100644 tests/plan/to_compliant_test.py diff --git a/narwhals/_namespace.py b/narwhals/_namespace.py index 0c23d378c7..cf8137c159 100644 --- a/narwhals/_namespace.py +++ b/narwhals/_namespace.py @@ -203,7 +203,9 @@ def from_backend(cls, backend: EagerAllowed, /) -> EagerAllowedNamespace: ... @overload @classmethod - def from_backend(cls, backend: ModuleType, /) -> Namespace[CompliantNamespaceAny]: ... + def from_backend( + cls, backend: IntoBackend, / + ) -> Namespace[CompliantNamespaceAny]: ... @classmethod def from_backend( diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index fd8c61d012..ff4926da4c 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -4,7 +4,7 @@ from decimal import Decimal from typing import TYPE_CHECKING, Generic, TypeVar -from narwhals._plan.typing import IRNamespaceT +from narwhals._plan.typing import ExprT, IRNamespaceT, Ns from narwhals.utils import Version if TYPE_CHECKING: @@ -12,12 +12,7 @@ from typing_extensions import Never, Self, TypeAlias, TypeIs, dataclass_transform - from narwhals._plan.dummy import ( - DummyCompliantExpr, - DummyExpr, - DummySelector, - DummySeries, - ) + from narwhals._plan.dummy import DummyExpr, DummySelector, DummySeries from narwhals._plan.expr import FunctionExpr from narwhals._plan.lists import IRListNamespace from narwhals._plan.meta import IRMetaNamespace @@ -146,10 +141,8 @@ def to_narwhals(self, version: Version = Version.MAIN) -> DummyExpr: return dummy.DummyExpr._from_ir(self) return dummy.DummyExprV1._from_ir(self) - def to_compliant(self, version: Version = Version.MAIN) -> DummyCompliantExpr: - from narwhals._plan.dummy import DummyCompliantExpr - - return DummyCompliantExpr._from_ir(self, version) + def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: + raise NotImplementedError @property def is_scalar(self) -> bool: diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 7382e32786..a135959fd1 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -37,6 +37,7 @@ from narwhals._plan.strings import ExprStringNamespace from narwhals._plan.struct import ExprStructNamespace from narwhals._plan.temporal import ExprDateTimeNamespace + from narwhals._plan.typing import ExprT, Ns from narwhals.typing import ( FillNullStrategy, NativeSeries, @@ -62,6 +63,9 @@ def _from_ir(cls, ir: ExprIR, /) -> Self: obj._ir = ir return obj + def _to_compliant(self, plx: Ns[ExprT], /) -> ExprT: + return self._ir.to_compliant(plx) + @property def version(self) -> Version: return self._version diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 906094fa30..39b82fec92 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -16,9 +16,11 @@ from narwhals._plan.common import ExprIR, SelectorIR, _field_str from narwhals._plan.name import KeepName, RenameAlias from narwhals._plan.typing import ( + ExprT, FunctionT, LeftSelectorT, LeftT, + Ns, OperatorT, RightSelectorT, RightT, @@ -97,6 +99,9 @@ class Column(ExprIR): def __repr__(self) -> str: return f"col({self.name!r})" + def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: + return plx.col(self.name) + class Columns(ExprIR): __slots__ = ("names",) @@ -106,6 +111,9 @@ class Columns(ExprIR): def __repr__(self) -> str: return f"cols({list(self.names)!r})" + def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: + return plx.col(*self.names) + class Literal(ExprIR): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L81.""" @@ -129,6 +137,9 @@ def name(self) -> str: def __repr__(self) -> str: return f"lit({self.value!r})" + def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: + return plx.lit(self.value.unwrap(), self.dtype) + class _BinaryOp(ExprIR, t.Generic[LeftT, OperatorT, RightT]): __slots__ = ("left", "op", "right") diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index f9a70219d4..03b0fef13d 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Generic from narwhals._plan.common import Immutable @@ -10,8 +10,15 @@ from narwhals.dtypes import DType from narwhals.typing import NonNestedLiteral +from narwhals._typing_compat import TypeVar -class LiteralValue(Immutable): +T = TypeVar("T", default=Any) +NonNestedLiteralT = TypeVar( + "NonNestedLiteralT", bound="NonNestedLiteral", default="NonNestedLiteral" +) + + +class LiteralValue(Immutable, Generic[T]): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/plans/lit.rs#L67-L73.""" @property @@ -31,11 +38,14 @@ def to_literal(self) -> Literal: return Literal(value=self) + def unwrap(self) -> T: + raise NotImplementedError + -class ScalarLiteral(LiteralValue): +class ScalarLiteral(LiteralValue[NonNestedLiteralT]): __slots__ = ("dtype", "value") - value: NonNestedLiteral + value: NonNestedLiteralT dtype: DType @property @@ -47,8 +57,11 @@ def __repr__(self) -> str: return f"{type(self.value).__name__}: {self.value!s}" return "null" + def unwrap(self) -> NonNestedLiteralT: + return self.value + -class SeriesLiteral(LiteralValue): +class SeriesLiteral(LiteralValue["DummySeries"]): """We already need this. https://github.com/narwhals-dev/narwhals/blob/e51eba891719a5eb1f7ce91c02a477af39c0baee/narwhals/_expression_parsing.py#L96-L97 @@ -69,6 +82,9 @@ def name(self) -> str: def __repr__(self) -> str: return "Series" + def unwrap(self) -> DummySeries: + return self.value + class RangeLiteral(LiteralValue): """Don't need yet, but might push forward the discussions. diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index add10565cb..5f21589132 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -5,6 +5,10 @@ from narwhals._typing_compat import TypeVar if t.TYPE_CHECKING: + from typing_extensions import TypeAlias + + from narwhals._compliant import CompliantNamespace as Namespace + from narwhals._compliant.typing import CompliantExprAny from narwhals._plan import operators as ops from narwhals._plan.common import ExprIR, Function, IRNamespace, SelectorIR from narwhals._plan.functions import RollingWindow @@ -24,3 +28,10 @@ "SelectorOperatorT", bound="ops.SelectorOperator", default="ops.SelectorOperator" ) IRNamespaceT = TypeVar("IRNamespaceT", bound="IRNamespace") +# NOTE: Shorter aliases of `_compliant.typing` +# - Aiming to try and preserve the types as much as possible +# - Recursion between `Expr` and `Frame` is an issue +Expr: TypeAlias = "CompliantExprAny" +ExprT = TypeVar("ExprT", bound="Expr") +Ns: TypeAlias = "Namespace[t.Any, ExprT]" +"""A `CompliantNamespace`, ignoring the `Frame` type.""" diff --git a/tests/plan/to_compliant_test.py b/tests/plan/to_compliant_test.py new file mode 100644 index 0000000000..c6ee5c6706 --- /dev/null +++ b/tests/plan/to_compliant_test.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +import narwhals._plan.demo as nwd +from narwhals.utils import Version +from tests.namespace_test import backends + +if TYPE_CHECKING: + from narwhals._namespace import BackendName + from narwhals._plan.dummy import DummyExpr + + +def _ids_ir(expr: DummyExpr) -> str: + return repr(expr._ir) + + +@pytest.mark.parametrize( + ("expr"), + [ + nwd.col("a"), + nwd.col("a", "b"), + nwd.lit(1), + nwd.lit(2.0), + nwd.lit(None, nw.String()), + ], + ids=_ids_ir, +) +@backends +def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None: + pytest.importorskip(backend) + namespace = Version.MAIN.namespace.from_backend(backend).compliant + compliant_expr = expr._to_compliant(namespace) + assert isinstance(compliant_expr, namespace._expr) From 0eada48cd2f2dbce554cc59fbee43ce8d2355a1d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 26 May 2025 21:01:34 +0100 Subject: [PATCH 128/368] feat(DRAFT): Add experimental `@singledispatch` version Somewhat of a bridge between the current `Compliant*` stuff and (https://github.com/narwhals-dev/narwhals/issues/2571#issuecomment-2907776908) --- narwhals/_plan/dummy.py | 3 +++ narwhals/_plan/impl_arrow.py | 50 ++++++++++++++++++++++++++++++++++++ narwhals/_plan/literal.py | 10 ++++++++ 3 files changed, 63 insertions(+) create mode 100644 narwhals/_plan/impl_arrow.py diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index a135959fd1..0d6c61f7e7 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -631,6 +631,9 @@ def from_native(cls, native: NativeSeries, /) -> Self: obj._compliant = DummyCompliantSeries.from_native(native, cls._version) return obj + def to_native(self) -> NativeSeries: + return self._compliant._native + class DummySeriesV1(DummySeries): _version: t.ClassVar[Version] = Version.V1 diff --git a/narwhals/_plan/impl_arrow.py b/narwhals/_plan/impl_arrow.py new file mode 100644 index 0000000000..00c83d149b --- /dev/null +++ b/narwhals/_plan/impl_arrow.py @@ -0,0 +1,50 @@ +"""Translating `ExprIR` nodes for pyarrow. + +Acting like a trimmed down, native-only `CompliantExpr`, `CompliantSeries`, etc. +""" + +from __future__ import annotations + +import typing as t +from functools import singledispatch + +from narwhals._plan import expr +from narwhals._plan.literal import is_scalar_literal, is_series_literal + +if t.TYPE_CHECKING: + import pyarrow as pa + from typing_extensions import TypeAlias + + from narwhals._plan.common import ExprIR + + NativeFrame: TypeAlias = pa.Table + NativeSeries: TypeAlias = pa.ChunkedArray[t.Any] + Evaluated: TypeAlias = t.Sequence[NativeSeries] + + +@singledispatch +def evaluate(node: ExprIR, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.Column) +def col(node: expr.Column, frame: NativeFrame) -> Evaluated: + return [frame.column(node.name)] + + +@evaluate.register(expr.Columns) +def cols(node: expr.Columns, frame: NativeFrame) -> Evaluated: + return frame.select(list(node.names)).columns + + +@evaluate.register(expr.Literal) +def lit(node: expr.Literal, frame: NativeFrame) -> Evaluated: # noqa: ARG001 + import pyarrow as pa + + if is_scalar_literal(node.value): + return [pa.chunked_array([node.value.unwrap()])] + elif is_series_literal(node.value): + ca = node.value.unwrap().to_native() + return [t.cast("NativeSeries", ca)] + else: + raise NotImplementedError(type(node.value)) diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index 03b0fef13d..ac3c7d70b8 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -5,6 +5,8 @@ from narwhals._plan.common import Immutable if TYPE_CHECKING: + from typing_extensions import TypeIs + from narwhals._plan.dummy import DummySeries from narwhals._plan.expr import Literal from narwhals.dtypes import DType @@ -98,3 +100,11 @@ class RangeLiteral(LiteralValue): low: int high: int dtype: DType + + +def is_scalar_literal(obj: Any) -> TypeIs[ScalarLiteral]: + return isinstance(obj, ScalarLiteral) + + +def is_series_literal(obj: Any) -> TypeIs[SeriesLiteral]: + return isinstance(obj, SeriesLiteral) From 9273510f78e18875fe117b23771dcf9ce0fa3468 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 26 May 2025 22:18:13 +0100 Subject: [PATCH 129/368] feat: Map each operator to an `operator` function --- narwhals/_plan/operators.py | 65 +++++++++++++++++++++++++++---------- narwhals/_plan/typing.py | 1 + 2 files changed, 48 insertions(+), 18 deletions(-) diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index 5d43c4b514..2577fb2bfd 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -1,19 +1,29 @@ from __future__ import annotations +import operator from typing import TYPE_CHECKING +from narwhals._plan.common import Immutable from narwhals._plan.expr import BinarySelector if TYPE_CHECKING: + from typing import Any, ClassVar + from typing_extensions import Self from narwhals._plan.expr import BinaryExpr, BinarySelector - from narwhals._plan.typing import LeftSelectorT, LeftT, RightSelectorT, RightT - -from narwhals._plan.common import Immutable + from narwhals._plan.typing import ( + LeftSelectorT, + LeftT, + OperatorFn, + RightSelectorT, + RightT, + ) class Operator(Immutable): + _op: ClassVar[OperatorFn] + def __repr__(self) -> str: tp = type(self) if tp in {Operator, SelectorOperator}: @@ -44,6 +54,10 @@ def to_binary_expr( return BinaryExpr(left=left, op=self, right=right) + def __call__(self, lhs: Any, rhs: Any) -> Any: + """Apply binary operator to `left`, `right` operands.""" + return self.__class__._op(lhs, rhs) + class SelectorOperator(Operator): """Operators that can *also* be used in selectors. @@ -61,46 +75,61 @@ def to_binary_selector( return BinarySelector(left=left, op=self, right=right) -class Eq(Operator): ... +class Eq(Operator): + _op = operator.eq -class NotEq(Operator): ... +class NotEq(Operator): + _op = operator.ne -class Lt(Operator): ... +class Lt(Operator): + _op = operator.le -class LtEq(Operator): ... +class LtEq(Operator): + _op = operator.lt -class Gt(Operator): ... +class Gt(Operator): + _op = operator.gt -class GtEq(Operator): ... +class GtEq(Operator): + _op = operator.ge -class Add(Operator): ... +class Add(Operator): + _op = operator.add -class Sub(SelectorOperator): ... +class Sub(SelectorOperator): + _op = operator.sub -class Multiply(Operator): ... +class Multiply(Operator): + _op = operator.mul -class TrueDivide(Operator): ... +class TrueDivide(Operator): + _op = operator.truediv -class FloorDivide(Operator): ... +class FloorDivide(Operator): + _op = operator.floordiv -class Modulus(Operator): ... +class Modulus(Operator): + _op = operator.mod -class And(SelectorOperator): ... +class And(SelectorOperator): + _op = operator.and_ -class Or(SelectorOperator): ... +class Or(SelectorOperator): + _op = operator.or_ -class ExclusiveOr(SelectorOperator): ... +class ExclusiveOr(SelectorOperator): + _op = operator.xor diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 5f21589132..644de58e4a 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -21,6 +21,7 @@ LeftT = TypeVar("LeftT", bound="ExprIR", default="ExprIR") OperatorT = TypeVar("OperatorT", bound="ops.Operator", default="ops.Operator") RightT = TypeVar("RightT", bound="ExprIR", default="ExprIR") +OperatorFn: TypeAlias = "t.Callable[[t.Any, t.Any], t.Any]" LeftSelectorT = TypeVar("LeftSelectorT", bound="SelectorIR", default="SelectorIR") RightSelectorT = TypeVar("RightSelectorT", bound="SelectorIR", default="SelectorIR") From 0d1394ff22ee7ae74cf573d3b5f13a47457ff75b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 26 May 2025 23:01:01 +0100 Subject: [PATCH 130/368] test: Fix `lit` and add some tests Forgot one level of nesting on the lists --- narwhals/_plan/impl_arrow.py | 3 ++- tests/plan/to_compliant_test.py | 40 ++++++++++++++++++++++++++++++--- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/impl_arrow.py b/narwhals/_plan/impl_arrow.py index 00c83d149b..12d00bc8d9 100644 --- a/narwhals/_plan/impl_arrow.py +++ b/narwhals/_plan/impl_arrow.py @@ -42,7 +42,8 @@ def lit(node: expr.Literal, frame: NativeFrame) -> Evaluated: # noqa: ARG001 import pyarrow as pa if is_scalar_literal(node.value): - return [pa.chunked_array([node.value.unwrap()])] + scalar = node.value.unwrap() + return [pa.chunked_array([[scalar]])] elif is_series_literal(node.value): ca = node.value.unwrap().to_native() return [t.cast("NativeSeries", ca)] diff --git a/tests/plan/to_compliant_test.py b/tests/plan/to_compliant_test.py index c6ee5c6706..3173623603 100644 --- a/tests/plan/to_compliant_test.py +++ b/tests/plan/to_compliant_test.py @@ -1,11 +1,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pytest import narwhals as nw import narwhals._plan.demo as nwd +from narwhals._plan.common import is_expr +from narwhals._plan.impl_arrow import evaluate as evaluate_pyarrow from narwhals.utils import Version from tests.namespace_test import backends @@ -14,8 +16,10 @@ from narwhals._plan.dummy import DummyExpr -def _ids_ir(expr: DummyExpr) -> str: - return repr(expr._ir) +def _ids_ir(expr: DummyExpr | Any) -> str: + if is_expr(expr): + return repr(expr._ir) + return repr(expr) @pytest.mark.parametrize( @@ -35,3 +39,33 @@ def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None: namespace = Version.MAIN.namespace.from_backend(backend).compliant compliant_expr = expr._to_compliant(namespace) assert isinstance(compliant_expr, namespace._expr) + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (nwd.col("a"), ["A", "B", "A"]), + (nwd.col("a", "b"), [["A", "B", "A"], [1, 2, 3]]), + (nwd.lit(1), [1]), + (nwd.lit(2.0), [2.0]), + (nwd.lit(None, nw.String()), [None]), + ], + ids=_ids_ir, +) +def test_evaluate_pyarrow(expr: DummyExpr, expected: Any) -> None: + pytest.importorskip("pyarrow") + import pyarrow as pa + + data: dict[str, Any] = { + "a": ["A", "B", "A"], + "b": [1, 2, 3], + "c": [9, 2, 4], + "d": [8, 7, 8], + } + frame = pa.table(data) + result = evaluate_pyarrow(expr._ir, frame) + if len(result) == 1: + assert result[0].to_pylist() == expected + else: + results = [col.to_pylist() for col in result] + assert results == expected From 3165da4315f0e86cb639990095d9c6028efb409b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 27 May 2025 11:08:42 +0100 Subject: [PATCH 131/368] fix: Handle `lit` broadcasting --- narwhals/_plan/impl_arrow.py | 7 ++++--- tests/plan/to_compliant_test.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/impl_arrow.py b/narwhals/_plan/impl_arrow.py index 12d00bc8d9..6e3b4a73f6 100644 --- a/narwhals/_plan/impl_arrow.py +++ b/narwhals/_plan/impl_arrow.py @@ -38,12 +38,13 @@ def cols(node: expr.Columns, frame: NativeFrame) -> Evaluated: @evaluate.register(expr.Literal) -def lit(node: expr.Literal, frame: NativeFrame) -> Evaluated: # noqa: ARG001 +def lit(node: expr.Literal, frame: NativeFrame) -> Evaluated: import pyarrow as pa if is_scalar_literal(node.value): - scalar = node.value.unwrap() - return [pa.chunked_array([[scalar]])] + lit: t.Any = pa.scalar + array = pa.repeat(lit(node.value.unwrap()), len(frame)) + return [pa.chunked_array([array])] elif is_series_literal(node.value): ca = node.value.unwrap().to_native() return [t.cast("NativeSeries", ca)] diff --git a/tests/plan/to_compliant_test.py b/tests/plan/to_compliant_test.py index 3173623603..a08dc0f5df 100644 --- a/tests/plan/to_compliant_test.py +++ b/tests/plan/to_compliant_test.py @@ -46,9 +46,9 @@ def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None: [ (nwd.col("a"), ["A", "B", "A"]), (nwd.col("a", "b"), [["A", "B", "A"], [1, 2, 3]]), - (nwd.lit(1), [1]), - (nwd.lit(2.0), [2.0]), - (nwd.lit(None, nw.String()), [None]), + (nwd.lit(1), [1, 1, 1]), + (nwd.lit(2.0), [2.0, 2.0, 2.0]), + (nwd.lit(None, nw.String()), [None, None, None]), ], ids=_ids_ir, ) From ca1f8227a2e9da1ada8608db7a2c6aaeb93b4698 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 27 May 2025 12:26:39 +0100 Subject: [PATCH 132/368] feat: Improve `Literal` type unwrapping - 4th iteration of trying to get this working - A nice property of this is the unimplemented `RangeLiteral` case is (statically) unreachable in `impl_arrow` - Without the need to name *any* of the `LiteralValue` classes - Also the traversal is now hidden behind `unwrap`, which preserves the type --- narwhals/_plan/expr.py | 15 +++++++++++---- narwhals/_plan/impl_arrow.py | 16 ++++++++++------ narwhals/_plan/literal.py | 35 +++++++++++++++++++++++------------ narwhals/_plan/typing.py | 8 ++++++++ 4 files changed, 52 insertions(+), 22 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 39b82fec92..9e94339ecc 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -13,13 +13,14 @@ import typing as t from narwhals._plan.aggregation import Agg, OrderableAgg -from narwhals._plan.common import ExprIR, SelectorIR, _field_str +from narwhals._plan.common import ExprIR, SelectorIR, _field_str, is_non_nested_literal from narwhals._plan.name import KeepName, RenameAlias from narwhals._plan.typing import ( ExprT, FunctionT, LeftSelectorT, LeftT, + LiteralT, Ns, OperatorT, RightSelectorT, @@ -115,12 +116,12 @@ def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: return plx.col(*self.names) -class Literal(ExprIR): +class Literal(ExprIR, t.Generic[LiteralT]): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L81.""" __slots__ = ("value",) - value: LiteralValue + value: LiteralValue[LiteralT] @property def is_scalar(self) -> bool: @@ -138,7 +139,13 @@ def __repr__(self) -> str: return f"lit({self.value!r})" def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: - return plx.lit(self.value.unwrap(), self.dtype) + value = self.unwrap() + if is_non_nested_literal(value): + return plx.lit(value, self.dtype) + raise NotImplementedError(type(self.value)) + + def unwrap(self) -> LiteralT: + return self.value.unwrap() class _BinaryOp(ExprIR, t.Generic[LeftT, OperatorT, RightT]): diff --git a/narwhals/_plan/impl_arrow.py b/narwhals/_plan/impl_arrow.py index 6e3b4a73f6..d080d1037b 100644 --- a/narwhals/_plan/impl_arrow.py +++ b/narwhals/_plan/impl_arrow.py @@ -9,13 +9,15 @@ from functools import singledispatch from narwhals._plan import expr -from narwhals._plan.literal import is_scalar_literal, is_series_literal +from narwhals._plan.literal import is_literal_scalar, is_literal_series if t.TYPE_CHECKING: import pyarrow as pa from typing_extensions import TypeAlias from narwhals._plan.common import ExprIR + from narwhals._plan.dummy import DummySeries + from narwhals.typing import NonNestedLiteral NativeFrame: TypeAlias = pa.Table NativeSeries: TypeAlias = pa.ChunkedArray[t.Any] @@ -38,15 +40,17 @@ def cols(node: expr.Columns, frame: NativeFrame) -> Evaluated: @evaluate.register(expr.Literal) -def lit(node: expr.Literal, frame: NativeFrame) -> Evaluated: +def lit( + node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries], frame: NativeFrame +) -> Evaluated: import pyarrow as pa - if is_scalar_literal(node.value): + if is_literal_scalar(node): lit: t.Any = pa.scalar - array = pa.repeat(lit(node.value.unwrap()), len(frame)) + array = pa.repeat(lit(node.unwrap()), len(frame)) return [pa.chunked_array([array])] - elif is_series_literal(node.value): - ca = node.value.unwrap().to_native() + elif is_literal_series(node): + ca = node.unwrap().to_native() return [t.cast("NativeSeries", ca)] else: raise NotImplementedError(type(node.value)) diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index ac3c7d70b8..16ddfbf6af 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Generic from narwhals._plan.common import Immutable +from narwhals._plan.typing import LiteralT, NonNestedLiteralT if TYPE_CHECKING: from typing_extensions import TypeIs @@ -10,17 +11,9 @@ from narwhals._plan.dummy import DummySeries from narwhals._plan.expr import Literal from narwhals.dtypes import DType - from narwhals.typing import NonNestedLiteral -from narwhals._typing_compat import TypeVar -T = TypeVar("T", default=Any) -NonNestedLiteralT = TypeVar( - "NonNestedLiteralT", bound="NonNestedLiteral", default="NonNestedLiteral" -) - - -class LiteralValue(Immutable, Generic[T]): +class LiteralValue(Immutable, Generic[LiteralT]): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/plans/lit.rs#L67-L73.""" @property @@ -40,7 +33,7 @@ def to_literal(self) -> Literal: return Literal(value=self) - def unwrap(self) -> T: + def unwrap(self) -> LiteralT: raise NotImplementedError @@ -102,9 +95,27 @@ class RangeLiteral(LiteralValue): dtype: DType -def is_scalar_literal(obj: Any) -> TypeIs[ScalarLiteral]: +def _is_scalar( + obj: ScalarLiteral[NonNestedLiteralT] | Any, +) -> TypeIs[ScalarLiteral[NonNestedLiteralT]]: return isinstance(obj, ScalarLiteral) -def is_series_literal(obj: Any) -> TypeIs[SeriesLiteral]: +def _is_series(obj: Any) -> TypeIs[SeriesLiteral]: return isinstance(obj, SeriesLiteral) + + +def is_literal(obj: Literal[LiteralT] | Any) -> TypeIs[Literal[LiteralT]]: + from narwhals._plan.expr import Literal + + return isinstance(obj, Literal) + + +def is_literal_scalar( + obj: Literal[NonNestedLiteralT] | Any, +) -> TypeIs[Literal[NonNestedLiteralT]]: + return is_literal(obj) and _is_scalar(obj.value) + + +def is_literal_series(obj: Literal[DummySeries] | Any) -> TypeIs[Literal[DummySeries]]: + return is_literal(obj) and _is_series(obj.value) diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 644de58e4a..292d76bf82 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -11,7 +11,9 @@ from narwhals._compliant.typing import CompliantExprAny from narwhals._plan import operators as ops from narwhals._plan.common import ExprIR, Function, IRNamespace, SelectorIR + from narwhals._plan.dummy import DummySeries from narwhals._plan.functions import RollingWindow + from narwhals.typing import NonNestedLiteral __all__ = ["FunctionT", "LeftT", "OperatorT", "RightT", "RollingT", "SelectorOperatorT"] @@ -29,6 +31,12 @@ "SelectorOperatorT", bound="ops.SelectorOperator", default="ops.SelectorOperator" ) IRNamespaceT = TypeVar("IRNamespaceT", bound="IRNamespace") + +NonNestedLiteralT = TypeVar( + "NonNestedLiteralT", bound="NonNestedLiteral", default="NonNestedLiteral" +) +LiteralT = TypeVar("LiteralT", bound="NonNestedLiteral | DummySeries", default=t.Any) + # NOTE: Shorter aliases of `_compliant.typing` # - Aiming to try and preserve the types as much as possible # - Recursion between `Expr` and `Frame` is an issue From 4839a909968e55d3351a137847250c3ace6a211c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 27 May 2025 12:50:14 +0100 Subject: [PATCH 133/368] chore: Fill out planned nodes in `impl_arrow` - Mostly ordered by priority/ease to implement - `@singledispatch` matches by **mro**, so the order is here doesn't impact that --- narwhals/_plan/impl_arrow.py | 110 +++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/narwhals/_plan/impl_arrow.py b/narwhals/_plan/impl_arrow.py index d080d1037b..6b883318f1 100644 --- a/narwhals/_plan/impl_arrow.py +++ b/narwhals/_plan/impl_arrow.py @@ -54,3 +54,113 @@ def lit( return [t.cast("NativeSeries", ca)] else: raise NotImplementedError(type(node.value)) + + +@evaluate.register(expr.Alias) +def alias(node: expr.Alias, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.Len) +def len_(node: expr.Len, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.Nth) +def nth(node: expr.Nth, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.IndexColumns) +def index_columns(node: expr.IndexColumns, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.All) +def all_(node: expr.All, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.Exclude) +def exclude(node: expr.Exclude, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.Cast) +def cast_(node: expr.Cast, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.Ternary) +def ternary(node: expr.Ternary, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.Agg) +def agg(node: expr.Agg, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.OrderableAgg) +def orderable_agg(node: expr.OrderableAgg, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.BinaryExpr) +def binary_expr(node: expr.BinaryExpr, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.FunctionExpr) +def function_expr(node: expr.FunctionExpr[t.Any], frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.RollingExpr) +def rolling_expr(node: expr.RollingExpr[t.Any], frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.WindowExpr) +def window_expr(node: expr.WindowExpr, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.RootSelector) +def selector(node: expr.RootSelector, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.BinarySelector) +def binary_selector(node: expr.BinarySelector, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.RenameAlias) +def rename_alias(node: expr.RenameAlias, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.Sort) +def sort(node: expr.Sort, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.SortBy) +def sort_by(node: expr.SortBy, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.Filter) +def filter_(node: expr.Filter, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.AnonymousExpr) +def anonymous_expr(node: expr.AnonymousExpr, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) + + +@evaluate.register(expr.KeepName) +def keep_name(node: expr.KeepName, frame: NativeFrame) -> Evaluated: + raise NotImplementedError(type(node)) From 6e7f9bce36fbad54353a6d89f41e1b6cc8a2b6f7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 27 May 2025 21:28:56 +0100 Subject: [PATCH 134/368] feat: Add missed `boolean` methods oops --- narwhals/_plan/dummy.py | 47 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 0d6c61f7e7..f51f5ea080 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -39,6 +39,7 @@ from narwhals._plan.temporal import ExprDateTimeNamespace from narwhals._plan.typing import ExprT, Ns from narwhals.typing import ( + ClosedInterval, FillNullStrategy, NativeSeries, NumericLiteral, @@ -376,6 +377,52 @@ def map_batches( ).to_function_expr(self._ir) ) + def any(self) -> Self: + return self._from_ir(boolean.Any().to_function_expr(self._ir)) + + def all(self) -> Self: + return self._from_ir(boolean.All().to_function_expr(self._ir)) + + def is_duplicated(self) -> Self: + return self._from_ir(boolean.IsDuplicated().to_function_expr(self._ir)) + + def is_finite(self) -> Self: + return self._from_ir(boolean.IsFinite().to_function_expr(self._ir)) + + def is_nan(self) -> Self: + return self._from_ir(boolean.IsNan().to_function_expr(self._ir)) + + def is_null(self) -> Self: + return self._from_ir(boolean.IsNull().to_function_expr(self._ir)) + + def is_first_distinct(self) -> Self: + return self._from_ir(boolean.IsFirstDistinct().to_function_expr(self._ir)) + + def is_last_distinct(self) -> Self: + return self._from_ir(boolean.IsLastDistinct().to_function_expr(self._ir)) + + def is_unique(self) -> Self: + return self._from_ir(boolean.IsUnique().to_function_expr(self._ir)) + + def is_between( + self, + lower_bound: IntoExpr, + upper_bound: IntoExpr, + closed: ClosedInterval = "both", + ) -> Self: + it = parse.parse_into_seq_of_expr_ir(lower_bound, upper_bound) + return self._from_ir( + boolean.IsBetween(closed=closed).to_function_expr(self._ir, *it) + ) + + def is_in(self, other: t.Any) -> Self: + msg = ( + "There's some special handling of iterables that I'm not sure on:\n" + "https://github.com/narwhals-dev/narwhals/blob/8975189cb2459f129017cf833075b28ec3d4dfa8/narwhals/expr.py#L1176-L1184" + ) + raise NotImplementedError(msg) + return self._from_ir(boolean.IsIn().to_function_expr(self._ir)) + def __eq__(self, other: IntoExpr) -> Self: # type: ignore[override] op = ops.Eq() rhs = parse.parse_into_expr_ir(other, str_as_lit=True) From 8b9b8d32ef0ef3d64041a6351eb2ac2f3678f302 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 27 May 2025 21:46:15 +0100 Subject: [PATCH 135/368] fix: Enforce no repeat aggs, fix flags Forgot that `rust` has the `contains` check the other way --- narwhals/_plan/aggregation.py | 9 ++++++++- narwhals/_plan/options.py | 4 ++-- tests/plan/expr_parsing_test.py | 12 ++++++++++++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index 9fb4fa64f7..8a15711509 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -1,8 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from narwhals._plan.common import ExprIR +from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: from typing import Iterator @@ -35,6 +36,12 @@ def iter_right(self) -> Iterator[ExprIR]: yield self yield from self.expr.iter_right() + def __init__(self, *, expr: ExprIR, **kwds: Any) -> None: + if expr.is_scalar: + msg = "Can't apply aggregations to scalar-like expressions." + raise InvalidOperationError(msg) + super().__init__(expr=expr, **kwds) # pyright: ignore[reportCallIssue] + class Count(Agg): ... diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 7ddcd9f199..1deac6aa38 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -39,10 +39,10 @@ def is_elementwise(self) -> bool: return self in (FunctionFlags.ROW_SEPARABLE | FunctionFlags.LENGTH_PRESERVING) def returns_scalar(self) -> bool: - return self in FunctionFlags.RETURNS_SCALAR + return FunctionFlags.RETURNS_SCALAR in self def is_length_preserving(self) -> bool: - return self in FunctionFlags.LENGTH_PRESERVING + return FunctionFlags.LENGTH_PRESERVING in self @staticmethod def default() -> FunctionFlags: diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 84487a6e47..dd3f42bbc2 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -13,6 +13,7 @@ from narwhals._plan.common import ExprIR, Function from narwhals._plan.dummy import DummyExpr from narwhals._plan.expr import FunctionExpr +from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: from narwhals._plan.common import IntoExpr, Seq @@ -78,3 +79,14 @@ def test_function_expr_horizontal( assert isinstance(variadic_node.function, ir_node) assert variadic_node == sequence_node assert sequence_node != unrelated_node + + +def test_invalid_repeat_agg() -> None: + with pytest.raises(InvalidOperationError): + nwd.col("a").mean().mean() + with pytest.raises(InvalidOperationError): + nwd.col("a").first().max() + with pytest.raises(InvalidOperationError): + nwd.col("a").any().std() + with pytest.raises(InvalidOperationError): + nwd.col("a").all().quantile(0.5, "linear") From c203ef4dabf9643f0f5bca9925695c8deb8b8dbe Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 27 May 2025 21:58:11 +0100 Subject: [PATCH 136/368] chore: Remove/update comments/notes --- narwhals/_plan/__init__.py | 40 ----------------------------------- narwhals/_plan/aggregation.py | 2 -- narwhals/_plan/categorical.py | 1 - narwhals/_plan/expr.py | 13 ++---------- narwhals/_plan/functions.py | 6 +----- narwhals/_plan/meta.py | 2 -- narwhals/_plan/options.py | 1 - narwhals/_plan/selectors.py | 2 +- 8 files changed, 4 insertions(+), 63 deletions(-) diff --git a/narwhals/_plan/__init__.py b/narwhals/_plan/__init__.py index bfc32b789c..9d48db4f9f 100644 --- a/narwhals/_plan/__init__.py +++ b/narwhals/_plan/__init__.py @@ -1,41 +1 @@ -"""Brainstorming an `Expr` internal representation. - -Notes: -- Each `Expr` method should be representable by a single node - - But the node does not need to be unique to the method -- A chain of `Expr` methods should form a plan of operations -- We must be able to enforce rules on what plans are permitted: - - Must be flexible to both eager/lazy and individual backends - - Must be flexible to a given context (select, with_columns, filter, group_by) -- Nodes & plans are: - - Immutable, but - - Can be extended/re-written at both the Narwhals & Compliant levels - - Introspectable, but - - Store as little-as-needed for the common case - - Provide properties/methods for computing the less frequent metadata - -References: -- https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs -- https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs -- https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/function_expr/mod.rs -- https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/options/mod.rs#L137-L172 -- https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs#L35-L106 -- https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs#L131-L236 -- https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs#L240-L267 -- https://github.com/pola-rs/polars/blob/6df23a09a81c640c21788607611e09d9f43b1abc/crates/polars-plan/src/plans/aexpr/mod.rs - -Related: -- https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2866902903 -- https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2867331343 -- https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2867446959 -- https://github.com/narwhals-dev/narwhals/pull/2483#issuecomment-2869070157 -- (https://github.com/narwhals-dev/narwhals/pull/2538/commits/a7eeb0d23e67cb70e7cfa73cec2c7b69a15c8bef#r2083562677) -- https://github.com/narwhals-dev/narwhals/issues/2225 -- https://github.com/narwhals-dev/narwhals/issues/1848 -- https://github.com/narwhals-dev/narwhals/issues/2534#issuecomment-2875676729 -- https://github.com/narwhals-dev/narwhals/issues/2291 -- https://github.com/narwhals-dev/narwhals/issues/2522 -- https://github.com/narwhals-dev/narwhals/pull/2555 -""" - from __future__ import annotations diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index 8a15711509..7047d96340 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -72,7 +72,6 @@ class Std(Agg): __slots__ = (*Agg.__slots__, "ddof") ddof: int - """https://github.com/narwhals-dev/narwhals/pull/2555""" class Sum(Agg): ... @@ -82,7 +81,6 @@ class Var(Agg): __slots__ = (*Agg.__slots__, "ddof") ddof: int - """https://github.com/narwhals-dev/narwhals/pull/2555""" class OrderableAgg(Agg): ... diff --git a/narwhals/_plan/categorical.py b/narwhals/_plan/categorical.py index 47e045d926..0f8d490fba 100644 --- a/narwhals/_plan/categorical.py +++ b/narwhals/_plan/categorical.py @@ -17,7 +17,6 @@ class GetCategories(CategoricalFunction): @property def function_options(self) -> FunctionOptions: - """https://github.com/pola-rs/polars/blob/62257860a43ec44a638e8492ed2cf98a49c05f2e/crates/polars-plan/src/dsl/function_expr/cat.rs#L41.""" return FunctionOptions.groupwise() def __repr__(self) -> str: diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 9e94339ecc..4711dbc1b6 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -1,10 +1,4 @@ -"""Top-level `Expr` nodes. - -Todo: -- `Selector` -- `Ternary` -- `Window` (investigate variants) -""" +"""Top-level `Expr` nodes.""" from __future__ import annotations @@ -489,10 +483,7 @@ class BinarySelector( class Ternary(ExprIR): - """When-Then-Otherwise. - - Deferring this for now. - """ + """When-Then-Otherwise.""" __slots__ = ("falsy", "predicate", "truthy") diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 3fb7cfca14..00c8ee040f 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -1,8 +1,4 @@ -"""General functions that aren't namespaced. - -Todo: -- repr -""" +"""General functions that aren't namespaced.""" from __future__ import annotations diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 9454bd758d..ad2b30c055 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -1,7 +1,5 @@ """`pl.Expr.meta` namespace functionality. -- It seems like there might be a need to distinguish the top-level nodes for iterating - - polars_plan::dsl::expr::Expr - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/meta.rs#L11-L111 - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/plans/iterator.rs#L10-L105 """ diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 1deac6aa38..3c393c8f27 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -115,7 +115,6 @@ def aggregation() -> FunctionOptions: return FunctionOptions.groupwise().with_flags(FunctionFlags.RETURNS_SCALAR) -# TODO @dangotbanned: Decide on constructors class SortOptions(Immutable): __slots__ = ("descending", "nulls_last") diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/selectors.py index f9ef444b60..56f8ec3f4c 100644 --- a/narwhals/_plan/selectors.py +++ b/narwhals/_plan/selectors.py @@ -1,7 +1,7 @@ """Deviations from `polars`. - A `Selector` corresponds to a `nw.selectors` function -- Binary ops are represented as a subtype of `BinaryExpr` +- Binary ops are represented as a `BinarySelector`, similar to `BinaryExpr`. """ from __future__ import annotations From b13ebc6690314be3d06b06607425531a1c74aff4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 28 May 2025 11:49:54 +0100 Subject: [PATCH 137/368] test(DRAFT): Add (failing) `expression_parsing_test.py` tests Only have parity on one so far, need to do some more stuff like (https://github.com/narwhals-dev/narwhals/pull/2572/commits/8b9b8d32ef0ef3d64041a6351eb2ac2f3678f302) --- tests/plan/expr_parsing_test.py | 57 +++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index dd3f42bbc2..d2cb07c6e1 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -81,6 +81,21 @@ def test_function_expr_horizontal( assert sequence_node != unrelated_node +# TODO @dangotbanned: Get partity with the existing tests +# https://github.com/narwhals-dev/narwhals/blob/63c8e4771a1df4e0bfeea5559c303a4a447d5cc2/tests/expression_parsing_test.py#L48-L105 + + +def test_misleading_order_by() -> None: + with pytest.raises(InvalidOperationError): + nw.col("a").mean().over(order_by="b") + with pytest.raises(InvalidOperationError): + nw.col("a").rank().over(order_by="b") + + +# `test_double_over` is already covered in the later `test_nested_over` + + +# test_double_agg def test_invalid_repeat_agg() -> None: with pytest.raises(InvalidOperationError): nwd.col("a").mean().mean() @@ -90,3 +105,45 @@ def test_invalid_repeat_agg() -> None: nwd.col("a").any().std() with pytest.raises(InvalidOperationError): nwd.col("a").all().quantile(0.5, "linear") + + +def test_filter_aggregation() -> None: + with pytest.raises(InvalidOperationError): + nwd.col("a").mean().drop_nulls() + + +# TODO @dangotbanned: Add `head`, `tail` +def test_head_aggregation() -> None: + with pytest.raises(InvalidOperationError): + nwd.col("a").mean().head() # type: ignore[attr-defined] + + +def test_rank_aggregation() -> None: + with pytest.raises(InvalidOperationError): + nwd.col("a").mean().rank() + + +def test_diff_aggregation() -> None: + with pytest.raises(InvalidOperationError): + nwd.col("a").mean().diff() + + +def test_invalid_over() -> None: + with pytest.raises(InvalidOperationError): + nwd.col("a").fill_null(3).over("b") + + +def test_nested_over() -> None: + with pytest.raises(InvalidOperationError): + nwd.col("a").mean().over("b").over("c") + with pytest.raises(InvalidOperationError): + nwd.col("a").mean().over("b").over("c", order_by="i") + + +def test_filtration_over() -> None: + with pytest.raises(InvalidOperationError): + nwd.col("a").drop_nulls().over("b") + with pytest.raises(InvalidOperationError): + nwd.col("a").drop_nulls().over("b", order_by="i") + with pytest.raises(InvalidOperationError): + nwd.col("a").diff().drop_nulls().over("b", order_by="i") From d085b3a40f470d6ce2955d503ef0dda932464a6e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 28 May 2025 14:24:49 +0100 Subject: [PATCH 138/368] feat: Get 3x `over()` rules passing Aaaaand added comments on rule origin --- narwhals/_plan/expr.py | 29 +++++++++++++++++++++++ narwhals/_plan/options.py | 10 ++++++-- tests/plan/expr_parsing_test.py | 41 ++++++++++++++++++++++----------- 3 files changed, 64 insertions(+), 16 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 4711dbc1b6..8304d34e56 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -22,6 +22,7 @@ RollingT, SelectorOperatorT, ) +from narwhals.exceptions import InvalidOperationError if t.TYPE_CHECKING: from typing_extensions import Self @@ -403,6 +404,34 @@ def iter_right(self) -> t.Iterator[ExprIR]: yield from e.iter_right() yield from self.expr.iter_right() + def __init__( + self, + *, + expr: ExprIR, + partition_by: Seq[ExprIR], + order_by: tuple[Seq[ExprIR], SortOptions] | None, + options: Window, + ) -> None: + if isinstance(expr, WindowExpr): + msg = "Cannot nest `over` statements." + raise InvalidOperationError(msg) + + if isinstance(expr, FunctionExpr): + if expr.options.is_elementwise(): + msg = f"Cannot use `over` on expressions which are elementwise.\n{expr!r}" + raise InvalidOperationError(msg) + if expr.options.is_row_separable(): + msg = f"Cannot use `over` on expressions which change length.\n{expr!r}" + raise InvalidOperationError(msg) + + kwds = { + "expr": expr, + "partition_by": partition_by, + "order_by": order_by, + "options": options, + } + super().__init__(**kwds) + class Len(ExprIR): @property diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 3c393c8f27..2e9410ec14 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -27,7 +27,7 @@ class FunctionFlags(enum.Flag): """Automatically explode on unit length if it ran as final aggregation.""" ROW_SEPARABLE = 1 << 8 - """Not sure lol. + """`drop_nulls` is the only one we've got that is *just* this. https://github.com/pola-rs/polars/pull/22573 """ @@ -36,7 +36,7 @@ class FunctionFlags(enum.Flag): """mutually exclusive with `RETURNS_SCALAR`""" def is_elementwise(self) -> bool: - return self in (FunctionFlags.ROW_SEPARABLE | FunctionFlags.LENGTH_PRESERVING) + return (FunctionFlags.ROW_SEPARABLE | FunctionFlags.LENGTH_PRESERVING) in self def returns_scalar(self) -> bool: return FunctionFlags.RETURNS_SCALAR in self @@ -44,6 +44,9 @@ def returns_scalar(self) -> bool: def is_length_preserving(self) -> bool: return FunctionFlags.LENGTH_PRESERVING in self + def is_row_separable(self) -> bool: + return FunctionFlags.ROW_SEPARABLE in self + @staticmethod def default() -> FunctionFlags: return FunctionFlags.ALLOW_GROUP_AWARE @@ -75,6 +78,9 @@ def returns_scalar(self) -> bool: def is_length_preserving(self) -> bool: return self.flags.is_length_preserving() + def is_row_separable(self) -> bool: + return self.flags.is_row_separable() + def with_flags(self, flags: FunctionFlags, /) -> FunctionOptions: if (FunctionFlags.RETURNS_SCALAR | FunctionFlags.LENGTH_PRESERVING) in flags: msg = "A function cannot both return a scalar and preserve length, they are mutually exclusive." diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index d2cb07c6e1..a0252935ee 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING, Callable, Iterable import pytest @@ -81,17 +82,10 @@ def test_function_expr_horizontal( assert sequence_node != unrelated_node -# TODO @dangotbanned: Get partity with the existing tests +# TODO @dangotbanned: Get parity with the existing tests # https://github.com/narwhals-dev/narwhals/blob/63c8e4771a1df4e0bfeea5559c303a4a447d5cc2/tests/expression_parsing_test.py#L48-L105 -def test_misleading_order_by() -> None: - with pytest.raises(InvalidOperationError): - nw.col("a").mean().over(order_by="b") - with pytest.raises(InvalidOperationError): - nw.col("a").rank().over(order_by="b") - - # `test_double_over` is already covered in the later `test_nested_over` @@ -107,6 +101,9 @@ def test_invalid_repeat_agg() -> None: nwd.col("a").all().quantile(0.5, "linear") +# TODO @dangotbanned: Weirdly, `polars` suggestion **does** resolve it +# InvalidOperationError: Series idx, length 1 doesn't match the DataFrame height of 9 +# If you want expression: col("idx").mean().drop_nulls() to be broadcasted, ensure it is a scalar (for instance by adding '.first()') def test_filter_aggregation() -> None: with pytest.raises(InvalidOperationError): nwd.col("a").mean().drop_nulls() @@ -118,32 +115,48 @@ def test_head_aggregation() -> None: nwd.col("a").mean().head() # type: ignore[attr-defined] +# TODO @dangotbanned: (Same as `test_filter_aggregation`) def test_rank_aggregation() -> None: with pytest.raises(InvalidOperationError): nwd.col("a").mean().rank() +# TODO @dangotbanned: No error in `polars`, but results in all `null`s def test_diff_aggregation() -> None: with pytest.raises(InvalidOperationError): nwd.col("a").mean().diff() -def test_invalid_over() -> None: +# TODO @dangotbanned: Non-`polars`` rule +def test_misleading_order_by() -> None: + with pytest.raises(InvalidOperationError): + nwd.col("a").mean().over(order_by="b") with pytest.raises(InvalidOperationError): + nwd.col("a").rank().over(order_by="b") + + +# NOTE: Non-`polars`` rule +def test_invalid_over() -> None: + pattern = re.compile(r"cannot use.+over.+elementwise", re.IGNORECASE) + with pytest.raises(InvalidOperationError, match=pattern): nwd.col("a").fill_null(3).over("b") def test_nested_over() -> None: - with pytest.raises(InvalidOperationError): + pattern = re.compile(r"cannot nest.+over", re.IGNORECASE) + with pytest.raises(InvalidOperationError, match=pattern): nwd.col("a").mean().over("b").over("c") - with pytest.raises(InvalidOperationError): + with pytest.raises(InvalidOperationError, match=pattern): nwd.col("a").mean().over("b").over("c", order_by="i") +# NOTE: This *can* error in polars, but only if the length **actualy changes** +# The rule then breaks down to needing the same length arrays in all parts of the over def test_filtration_over() -> None: - with pytest.raises(InvalidOperationError): + pattern = re.compile(r"cannot use.+over.+change length", re.IGNORECASE) + with pytest.raises(InvalidOperationError, match=pattern): nwd.col("a").drop_nulls().over("b") - with pytest.raises(InvalidOperationError): + with pytest.raises(InvalidOperationError, match=pattern): nwd.col("a").drop_nulls().over("b", order_by="i") - with pytest.raises(InvalidOperationError): + with pytest.raises(InvalidOperationError, match=pattern): nwd.col("a").diff().drop_nulls().over("b", order_by="i") From 0fa48a0c8e2bdacdd9fecf6c8238a2fcf04910c1 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 28 May 2025 15:11:01 +0100 Subject: [PATCH 139/368] chore: Add missing `cum_sum` --- narwhals/_plan/dummy.py | 3 +++ narwhals/_plan/functions.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index f51f5ea080..5ea3f74979 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -233,6 +233,9 @@ def cum_max(self, *, reverse: bool = False) -> Self: def cum_prod(self, *, reverse: bool = False) -> Self: return self._from_ir(F.CumProd(reverse=reverse).to_function_expr(self._ir)) + def cum_sum(self, *, reverse: bool = False) -> Self: + return self._from_ir(F.CumSum(reverse=reverse).to_function_expr(self._ir)) + def rolling_sum( self, window_size: int, *, min_samples: int | None = None, center: bool = False ) -> Self: diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 00c8ee040f..5bcdac00d9 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -202,6 +202,7 @@ def __repr__(self) -> str: CumMin: "min", CumMax: "max", CumProd: "prod", + CumSum: "sum", } return f"cum_{m[tp]}" @@ -241,6 +242,9 @@ class CumMax(CumAgg): ... class CumProd(CumAgg): ... +class CumSum(CumAgg): ... + + class RollingSum(RollingWindow): ... From 51925d6abec24fa94ed8a6c30b0cd4a6f09b07db Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 28 May 2025 15:31:36 +0100 Subject: [PATCH 140/368] test: Identify horizontal/elementwise gap? Seems like allowing elementwise *only* here was unintentional? Also a note on head/tail/slice --- tests/plan/expr_parsing_test.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index a0252935ee..d0fb1bb70f 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -82,6 +82,34 @@ def test_function_expr_horizontal( assert sequence_node != unrelated_node +def test_valid_windows() -> None: + """Was planning to test this matched, but we seem to allow elementwise horizontal? + + https://github.com/narwhals-dev/narwhals/blob/63c8e4771a1df4e0bfeea5559c303a4a447d5cc2/tests/expression_parsing_test.py#L10-L45 + """ + ELEMENTWISE_ERR = re.compile(r"cannot use.+over.+elementwise", re.IGNORECASE) # noqa: N806 + a = nwd.col("a") + assert a.cum_sum() + assert a.cum_sum().over(order_by="id") + with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): + assert a.cum_sum().abs().over(order_by="id") + + assert (a.cum_sum() + 1).over(order_by="id") + assert a.cum_sum().cum_sum().over(order_by="id") + assert a.cum_sum().cum_sum() + assert nwd.sum_horizontal(a, a.cum_sum()) + with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): + assert nwd.sum_horizontal(a, a.cum_sum()).over(order_by="a") + + assert nwd.sum_horizontal(a, a.cum_sum().over(order_by="i")) + assert nwd.sum_horizontal(a.diff(), a.cum_sum().over(order_by="i")) + with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): + assert nwd.sum_horizontal(a.diff(), a.cum_sum()).over(order_by="i") + + with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): + assert nwd.sum_horizontal(a.diff().abs(), a.cum_sum()).over(order_by="i") + + # TODO @dangotbanned: Get parity with the existing tests # https://github.com/narwhals-dev/narwhals/blob/63c8e4771a1df4e0bfeea5559c303a4a447d5cc2/tests/expression_parsing_test.py#L48-L105 @@ -110,6 +138,9 @@ def test_filter_aggregation() -> None: # TODO @dangotbanned: Add `head`, `tail` +# head/tail are implemented in terms of `Expr::Slice` +# We don't support `Expr.slice`, seems odd to add it for a deprecation 🤔 +# polars allows this in `select`, but not `with_columns` def test_head_aggregation() -> None: with pytest.raises(InvalidOperationError): nwd.col("a").mean().head() # type: ignore[attr-defined] From 9bd10ad47acfa042d2c51455fd5218539bc08e1a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 28 May 2025 15:33:38 +0100 Subject: [PATCH 141/368] spellcheck https://results.pre-commit.ci/run/github/760058710/1748442712.iM1jW7a-RO2mrst14HWeOg --- tests/plan/expr_parsing_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index d0fb1bb70f..3a81d355b3 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -181,7 +181,7 @@ def test_nested_over() -> None: nwd.col("a").mean().over("b").over("c", order_by="i") -# NOTE: This *can* error in polars, but only if the length **actualy changes** +# NOTE: This *can* error in polars, but only if the length **actually changes** # The rule then breaks down to needing the same length arrays in all parts of the over def test_filtration_over() -> None: pattern = re.compile(r"cannot use.+over.+change length", re.IGNORECASE) From 87ea25e1c0e11edd9dc7320438179791f4daed62 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 28 May 2025 17:25:18 +0100 Subject: [PATCH 142/368] fix: Add `to_function_expr` overrides --- narwhals/_plan/functions.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 5bcdac00d9..5c3c55bfd6 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -11,7 +11,10 @@ if TYPE_CHECKING: from typing import Any - from narwhals._plan.common import Seq, Udf + from typing_extensions import Self + + from narwhals._plan.common import ExprIR, Seq, Udf + from narwhals._plan.expr import AnonymousExpr, RollingExpr from narwhals._plan.options import EWMOptions, RankOptions, RollingOptionsFixedWindow from narwhals.dtypes import DType from narwhals.typing import FillNullStrategy @@ -229,6 +232,12 @@ def __repr__(self) -> str: } return f"rolling_{m[tp]}" + def to_function_expr(self, *inputs: ExprIR) -> RollingExpr[Self]: + from narwhals._plan.expr import RollingExpr + + options = self.function_options + return RollingExpr(input=inputs, function=self, options=options) + class CumCount(CumAgg): ... @@ -394,3 +403,9 @@ def function_options(self) -> FunctionOptions: def __repr__(self) -> str: return "map_batches" + + def to_function_expr(self, *inputs: ExprIR) -> AnonymousExpr: + from narwhals._plan.expr import AnonymousExpr + + options = self.function_options + return AnonymousExpr(input=inputs, function=self, options=options) From 2d04634fed582ecafec74044685a391488446077 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 28 May 2025 17:31:01 +0100 Subject: [PATCH 143/368] refactor: Move validation to `window.py` --- narwhals/_plan/expr.py | 29 ----------------------------- narwhals/_plan/window.py | 15 ++++++++++++++- 2 files changed, 14 insertions(+), 30 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 8304d34e56..4711dbc1b6 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -22,7 +22,6 @@ RollingT, SelectorOperatorT, ) -from narwhals.exceptions import InvalidOperationError if t.TYPE_CHECKING: from typing_extensions import Self @@ -404,34 +403,6 @@ def iter_right(self) -> t.Iterator[ExprIR]: yield from e.iter_right() yield from self.expr.iter_right() - def __init__( - self, - *, - expr: ExprIR, - partition_by: Seq[ExprIR], - order_by: tuple[Seq[ExprIR], SortOptions] | None, - options: Window, - ) -> None: - if isinstance(expr, WindowExpr): - msg = "Cannot nest `over` statements." - raise InvalidOperationError(msg) - - if isinstance(expr, FunctionExpr): - if expr.options.is_elementwise(): - msg = f"Cannot use `over` on expressions which are elementwise.\n{expr!r}" - raise InvalidOperationError(msg) - if expr.options.is_row_separable(): - msg = f"Cannot use `over` on expressions which change length.\n{expr!r}" - raise InvalidOperationError(msg) - - kwds = { - "expr": expr, - "partition_by": partition_by, - "order_by": order_by, - "options": options, - } - super().__init__(**kwds) - class Len(ExprIR): @property diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py index 861e7baff8..f2dc5214a5 100644 --- a/narwhals/_plan/window.py +++ b/narwhals/_plan/window.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING from narwhals._plan.common import Immutable +from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: from narwhals._plan.common import ExprIR, Seq @@ -30,7 +31,19 @@ def to_window_expr( order_by: tuple[Seq[ExprIR], SortOptions] | None, /, ) -> WindowExpr: - from narwhals._plan.expr import WindowExpr + from narwhals._plan.expr import FunctionExpr, WindowExpr + + if isinstance(expr, WindowExpr): + msg = "Cannot nest `over` statements." + raise InvalidOperationError(msg) + + if isinstance(expr, FunctionExpr): + if expr.options.is_elementwise(): + msg = f"Cannot use `over` on expressions which are elementwise.\n{expr!r}" + raise InvalidOperationError(msg) + if expr.options.is_row_separable(): + msg = f"Cannot use `over` on expressions which change length.\n{expr!r}" + raise InvalidOperationError(msg) return WindowExpr( expr=expr, partition_by=partition_by, order_by=order_by, options=self From 49083ced1ecd2ccad6b0fd601b6ecfff799c3aa6 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 28 May 2025 18:16:29 +0100 Subject: [PATCH 144/368] feat: Simplify validating aggregations are elementwise Collapsed 3 tests into one branch --- narwhals/_plan/expr.py | 15 +++++++++++++++ tests/plan/expr_parsing_test.py | 34 ++++++++++++++------------------- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 4711dbc1b6..8c83714120 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -22,6 +22,7 @@ RollingT, SelectorOperatorT, ) +from narwhals.exceptions import InvalidOperationError if t.TYPE_CHECKING: from typing_extensions import Self @@ -300,6 +301,20 @@ def iter_right(self) -> t.Iterator[ExprIR]: for e in reversed(self.input): yield from e.iter_right() + def __init__( + self, + *, + input: Seq[ExprIR], # noqa: A002 + function: FunctionT, + options: FunctionOptions, + **kwds: t.Any, + ) -> None: + parent = input[0] + if parent.is_scalar and not options.is_elementwise(): + msg = f"Cannot use `{function!r}()` on aggregated expression `{parent!r}`." + raise InvalidOperationError(msg) + super().__init__(**dict(input=input, function=function, options=options, **kwds)) + class RollingExpr(FunctionExpr[RollingT]): ... diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 3a81d355b3..a0f741a624 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -129,14 +129,6 @@ def test_invalid_repeat_agg() -> None: nwd.col("a").all().quantile(0.5, "linear") -# TODO @dangotbanned: Weirdly, `polars` suggestion **does** resolve it -# InvalidOperationError: Series idx, length 1 doesn't match the DataFrame height of 9 -# If you want expression: col("idx").mean().drop_nulls() to be broadcasted, ensure it is a scalar (for instance by adding '.first()') -def test_filter_aggregation() -> None: - with pytest.raises(InvalidOperationError): - nwd.col("a").mean().drop_nulls() - - # TODO @dangotbanned: Add `head`, `tail` # head/tail are implemented in terms of `Expr::Slice` # We don't support `Expr.slice`, seems odd to add it for a deprecation 🤔 @@ -146,18 +138,6 @@ def test_head_aggregation() -> None: nwd.col("a").mean().head() # type: ignore[attr-defined] -# TODO @dangotbanned: (Same as `test_filter_aggregation`) -def test_rank_aggregation() -> None: - with pytest.raises(InvalidOperationError): - nwd.col("a").mean().rank() - - -# TODO @dangotbanned: No error in `polars`, but results in all `null`s -def test_diff_aggregation() -> None: - with pytest.raises(InvalidOperationError): - nwd.col("a").mean().diff() - - # TODO @dangotbanned: Non-`polars`` rule def test_misleading_order_by() -> None: with pytest.raises(InvalidOperationError): @@ -166,6 +146,20 @@ def test_misleading_order_by() -> None: nwd.col("a").rank().over(order_by="b") +# NOTE: Previously multiple different errors, but they can be reduced to the same thing +# Once we are scalar, only elementwise is allowed +def test_invalid_agg_non_elementwise() -> None: + pattern = re.compile(r"cannot use.+rank.+aggregated.+mean", re.IGNORECASE) + with pytest.raises(InvalidOperationError, match=pattern): + nwd.col("a").mean().rank() + pattern = re.compile(r"cannot use.+drop_nulls.+aggregated.+max", re.IGNORECASE) + with pytest.raises(InvalidOperationError): + nwd.col("a").max().drop_nulls() + pattern = re.compile(r"cannot use.+diff.+aggregated.+min", re.IGNORECASE) + with pytest.raises(InvalidOperationError): + nwd.col("a").min().diff() + + # NOTE: Non-`polars`` rule def test_invalid_over() -> None: pattern = re.compile(r"cannot use.+over.+elementwise", re.IGNORECASE) From 15402499e638165109c60479613ac52d2ef4302b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 28 May 2025 18:19:33 +0100 Subject: [PATCH 145/368] revert: Remove unplanned `head` test --- tests/plan/expr_parsing_test.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index a0f741a624..6898282256 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -114,10 +114,6 @@ def test_valid_windows() -> None: # https://github.com/narwhals-dev/narwhals/blob/63c8e4771a1df4e0bfeea5559c303a4a447d5cc2/tests/expression_parsing_test.py#L48-L105 -# `test_double_over` is already covered in the later `test_nested_over` - - -# test_double_agg def test_invalid_repeat_agg() -> None: with pytest.raises(InvalidOperationError): nwd.col("a").mean().mean() @@ -129,15 +125,6 @@ def test_invalid_repeat_agg() -> None: nwd.col("a").all().quantile(0.5, "linear") -# TODO @dangotbanned: Add `head`, `tail` -# head/tail are implemented in terms of `Expr::Slice` -# We don't support `Expr.slice`, seems odd to add it for a deprecation 🤔 -# polars allows this in `select`, but not `with_columns` -def test_head_aggregation() -> None: - with pytest.raises(InvalidOperationError): - nwd.col("a").mean().head() # type: ignore[attr-defined] - - # TODO @dangotbanned: Non-`polars`` rule def test_misleading_order_by() -> None: with pytest.raises(InvalidOperationError): From 4091f32a24285bb88cd62b01e6cfe3d431fea8e5 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 28 May 2025 18:40:18 +0100 Subject: [PATCH 146/368] feat: Add `arg_(min|max)` --- narwhals/_plan/dummy.py | 6 ++++++ tests/plan/expr_parsing_test.py | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 5ea3f74979..3c93b71b11 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -99,6 +99,12 @@ def n_unique(self) -> Self: def sum(self) -> Self: return self._from_ir(agg.Sum(expr=self._ir)) + def arg_min(self) -> Self: + return self._from_ir(agg.ArgMin(expr=self._ir)) + + def arg_max(self) -> Self: + return self._from_ir(agg.ArgMax(expr=self._ir)) + def first(self) -> Self: return self._from_ir(agg.First(expr=self._ir)) diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 6898282256..9b377393f2 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -123,6 +123,10 @@ def test_invalid_repeat_agg() -> None: nwd.col("a").any().std() with pytest.raises(InvalidOperationError): nwd.col("a").all().quantile(0.5, "linear") + with pytest.raises(InvalidOperationError): + nwd.col("a").arg_max().min() + with pytest.raises(InvalidOperationError): + nwd.col("a").arg_min().arg_max() # TODO @dangotbanned: Non-`polars`` rule From 726987e8897275ff7ece0db87e4de5a201875e60 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 28 May 2025 19:24:55 +0100 Subject: [PATCH 147/368] revert: remove `test_misleading_order_by` Can revisit this later if important, but it seems more like a lint rule --- tests/plan/expr_parsing_test.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 9b377393f2..b30f2b0683 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -110,10 +110,6 @@ def test_valid_windows() -> None: assert nwd.sum_horizontal(a.diff().abs(), a.cum_sum()).over(order_by="i") -# TODO @dangotbanned: Get parity with the existing tests -# https://github.com/narwhals-dev/narwhals/blob/63c8e4771a1df4e0bfeea5559c303a4a447d5cc2/tests/expression_parsing_test.py#L48-L105 - - def test_invalid_repeat_agg() -> None: with pytest.raises(InvalidOperationError): nwd.col("a").mean().mean() @@ -129,14 +125,6 @@ def test_invalid_repeat_agg() -> None: nwd.col("a").arg_min().arg_max() -# TODO @dangotbanned: Non-`polars`` rule -def test_misleading_order_by() -> None: - with pytest.raises(InvalidOperationError): - nwd.col("a").mean().over(order_by="b") - with pytest.raises(InvalidOperationError): - nwd.col("a").rank().over(order_by="b") - - # NOTE: Previously multiple different errors, but they can be reduced to the same thing # Once we are scalar, only elementwise is allowed def test_invalid_agg_non_elementwise() -> None: From e91db3bfdee668bd3acf3ba90b66db7a587784e0 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 29 May 2025 12:06:49 +0100 Subject: [PATCH 148/368] test: Add `immutable_test` and include class in hash --- narwhals/_plan/common.py | 3 +- tests/plan/immutable_test.py | 133 +++++++++++++++++++++++++++++++++++ 2 files changed, 135 insertions(+), 1 deletion(-) create mode 100644 tests/plan/immutable_test.py diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index ff4926da4c..5e45df50bf 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -80,7 +80,8 @@ def __init_subclass__(cls, *args: Any, **kwds: Any) -> None: def __hash__(self) -> int: slots: tuple[str, ...] = self.__slots__ - return hash(tuple(getattr(self, name) for name in slots)) + it = (getattr(self, name) for name in slots) + return hash((self.__class__, *it)) def __eq__(self, other: object) -> bool: if self is other: diff --git a/tests/plan/immutable_test.py b/tests/plan/immutable_test.py new file mode 100644 index 0000000000..ba15a50828 --- /dev/null +++ b/tests/plan/immutable_test.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +from narwhals._plan.common import Immutable + + +class Empty(Immutable): ... + + +class EmptyDerived(Empty): + __slots__ = ("a",) + a: int + + +class OneSlot(Immutable): + __slots__ = ("a",) + a: int + + +class TwoSlot(Immutable): + __slots__ = ("a", "b") + a: int + b: str + + +@pytest.fixture +def empty() -> Empty: + return Empty() + + +@pytest.fixture +def empty_derived() -> EmptyDerived: + return EmptyDerived(a=1) + + +@pytest.fixture +def one() -> OneSlot: + return OneSlot(a=1) + + +@pytest.fixture +def two() -> TwoSlot: + return TwoSlot(a=1, b="two") + + +def test_immutable_really_immutable( + empty: Empty, empty_derived: EmptyDerived, one: OneSlot, two: TwoSlot +) -> None: + with pytest.raises(AttributeError, match=r"Empty.+immutable.+'a'"): + empty.a = 1 # type: ignore[assignment] + assert empty_derived.a == 1 + with pytest.raises(AttributeError, match=r"EmptyDerived.+immutable.+'a'"): + empty_derived.a = 2 # type: ignore[misc] + with pytest.raises(AttributeError, match=r"OneSlot.+immutable.+'a'"): + one.a = 2 # type: ignore[misc] + with pytest.raises(AttributeError, match=r"OneSlot.+immutable.+'b'"): + one.b = "two" # type: ignore[assignment] + with pytest.raises(AttributeError, match=r"TwoSlot.+immutable.+'a'"): + two.a += 2 # type: ignore[misc] + with pytest.raises(AttributeError, match=r"TwoSlot.+immutable.+'b'"): + two.b = 2 # type: ignore[assignment, misc] + + +def test_immutable_hash( + empty: Empty, empty_derived: EmptyDerived, one: OneSlot, two: TwoSlot +) -> None: + class EmptyAgain(Immutable): ... + + assert empty == Empty() + assert empty_derived == EmptyDerived(a=1) + assert one == OneSlot(a=1) + assert two == TwoSlot(a=1, b="two") + + assert empty_derived != EmptyDerived(a=2) + assert one != OneSlot(a=2) + assert two != TwoSlot(a=2, b="two") + assert two != TwoSlot(a=1, b="three") + assert two != TwoSlot(a=2, b="three") + + assert empty != empty_derived + assert empty_derived != one + assert one != two + empty_again = EmptyAgain() + assert empty != empty_again + + mapping: dict[Any, Any] = {empty: empty} + mapping.update([(empty_derived, empty_derived), (one, one), (two, two)]) + assert len(mapping) == 4 + mapping[empty_again] = empty_again + assert len(mapping) == 5 + assert mapping[empty] is empty + assert mapping[EmptyDerived(a=1)] is empty_derived + assert mapping[OneSlot(a=1)] is one + assert mapping[OneSlot(a=1)] is not empty_derived + + assert hash(empty) != hash(empty_derived) + assert hash(empty_derived) != hash(one) + assert hash(one) != hash(two) + assert hash(empty_again) != hash(empty) + + +def test_immutable_invalid_constructor() -> None: + with pytest.raises(TypeError): + Empty(a=1) # pyright: ignore[reportCallIssue] + with pytest.raises(TypeError): + EmptyDerived(b="two") # type: ignore[call-arg] + with pytest.raises(TypeError): + EmptyDerived(a=1, b="two") # type: ignore[call-arg] + with pytest.raises(TypeError): + EmptyDerived() # type: ignore[call-arg] + with pytest.raises(TypeError): + OneSlot(b="two") # type: ignore[call-arg] + with pytest.raises(TypeError): + OneSlot(a=1, b="two") # type: ignore[call-arg] + with pytest.raises(TypeError): + OneSlot() # type: ignore[call-arg] + with pytest.raises(TypeError): + TwoSlot(b="two") # type: ignore[call-arg] + with pytest.raises(TypeError): + TwoSlot(a=1) # type: ignore[call-arg] + with pytest.raises(TypeError): + TwoSlot() # type: ignore[call-arg] + with pytest.raises(TypeError): + TwoSlot(a=1, b="two", c="huh?") # type: ignore[call-arg] + with pytest.raises(TypeError): + OneSlot(1) # type: ignore[misc] + with pytest.raises(TypeError): + OneSlot(1, 2, 3) # type: ignore[call-arg, misc] + with pytest.raises(TypeError): + OneSlot(1, a=1) # type: ignore[misc] From c6d777a2c6af33ca0cce24d672ca36b048095e69 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 29 May 2025 13:46:05 +0100 Subject: [PATCH 149/368] feat: Implement binary rhs `MultiOutputExpressionError` --- narwhals/_plan/expr.py | 2 +- narwhals/_plan/operators.py | 13 +++++++++++++ tests/plan/expr_parsing_test.py | 22 +++++++++++++++++++++- 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 8c83714120..452d84b1d6 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -470,7 +470,7 @@ class All(ExprIR): """ def __repr__(self) -> str: - return "*" + return "all()" class RootSelector(SelectorIR): diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index 2577fb2bfd..b7b2476d20 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -5,6 +5,7 @@ from narwhals._plan.common import Immutable from narwhals._plan.expr import BinarySelector +from narwhals.exceptions import MultiOutputExpressionError if TYPE_CHECKING: from typing import Any, ClassVar @@ -52,6 +53,18 @@ def to_binary_expr( ) -> BinaryExpr[LeftT, Self, RightT]: from narwhals._plan.expr import BinaryExpr + if right.meta.has_multiple_outputs(): + lhs_op = f"{left!r} {self!r} " + rhs = repr(right) + indent = len(lhs_op) * " " + underline = len(rhs) * "^" + msg = ( + "Multi-output expressions are only supported on the " + f"left-hand side of a binary operation.\n" + f"{lhs_op}{rhs}\n{indent}{underline}" + ) + raise MultiOutputExpressionError(msg) + return BinaryExpr(left=left, op=self, right=right) def __call__(self, lhs: Any, rhs: Any) -> Any: diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index b30f2b0683..e3a0774c83 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -14,7 +14,7 @@ from narwhals._plan.common import ExprIR, Function from narwhals._plan.dummy import DummyExpr from narwhals._plan.expr import FunctionExpr -from narwhals.exceptions import InvalidOperationError +from narwhals.exceptions import InvalidOperationError, MultiOutputExpressionError if TYPE_CHECKING: from narwhals._plan.common import IntoExpr, Seq @@ -164,3 +164,23 @@ def test_filtration_over() -> None: nwd.col("a").drop_nulls().over("b", order_by="i") with pytest.raises(InvalidOperationError, match=pattern): nwd.col("a").diff().drop_nulls().over("b", order_by="i") + + +def test_invalid_binary_expr() -> None: + pattern = re.escape("all() + cols(['b', 'c'])\n ^^^^^^^^^^^^^^^^") + with pytest.raises(MultiOutputExpressionError, match=pattern): + nwd.all() + nwd.col("b", "c") + pattern = re.escape( + "index_columns((1, 2, 3)) * index_columns((4, 5, 6)).max()\n" + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^" + ) + with pytest.raises(MultiOutputExpressionError, match=pattern): + nwd.nth(1, 2, 3) * nwd.nth(4, 5, 6).max() + pattern = re.escape( + "cols(['a', 'b', 'c']).abs().fill_null([lit(int: 0)]).round() * index_columns((9, 10)).cast(Int64).sort(asc)\n" + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^" + ) + with pytest.raises(MultiOutputExpressionError, match=pattern): + nwd.col("a", "b", "c").abs().fill_null(0).round(2) * nwd.nth(9, 10).cast( + nw.Int64() + ).sort() From e005cfa1f3d9661c45432681006530909a9b1be2 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 29 May 2025 15:02:44 +0100 Subject: [PATCH 150/368] feat: Implement `exclude`, `Expr.exclude` --- narwhals/_plan/demo.py | 4 ++++ narwhals/_plan/dummy.py | 3 +++ narwhals/_plan/expr.py | 16 +++++++++++++++- 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index abdd6aa9ef..d4a009b6b2 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -74,6 +74,10 @@ def all() -> DummyExpr: return All().to_narwhals() +def exclude(*names: str | t.Iterable[str]) -> DummyExpr: + return all().exclude(*names) + + def max(*columns: str) -> DummyExpr: return col(columns).max() diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 3c93b71b11..2bf0b39af8 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -78,6 +78,9 @@ def cast(self, dtype: DType | type[DType]) -> Self: dtype = dtype if isinstance(dtype, DType) else self.version.dtypes.Unknown() return self._from_ir(expr.Cast(expr=self._ir, dtype=dtype)) + def exclude(self, *names: str | t.Iterable[str]) -> Self: + return self._from_ir(expr.Exclude.from_names(self._ir, *names)) + def count(self) -> Self: return self._from_ir(agg.Count(expr=self._ir)) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 452d84b1d6..8cec683c4d 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -23,6 +23,7 @@ SelectorOperatorT, ) from narwhals.exceptions import InvalidOperationError +from narwhals.utils import flatten if t.TYPE_CHECKING: from typing_extensions import Self @@ -433,9 +434,22 @@ def __repr__(self) -> str: class Exclude(ExprIR): - __slots__ = ("names",) + __slots__ = ("expr", "names") + expr: ExprIR + """Default is `all()`.""" names: Seq[str] + """We're using a `frozenset` in main. + + Might want to switch to that later. + """ + + @staticmethod + def from_names(expr: ExprIR, *names: str | t.Iterable[str]) -> Exclude: + return Exclude(expr=expr, names=tuple(flatten(names))) + + def __repr__(self) -> str: + return f"{self.expr!r}.exclude({list(self.names)!r})" class Nth(ExprIR): From eb7605eea55e18c4a784e6dfb7404ef3dac22fb3 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 29 May 2025 16:58:42 +0100 Subject: [PATCH 151/368] feat: Implement binary multi `LengthChangingExprError` - Also added valid variants of the failures - `ShapeError` is todo --- narwhals/_plan/operators.py | 26 ++++++++++++-- tests/plan/expr_parsing_test.py | 60 +++++++++++++++++++++++++++++++-- 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index b7b2476d20..c66189fc70 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -3,9 +3,9 @@ import operator from typing import TYPE_CHECKING -from narwhals._plan.common import Immutable -from narwhals._plan.expr import BinarySelector -from narwhals.exceptions import MultiOutputExpressionError +from narwhals._plan.common import ExprIR, Immutable +from narwhals._plan.expr import BinarySelector, FunctionExpr +from narwhals.exceptions import LengthChangingExprError, MultiOutputExpressionError if TYPE_CHECKING: from typing import Any, ClassVar @@ -65,6 +65,19 @@ def to_binary_expr( ) raise MultiOutputExpressionError(msg) + if not any(_is_not_filtration(e) for e in (left, right)): + lhs, rhs = repr(left), repr(right) + op = f" {self!r} " + underline_left = len(lhs) * "^" + underline_right = len(rhs) * "^" + pad_middle = len(op) * " " + msg = ( + "Length-changing expressions can only be used in isolation, " + "or followed by an aggregation.\n" + f"{lhs}{op}{rhs}\n{underline_left}{pad_middle}{underline_right}" + ) + raise LengthChangingExprError(msg) + return BinaryExpr(left=left, op=self, right=right) def __call__(self, lhs: Any, rhs: Any) -> Any: @@ -72,6 +85,13 @@ def __call__(self, lhs: Any, rhs: Any) -> Any: return self.__class__._op(lhs, rhs) +def _is_not_filtration(ir: ExprIR) -> bool: + # NOTE: Strange naming/negation is to short-circuit on the `any` + if not ir.is_scalar and isinstance(ir, FunctionExpr): + return ir.options.is_elementwise() + return True + + class SelectorOperator(Operator): """Operators that can *also* be used in selectors. diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index e3a0774c83..f46a13aff9 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -13,8 +13,13 @@ ) from narwhals._plan.common import ExprIR, Function from narwhals._plan.dummy import DummyExpr -from narwhals._plan.expr import FunctionExpr -from narwhals.exceptions import InvalidOperationError, MultiOutputExpressionError +from narwhals._plan.expr import BinaryExpr, FunctionExpr +from narwhals.exceptions import ( + InvalidOperationError, + LengthChangingExprError, + MultiOutputExpressionError, + ShapeError, +) if TYPE_CHECKING: from narwhals._plan.common import IntoExpr, Seq @@ -166,7 +171,7 @@ def test_filtration_over() -> None: nwd.col("a").diff().drop_nulls().over("b", order_by="i") -def test_invalid_binary_expr() -> None: +def test_invalid_binary_expr_multi() -> None: pattern = re.escape("all() + cols(['b', 'c'])\n ^^^^^^^^^^^^^^^^") with pytest.raises(MultiOutputExpressionError, match=pattern): nwd.all() + nwd.col("b", "c") @@ -184,3 +189,52 @@ def test_invalid_binary_expr() -> None: nwd.col("a", "b", "c").abs().fill_null(0).round(2) * nwd.nth(9, 10).cast( nw.Int64() ).sort() + + +def test_invalid_binary_expr_length_changing() -> None: + a = nwd.col("a") + b = nwd.col("b") + + with pytest.raises(LengthChangingExprError): + a.unique() + b.unique() + + with pytest.raises(LengthChangingExprError): + a.mode() * b.unique() + + with pytest.raises(LengthChangingExprError): + a.drop_nulls() - b.mode() + + with pytest.raises(LengthChangingExprError): + a.gather_every(2, 1) / b.drop_nulls() + + with pytest.raises(LengthChangingExprError): + a.map_batches(lambda x: x) / b.gather_every(1, 0) + + +def _is_expr_ir_binary_expr(expr: DummyExpr) -> bool: + return isinstance(expr._ir, BinaryExpr) + + +def test_binary_expr_length_changing_agg() -> None: + a = nwd.col("a") + b = nwd.col("b") + + assert _is_expr_ir_binary_expr(a.unique().first() + b.unique()) + assert _is_expr_ir_binary_expr(a.mode().last() * b.unique()) + assert _is_expr_ir_binary_expr(a.drop_nulls().min() - b.mode()) + assert _is_expr_ir_binary_expr(a.gather_every(2, 1) / b.drop_nulls().max()) + assert _is_expr_ir_binary_expr( + b.gather_every(1, 0) / a.map_batches(lambda x: x, returns_scalar=True) + ) + assert _is_expr_ir_binary_expr( + a.map_batches(lambda x: x, is_elementwise=True) * b.gather_every(1, 0) + ) + assert _is_expr_ir_binary_expr(b.unique() * a.map_batches(lambda x: x).first()) + + +# TODO @dangotbanned: Figure out how to fit this in +@pytest.mark.xfail(reason="Did not raise, haven't added a check to raise for it yet") +def test_invalid_binary_expr_shape() -> None: + """Cannot combine length-changing expressions with length-preserving ones or aggregations.""" + with pytest.raises(ShapeError): + nwd.col("a").unique() + nwd.col("b") From 5dfaa485c6af4cba0f717627a02aa87fc8cfdea3 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 31 May 2025 18:48:28 +0100 Subject: [PATCH 152/368] feat: Implement binary filtration `ShapeError` Resolves (https://github.com/narwhals-dev/narwhals/pull/2572#discussion_r2118071252) --- narwhals/_plan/operators.py | 91 +++++++++++++++++++++++---------- tests/plan/expr_parsing_test.py | 21 +++++--- 2 files changed, 76 insertions(+), 36 deletions(-) diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index c66189fc70..74d019cbdf 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -5,7 +5,11 @@ from narwhals._plan.common import ExprIR, Immutable from narwhals._plan.expr import BinarySelector, FunctionExpr -from narwhals.exceptions import LengthChangingExprError, MultiOutputExpressionError +from narwhals.exceptions import ( + LengthChangingExprError, + MultiOutputExpressionError, + ShapeError, +) if TYPE_CHECKING: from typing import Any, ClassVar @@ -54,29 +58,16 @@ def to_binary_expr( from narwhals._plan.expr import BinaryExpr if right.meta.has_multiple_outputs(): - lhs_op = f"{left!r} {self!r} " - rhs = repr(right) - indent = len(lhs_op) * " " - underline = len(rhs) * "^" - msg = ( - "Multi-output expressions are only supported on the " - f"left-hand side of a binary operation.\n" - f"{lhs_op}{rhs}\n{indent}{underline}" - ) - raise MultiOutputExpressionError(msg) - - if not any(_is_not_filtration(e) for e in (left, right)): - lhs, rhs = repr(left), repr(right) - op = f" {self!r} " - underline_left = len(lhs) * "^" - underline_right = len(rhs) * "^" - pad_middle = len(op) * " " - msg = ( - "Length-changing expressions can only be used in isolation, " - "or followed by an aggregation.\n" - f"{lhs}{op}{rhs}\n{underline_left}{pad_middle}{underline_right}" - ) - raise LengthChangingExprError(msg) + raise _bin_op_multi_output_error(left, self, right) + + if _is_filtration(left): + if _is_filtration(right): + raise _bin_op_length_changing_error(left, self, right) + if not right.is_scalar: + raise _bin_op_shape_error(left, self, right) + elif _is_filtration(right): + if not left.is_scalar: + raise _bin_op_shape_error(left, self, right) return BinaryExpr(left=left, op=self, right=right) @@ -85,11 +76,55 @@ def __call__(self, lhs: Any, rhs: Any) -> Any: return self.__class__._op(lhs, rhs) -def _is_not_filtration(ir: ExprIR) -> bool: - # NOTE: Strange naming/negation is to short-circuit on the `any` +# NOTE: Always underlining `right`, since the message refers to both types of exprs +# Assuming the most recent as the issue +def _bin_op_shape_error(left: ExprIR, op: Operator, right: ExprIR) -> ShapeError: + lhs_op = f"{left!r} {op!r} " + rhs = repr(right) + indent = len(lhs_op) * " " + underline = len(rhs) * "^" + msg = ( + f"Cannot combine length-changing expressions with length-preserving ones.\n" + f"{lhs_op}{rhs}\n{indent}{underline}" + ) + return ShapeError(msg) + + +def _bin_op_multi_output_error( + left: ExprIR, op: Operator, right: ExprIR +) -> MultiOutputExpressionError: + lhs_op = f"{left!r} {op!r} " + rhs = repr(right) + indent = len(lhs_op) * " " + underline = len(rhs) * "^" + msg = ( + "Multi-output expressions are only supported on the " + f"left-hand side of a binary operation.\n" + f"{lhs_op}{rhs}\n{indent}{underline}" + ) + return MultiOutputExpressionError(msg) + + +def _bin_op_length_changing_error( + left: ExprIR, op: Operator, right: ExprIR +) -> LengthChangingExprError: + lhs, rhs = repr(left), repr(right) + op_s = f" {op!r} " + underline_left = len(lhs) * "^" + underline_right = len(rhs) * "^" + pad_middle = len(op_s) * " " + msg = ( + "Length-changing expressions can only be used in isolation, " + "or followed by an aggregation.\n" + f"{lhs}{op_s}{rhs}\n{underline_left}{pad_middle}{underline_right}" + ) + return LengthChangingExprError(msg) + + +def _is_filtration(ir: ExprIR) -> bool: if not ir.is_scalar and isinstance(ir, FunctionExpr): - return ir.options.is_elementwise() - return True + return not ir.options.is_elementwise() + return False class SelectorOperator(Operator): diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index f46a13aff9..645f91f56b 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -226,15 +226,20 @@ def test_binary_expr_length_changing_agg() -> None: assert _is_expr_ir_binary_expr( b.gather_every(1, 0) / a.map_batches(lambda x: x, returns_scalar=True) ) - assert _is_expr_ir_binary_expr( - a.map_batches(lambda x: x, is_elementwise=True) * b.gather_every(1, 0) - ) assert _is_expr_ir_binary_expr(b.unique() * a.map_batches(lambda x: x).first()) -# TODO @dangotbanned: Figure out how to fit this in -@pytest.mark.xfail(reason="Did not raise, haven't added a check to raise for it yet") def test_invalid_binary_expr_shape() -> None: - """Cannot combine length-changing expressions with length-preserving ones or aggregations.""" - with pytest.raises(ShapeError): - nwd.col("a").unique() + nwd.col("b") + pattern = re.compile( + re.escape("Cannot combine length-changing expressions with length-preserving"), + re.IGNORECASE, + ) + a = nwd.col("a") + b = nwd.col("b") + + with pytest.raises(ShapeError, match=pattern): + a.unique() + b + with pytest.raises(ShapeError, match=pattern): + a.map_batches(lambda x: x, is_elementwise=True) * b.gather_every(1, 0) + with pytest.raises(ShapeError, match=pattern): + a / b.drop_nulls() From da1f29345c92146a7899e546eb00f9933d5d52fd Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 1 Jun 2025 13:20:32 +0100 Subject: [PATCH 153/368] refactor: Split out exceptions to a new module - Aiming for more consistent messages - Will be easier now they are all together --- narwhals/_plan/aggregation.py | 5 +- narwhals/_plan/exceptions.py | 159 +++++++++++++++++++++++++++++++++ narwhals/_plan/expr.py | 5 +- narwhals/_plan/expr_parsing.py | 33 +++---- narwhals/_plan/functions.py | 5 +- narwhals/_plan/operators.py | 68 +++----------- narwhals/_plan/window.py | 17 ++-- 7 files changed, 194 insertions(+), 98 deletions(-) create mode 100644 narwhals/_plan/exceptions.py diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index 7047d96340..b4cf4fb44f 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any from narwhals._plan.common import ExprIR -from narwhals.exceptions import InvalidOperationError +from narwhals._plan.exceptions import agg_scalar_error if TYPE_CHECKING: from typing import Iterator @@ -38,8 +38,7 @@ def iter_right(self) -> Iterator[ExprIR]: def __init__(self, *, expr: ExprIR, **kwds: Any) -> None: if expr.is_scalar: - msg = "Can't apply aggregations to scalar-like expressions." - raise InvalidOperationError(msg) + raise agg_scalar_error(self, expr) super().__init__(expr=expr, **kwds) # pyright: ignore[reportCallIssue] diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py new file mode 100644 index 0000000000..6c448ad433 --- /dev/null +++ b/narwhals/_plan/exceptions.py @@ -0,0 +1,159 @@ +"""Exceptions and tools to format them.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from narwhals.exceptions import ( + ComputeError, + InvalidIntoExprError, + InvalidOperationError, + LengthChangingExprError, + MultiOutputExpressionError, + ShapeError, +) + +if TYPE_CHECKING: + from typing import Any, Iterable + + import pandas as pd + import polars as pl + + from narwhals._plan.aggregation import Agg + from narwhals._plan.common import ExprIR, Function, IntoExpr, Seq + from narwhals._plan.expr import FunctionExpr, WindowExpr + from narwhals._plan.operators import Operator + from narwhals._plan.options import SortOptions + + +# NOTE: Using verbose names to start with +# TODO @dangotbanned: Think about something better/more consistent once the new messages are finalized + + +# TODO @dangotbanned: Use arguments in error message +def agg_scalar_error(agg: Agg, scalar: ExprIR, /) -> InvalidOperationError: # noqa: ARG001 + msg = "Can't apply aggregations to scalar-like expressions." + return InvalidOperationError(msg) + + +def function_expr_invalid_operation_error( + function: Function, parent: ExprIR +) -> InvalidOperationError: + msg = f"Cannot use `{function!r}()` on aggregated expression `{parent!r}`." + return InvalidOperationError(msg) + + +# TODO @dangotbanned: Use arguments in error message +def hist_bins_monotonic_error(bins: Seq[float]) -> ComputeError: # noqa: ARG001 + msg = "bins must increase monotonically" + return ComputeError(msg) + + +# NOTE: Always underlining `right`, since the message refers to both types of exprs +# Assuming the most recent as the issue +def binary_expr_shape_error(left: ExprIR, op: Operator, right: ExprIR) -> ShapeError: + lhs_op = f"{left!r} {op!r} " + rhs = repr(right) + indent = len(lhs_op) * " " + underline = len(rhs) * "^" + msg = ( + f"Cannot combine length-changing expressions with length-preserving ones.\n" + f"{lhs_op}{rhs}\n{indent}{underline}" + ) + return ShapeError(msg) + + +# TODO @dangotbanned: Share the right underline code w/ `binary_expr_shape_error` +def binary_expr_multi_output_error( + left: ExprIR, op: Operator, right: ExprIR +) -> MultiOutputExpressionError: + lhs_op = f"{left!r} {op!r} " + rhs = repr(right) + indent = len(lhs_op) * " " + underline = len(rhs) * "^" + msg = ( + "Multi-output expressions are only supported on the " + f"left-hand side of a binary operation.\n" + f"{lhs_op}{rhs}\n{indent}{underline}" + ) + return MultiOutputExpressionError(msg) + + +def binary_expr_length_changing_error( + left: ExprIR, op: Operator, right: ExprIR +) -> LengthChangingExprError: + lhs, rhs = repr(left), repr(right) + op_s = f" {op!r} " + underline_left = len(lhs) * "^" + underline_right = len(rhs) * "^" + pad_middle = len(op_s) * " " + msg = ( + "Length-changing expressions can only be used in isolation, " + "or followed by an aggregation.\n" + f"{lhs}{op_s}{rhs}\n{underline_left}{pad_middle}{underline_right}" + ) + return LengthChangingExprError(msg) + + +# TODO @dangotbanned: Use arguments in error message +def over_nested_error( + expr: WindowExpr, # noqa: ARG001 + partition_by: Seq[ExprIR], # noqa: ARG001 + order_by: tuple[Seq[ExprIR], SortOptions] | None, # noqa: ARG001 +) -> InvalidOperationError: + msg = "Cannot nest `over` statements." + return InvalidOperationError(msg) + + +# TODO @dangotbanned: Use arguments in error message +def over_elementwise_error( + expr: FunctionExpr[Function], + partition_by: Seq[ExprIR], # noqa: ARG001 + order_by: tuple[Seq[ExprIR], SortOptions] | None, # noqa: ARG001 +) -> InvalidOperationError: + msg = f"Cannot use `over` on expressions which are elementwise.\n{expr!r}" + return InvalidOperationError(msg) + + +# TODO @dangotbanned: Use arguments in error message +def over_row_separable_error( + expr: FunctionExpr[Function], + partition_by: Seq[ExprIR], # noqa: ARG001 + order_by: tuple[Seq[ExprIR], SortOptions] | None, # noqa: ARG001 +) -> InvalidOperationError: + msg = f"Cannot use `over` on expressions which change length.\n{expr!r}" + return InvalidOperationError(msg) + + +def invalid_into_expr_error( + first_input: Iterable[IntoExpr], + more_inputs: tuple[IntoExpr, ...], + named_inputs: dict[str, IntoExpr], + /, +) -> InvalidIntoExprError: + msg = ( + f"Passing both iterable and positional inputs is not supported.\n" + f"Hint:\nInstead try collecting all arguments into a {type(first_input).__name__!r}\n" + f"{first_input!r}\n{more_inputs!r}\n{named_inputs!r}" + ) + return InvalidIntoExprError(msg) + + +def is_iterable_pandas_error(obj: pd.DataFrame | pd.Series[Any], /) -> TypeError: + msg = ( + f"Expected Narwhals class or scalar, got: {type(obj)}. " + "Perhaps you forgot a `nw.from_native` somewhere?" + ) + return TypeError(msg) + + +def is_iterable_polars_error( + obj: pl.Series | pl.Expr | pl.DataFrame | pl.LazyFrame, / +) -> TypeError: + msg = ( + f"Expected Narwhals class or scalar, got: {type(obj)}.\n\n" + "Hint: Perhaps you\n" + "- forgot a `nw.from_native` somewhere?\n" + "- used `pl.col` instead of `nw.col`?" + ) + return TypeError(msg) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 8cec683c4d..f68f3d2849 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -8,6 +8,7 @@ from narwhals._plan.aggregation import Agg, OrderableAgg from narwhals._plan.common import ExprIR, SelectorIR, _field_str, is_non_nested_literal +from narwhals._plan.exceptions import function_expr_invalid_operation_error from narwhals._plan.name import KeepName, RenameAlias from narwhals._plan.typing import ( ExprT, @@ -22,7 +23,6 @@ RollingT, SelectorOperatorT, ) -from narwhals.exceptions import InvalidOperationError from narwhals.utils import flatten if t.TYPE_CHECKING: @@ -312,8 +312,7 @@ def __init__( ) -> None: parent = input[0] if parent.is_scalar and not options.is_elementwise(): - msg = f"Cannot use `{function!r}()` on aggregated expression `{parent!r}`." - raise InvalidOperationError(msg) + raise function_expr_invalid_operation_error(function, parent) super().__init__(**dict(input=input, function=function, options=options, **kwds)) diff --git a/narwhals/_plan/expr_parsing.py b/narwhals/_plan/expr_parsing.py index a9944fceac..8af7db1f55 100644 --- a/narwhals/_plan/expr_parsing.py +++ b/narwhals/_plan/expr_parsing.py @@ -4,12 +4,17 @@ from typing import TYPE_CHECKING, Iterable, Sequence, TypeVar from narwhals._plan.common import is_expr, is_iterable_reject +from narwhals._plan.exceptions import ( + invalid_into_expr_error, + is_iterable_pandas_error, + is_iterable_polars_error, +) from narwhals.dependencies import get_polars, is_pandas_dataframe, is_pandas_series -from narwhals.exceptions import InvalidIntoExprError if TYPE_CHECKING: from typing import Any, Iterator + import polars as pl from typing_extensions import TypeAlias, TypeIs from narwhals._plan.common import ExprIR, IntoExpr, Seq @@ -108,7 +113,7 @@ def _parse_into_iter_expr_ir( # Otherwise, `str | bytes` always passes through typing if _is_iterable(first_input) and not is_iterable_reject(first_input): if more_inputs: - raise _invalid_into_expr_error(first_input, more_inputs, named_inputs) + raise invalid_into_expr_error(first_input, more_inputs, named_inputs) else: yield from _parse_positional_inputs(first_input) else: @@ -136,16 +141,9 @@ def _parse_named_inputs(named_inputs: dict[str, IntoExpr], /) -> Iterator[ExprIR def _is_iterable(obj: Iterable[T] | Any) -> TypeIs[Iterable[T]]: if is_pandas_dataframe(obj) or is_pandas_series(obj): - msg = f"Expected Narwhals class or scalar, got: {type(obj)}. Perhaps you forgot a `nw.from_native` somewhere?" - raise TypeError(msg) + raise is_iterable_pandas_error(obj) if _is_polars(obj): - msg = ( - f"Expected Narwhals class or scalar, got: {type(obj)}.\n\n" - "Hint: Perhaps you\n" - "- forgot a `nw.from_native` somewhere?\n" - "- used `pl.col` instead of `nw.col`?" - ) - raise TypeError(msg) + raise is_iterable_polars_error(obj) return isinstance(obj, Iterable) @@ -153,18 +151,7 @@ def _is_empty_sequence(obj: Any) -> bool: return isinstance(obj, Sequence) and not obj -def _is_polars(obj: Any) -> bool: +def _is_polars(obj: Any) -> TypeIs[pl.Series | pl.Expr | pl.DataFrame | pl.LazyFrame]: return (pl := get_polars()) is not None and isinstance( obj, (pl.Series, pl.Expr, pl.DataFrame, pl.LazyFrame) ) - - -def _invalid_into_expr_error( - first_input: Any, more_inputs: Any, named_inputs: Any -) -> InvalidIntoExprError: - msg = ( - f"Passing both iterable and positional inputs is not supported.\n" - f"Hint:\nInstead try collecting all arguments into a {type(first_input).__name__!r}\n" - f"{first_input!r}\n{more_inputs!r}\n{named_inputs!r}" - ) - return InvalidIntoExprError(msg) diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 5c3c55bfd6..1a74a58285 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -5,8 +5,8 @@ from typing import TYPE_CHECKING from narwhals._plan.common import Function +from narwhals._plan.exceptions import hist_bins_monotonic_error from narwhals._plan.options import FunctionFlags, FunctionOptions -from narwhals.exceptions import ComputeError if TYPE_CHECKING: from typing import Any @@ -54,8 +54,7 @@ class HistBins(Hist): def __init__(self, *, bins: Seq[float], include_breakpoint: bool = True) -> None: for i in range(1, len(bins)): if bins[i - 1] >= bins[i]: - msg = "bins must increase monotonically" - raise ComputeError(msg) + raise hist_bins_monotonic_error(bins) object.__setattr__(self, "bins", bins) object.__setattr__(self, "include_breakpoint", include_breakpoint) diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index 74d019cbdf..a702a60d85 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -3,19 +3,20 @@ import operator from typing import TYPE_CHECKING -from narwhals._plan.common import ExprIR, Immutable -from narwhals._plan.expr import BinarySelector, FunctionExpr -from narwhals.exceptions import ( - LengthChangingExprError, - MultiOutputExpressionError, - ShapeError, +from narwhals._plan.common import Immutable +from narwhals._plan.exceptions import ( + binary_expr_length_changing_error, + binary_expr_multi_output_error, + binary_expr_shape_error, ) +from narwhals._plan.expr import BinarySelector, FunctionExpr if TYPE_CHECKING: from typing import Any, ClassVar from typing_extensions import Self + from narwhals._plan.common import ExprIR from narwhals._plan.expr import BinaryExpr, BinarySelector from narwhals._plan.typing import ( LeftSelectorT, @@ -58,17 +59,15 @@ def to_binary_expr( from narwhals._plan.expr import BinaryExpr if right.meta.has_multiple_outputs(): - raise _bin_op_multi_output_error(left, self, right) - + raise binary_expr_multi_output_error(left, self, right) if _is_filtration(left): if _is_filtration(right): - raise _bin_op_length_changing_error(left, self, right) + raise binary_expr_length_changing_error(left, self, right) if not right.is_scalar: - raise _bin_op_shape_error(left, self, right) + raise binary_expr_shape_error(left, self, right) elif _is_filtration(right): if not left.is_scalar: - raise _bin_op_shape_error(left, self, right) - + raise binary_expr_shape_error(left, self, right) return BinaryExpr(left=left, op=self, right=right) def __call__(self, lhs: Any, rhs: Any) -> Any: @@ -76,51 +75,6 @@ def __call__(self, lhs: Any, rhs: Any) -> Any: return self.__class__._op(lhs, rhs) -# NOTE: Always underlining `right`, since the message refers to both types of exprs -# Assuming the most recent as the issue -def _bin_op_shape_error(left: ExprIR, op: Operator, right: ExprIR) -> ShapeError: - lhs_op = f"{left!r} {op!r} " - rhs = repr(right) - indent = len(lhs_op) * " " - underline = len(rhs) * "^" - msg = ( - f"Cannot combine length-changing expressions with length-preserving ones.\n" - f"{lhs_op}{rhs}\n{indent}{underline}" - ) - return ShapeError(msg) - - -def _bin_op_multi_output_error( - left: ExprIR, op: Operator, right: ExprIR -) -> MultiOutputExpressionError: - lhs_op = f"{left!r} {op!r} " - rhs = repr(right) - indent = len(lhs_op) * " " - underline = len(rhs) * "^" - msg = ( - "Multi-output expressions are only supported on the " - f"left-hand side of a binary operation.\n" - f"{lhs_op}{rhs}\n{indent}{underline}" - ) - return MultiOutputExpressionError(msg) - - -def _bin_op_length_changing_error( - left: ExprIR, op: Operator, right: ExprIR -) -> LengthChangingExprError: - lhs, rhs = repr(left), repr(right) - op_s = f" {op!r} " - underline_left = len(lhs) * "^" - underline_right = len(rhs) * "^" - pad_middle = len(op_s) * " " - msg = ( - "Length-changing expressions can only be used in isolation, " - "or followed by an aggregation.\n" - f"{lhs}{op_s}{rhs}\n{underline_left}{pad_middle}{underline_right}" - ) - return LengthChangingExprError(msg) - - def _is_filtration(ir: ExprIR) -> bool: if not ir.is_scalar and isinstance(ir, FunctionExpr): return not ir.options.is_elementwise() diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py index f2dc5214a5..b742c0e3e3 100644 --- a/narwhals/_plan/window.py +++ b/narwhals/_plan/window.py @@ -3,7 +3,11 @@ from typing import TYPE_CHECKING from narwhals._plan.common import Immutable -from narwhals.exceptions import InvalidOperationError +from narwhals._plan.exceptions import ( + over_elementwise_error, + over_nested_error, + over_row_separable_error, +) if TYPE_CHECKING: from narwhals._plan.common import ExprIR, Seq @@ -34,17 +38,12 @@ def to_window_expr( from narwhals._plan.expr import FunctionExpr, WindowExpr if isinstance(expr, WindowExpr): - msg = "Cannot nest `over` statements." - raise InvalidOperationError(msg) - + raise over_nested_error(expr, partition_by, order_by) if isinstance(expr, FunctionExpr): if expr.options.is_elementwise(): - msg = f"Cannot use `over` on expressions which are elementwise.\n{expr!r}" - raise InvalidOperationError(msg) + raise over_elementwise_error(expr, partition_by, order_by) if expr.options.is_row_separable(): - msg = f"Cannot use `over` on expressions which change length.\n{expr!r}" - raise InvalidOperationError(msg) - + raise over_row_separable_error(expr, partition_by, order_by) return WindowExpr( expr=expr, partition_by=partition_by, order_by=order_by, options=self ) From ff1d61b5c4526b6c92108d1e2c9c748e7555250c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 1 Jun 2025 16:24:52 +0100 Subject: [PATCH 154/368] feat: Add `Expr.is_in` --- narwhals/_plan/boolean.py | 48 ++++++++++++++++++++++---- narwhals/_plan/dummy.py | 23 +++++++++---- narwhals/_plan/literal.py | 2 +- tests/plan/expr_parsing_test.py | 60 +++++++++++++++++++++++++++++++-- 4 files changed, 117 insertions(+), 16 deletions(-) diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/boolean.py index 48c2f7db76..e8ccdbc54b 100644 --- a/narwhals/_plan/boolean.py +++ b/narwhals/_plan/boolean.py @@ -6,16 +6,25 @@ from narwhals._plan.common import Function from narwhals._plan.options import FunctionFlags, FunctionOptions +from narwhals._typing_compat import TypeVar if t.TYPE_CHECKING: + from narwhals._plan.common import ExprIR, Seq # noqa: F401 + from narwhals._plan.dummy import DummySeries + from narwhals._plan.expr import Literal # noqa: F401 from narwhals.typing import ClosedInterval +OtherT = TypeVar("OtherT") +ExprT = TypeVar("ExprT", bound="ExprIR", default="ExprIR") + class BooleanFunction(Function): def __repr__(self) -> str: tp = type(self) - if tp is BooleanFunction: + if tp in {BooleanFunction, IsIn}: return tp.__name__ + if isinstance(self, IsIn): + return "is_in" m: dict[type[BooleanFunction], str] = { All: "all", Any: "any", @@ -29,7 +38,6 @@ def __repr__(self) -> str: IsFirstDistinct: "is_first_distinct", IsLastDistinct: "is_last_distinct", IsUnique: "is_unique", - IsIn: "is_in", Not: "not", } return m[tp] @@ -63,6 +71,7 @@ def function_options(self) -> FunctionOptions: ) +# NOTE: `lower_bound`, `upper_bound` aren't spec'd in the function enum. class IsBetween(BooleanFunction): """`lower_bound`, `upper_bound` aren't spec'd in the function enum. @@ -98,17 +107,44 @@ def function_options(self) -> FunctionOptions: return FunctionOptions.length_preserving() -class IsIn(BooleanFunction): - """``other` isn't spec'd in the function enum. +class IsIn(BooleanFunction, t.Generic[OtherT]): + __slots__ = ("other",) - See `IsBetween` comment. - """ + other: OtherT @property def function_options(self) -> FunctionOptions: return FunctionOptions.elementwise() +class IsInSeq(IsIn["Seq[t.Any]"]): + @classmethod + def from_iterable(cls, other: t.Iterable[t.Any], /) -> IsInSeq: + if not isinstance(other, (str, bytes)): + return IsInSeq(other=tuple(other)) + msg = f"`is_in` doesn't accept `str | bytes` as iterables, got: {type(other).__name__}" + raise TypeError(msg) + + +# NOTE: Shouldn't be allowed for lazy backends (maybe besides `polars`) +class IsInSeries(IsIn["Literal[DummySeries]"]): + @classmethod + def from_series(cls, other: DummySeries, /) -> IsInSeries: + from narwhals._plan.literal import SeriesLiteral + + return IsInSeries(other=SeriesLiteral(value=other).to_literal()) + + +# NOTE: Placeholder for allowing `Expr` iff it passes `.meta.is_column()` +class IsInExpr(IsIn[ExprT], t.Generic[ExprT]): + def __init__(self, *, other: ExprT) -> None: + msg = ( + "`is_in` doesn't accept expressions as an argument, as opposed to Polars. " + "You should provide an iterable instead." + ) + raise NotImplementedError(msg) + + class IsLastDistinct(BooleanFunction): @property def function_options(self) -> FunctionOptions: diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 2bf0b39af8..d0ce50bf10 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -13,6 +13,7 @@ functions as F, # noqa: N812 operators as ops, ) +from narwhals._plan.common import is_expr, is_series from narwhals._plan.options import ( EWMOptions, RankOptions, @@ -427,13 +428,18 @@ def is_between( boolean.IsBetween(closed=closed).to_function_expr(self._ir, *it) ) - def is_in(self, other: t.Any) -> Self: - msg = ( - "There's some special handling of iterables that I'm not sure on:\n" - "https://github.com/narwhals-dev/narwhals/blob/8975189cb2459f129017cf833075b28ec3d4dfa8/narwhals/expr.py#L1176-L1184" - ) - raise NotImplementedError(msg) - return self._from_ir(boolean.IsIn().to_function_expr(self._ir)) + def is_in(self, other: t.Iterable[t.Any]) -> Self: + node: boolean.IsIn[t.Any] + if is_series(other): + node = boolean.IsInSeries.from_series(other) + elif isinstance(other, t.Iterable): + node = boolean.IsInSeq.from_iterable(other) + elif is_expr(other): + node = boolean.IsInExpr(other=other._ir) + else: + msg = f"`is_in` only supports iterables, got: {type(other).__name__}" + raise TypeError(msg) + return self._from_ir(node.to_function_expr(self._ir)) def __eq__(self, other: IntoExpr) -> Self: # type: ignore[override] op = ops.Eq() @@ -693,6 +699,9 @@ def from_native(cls, native: NativeSeries, /) -> Self: def to_native(self) -> NativeSeries: return self._compliant._native + def __iter__(self) -> t.Iterator[t.Any]: + yield from self.to_native() + class DummySeriesV1(DummySeries): _version: t.ClassVar[Version] = Version.V1 diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index 16ddfbf6af..b878a02fce 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -28,7 +28,7 @@ def name(self) -> str: def is_scalar(self) -> bool: return False - def to_literal(self) -> Literal: + def to_literal(self) -> Literal[LiteralT]: from narwhals._plan.expr import Literal return Literal(value=self) diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 645f91f56b..4fa567386e 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -1,7 +1,8 @@ from __future__ import annotations import re -from typing import TYPE_CHECKING, Callable, Iterable +from collections import deque +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import pytest @@ -12,7 +13,7 @@ functions as F, # noqa: N812 ) from narwhals._plan.common import ExprIR, Function -from narwhals._plan.dummy import DummyExpr +from narwhals._plan.dummy import DummyExpr, DummySeries from narwhals._plan.expr import BinaryExpr, FunctionExpr from narwhals.exceptions import ( InvalidOperationError, @@ -22,9 +23,16 @@ ) if TYPE_CHECKING: + from typing import ContextManager + + from typing_extensions import TypeAlias + from narwhals._plan.common import IntoExpr, Seq +IntoIterable: TypeAlias = Callable[[Sequence[Any]], Iterable[Any]] + + @pytest.mark.parametrize( ("exprs", "named_exprs"), [ @@ -243,3 +251,51 @@ def test_invalid_binary_expr_shape() -> None: a.map_batches(lambda x: x, is_elementwise=True) * b.gather_every(1, 0) with pytest.raises(ShapeError, match=pattern): a / b.drop_nulls() + + +@pytest.mark.parametrize("into_iter", [list, tuple, deque, iter, dict.fromkeys, set]) +def test_is_in_seq(into_iter: IntoIterable) -> None: + expected = 1, 2, 3 + other = into_iter(list(expected)) + expr = nwd.col("a").is_in(other) + ir = expr._ir + assert isinstance(ir, FunctionExpr) + assert isinstance(ir.function, boolean.IsInSeq) + assert ir.function.other == expected + + +def test_is_in_series() -> None: + pytest.importorskip("polars") + import polars as pl + + native = pl.Series([1, 2, 3]) + other = DummySeries.from_native(native) + expr = nwd.col("a").is_in(other) + ir = expr._ir + assert isinstance(ir, FunctionExpr) + assert isinstance(ir.function, boolean.IsInSeries) + assert ir.function.other.unwrap().to_native() is native + + +@pytest.mark.parametrize( + ("other", "context"), + [ + ("words", pytest.raises(TypeError, match=r"str \| bytes.+str")), + (b"words", pytest.raises(TypeError, match=r"str \| bytes.+bytes")), + ( + nwd.col("b"), + pytest.raises( + NotImplementedError, match=re.compile(r"iterable instead", re.IGNORECASE) + ), + ), + ( + 999, + pytest.raises( + TypeError, match=re.compile(r"only.+iterable.+int", re.IGNORECASE) + ), + ), + ], +) +def test_invalid_is_in(other: Any, context: ContextManager[Any]) -> None: + with context: + nwd.col("a").is_in(other) From 4503b52c0a0465681defd8b915817181897260df Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 1 Jun 2025 17:16:31 +0100 Subject: [PATCH 155/368] feat: Add `InvertSelector` --- narwhals/_plan/dummy.py | 8 ++++++-- narwhals/_plan/expr.py | 11 +++++++++++ narwhals/_plan/typing.py | 11 ++++++++++- 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index d0ce50bf10..8516fb0c2b 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -580,6 +580,10 @@ class DummySelector(DummyExpr): >>> (ncs.matches("[^z]a") & ncs.string()) | ncs.datetime("us", None) Narwhals DummySelector (main): [([(ncs.matches(pattern='[^z]a')) & (ncs.string())]) | (ncs.datetime(time_unit=['us'], time_zone=[None]))] + >>> + >>> ~(ncs.boolean() | ncs.matches(r"is_.*")) + Narwhals DummySelector (main): + ~[(ncs.boolean()) | (ncs.matches(pattern='is_.*'))] """ _ir: expr.SelectorIR @@ -620,8 +624,8 @@ def __xor__(self, other: t.Any) -> Self | t.Any: return self._from_ir(op.to_binary_selector(self._ir, other._ir)) return self._to_expr() ^ other - def __invert__(self) -> Never: - raise NotImplementedError + def __invert__(self) -> Self: + return self._from_ir(expr.InvertSelector(selector=self._ir)) def __add__(self, other: t.Any) -> DummyExpr: # type: ignore[override] if isinstance(other, type(self)): diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index f68f3d2849..7aaf65bbd8 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -22,6 +22,7 @@ RightT, RollingT, SelectorOperatorT, + SelectorT, ) from narwhals.utils import flatten @@ -510,6 +511,16 @@ class BinarySelector( """ +class InvertSelector(SelectorIR, t.Generic[SelectorT]): + __slots__ = ("selector",) + + selector: SelectorT + """`(Root|Binary)Selector`.""" + + def __repr__(self) -> str: + return f"~{self.selector!r}" + + class Ternary(ExprIR): """When-Then-Otherwise.""" diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 292d76bf82..1297974bea 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -15,7 +15,15 @@ from narwhals._plan.functions import RollingWindow from narwhals.typing import NonNestedLiteral -__all__ = ["FunctionT", "LeftT", "OperatorT", "RightT", "RollingT", "SelectorOperatorT"] +__all__ = [ + "FunctionT", + "LeftT", + "OperatorT", + "RightT", + "RollingT", + "SelectorOperatorT", + "SelectorT", +] FunctionT = TypeVar("FunctionT", bound="Function") @@ -25,6 +33,7 @@ RightT = TypeVar("RightT", bound="ExprIR", default="ExprIR") OperatorFn: TypeAlias = "t.Callable[[t.Any, t.Any], t.Any]" +SelectorT = TypeVar("SelectorT", bound="SelectorIR", default="SelectorIR") LeftSelectorT = TypeVar("LeftSelectorT", bound="SelectorIR", default="SelectorIR") RightSelectorT = TypeVar("RightSelectorT", bound="SelectorIR", default="SelectorIR") SelectorOperatorT = TypeVar( From a461cb5a2913d154d55d3d6e6e99ef43cf3eb9f8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 1 Jun 2025 17:35:38 +0100 Subject: [PATCH 156/368] feat(typing): Add selectors `@overload`s --- narwhals/_plan/dummy.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 8516fb0c2b..9c2279265f 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -511,6 +511,11 @@ def __or__(self, other: IntoExpr) -> Self: rhs = parse.parse_into_expr_ir(other, str_as_lit=True) return self._from_ir(op.to_binary_expr(self._ir, rhs)) + def __xor__(self, other: IntoExpr) -> Self: + op = ops.ExclusiveOr() + rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + return self._from_ir(op.to_binary_expr(self._ir, rhs)) + def __invert__(self) -> Self: return self._from_ir(boolean.Not().to_function_expr(self._ir)) @@ -600,25 +605,41 @@ def _from_ir(cls, ir: expr.SelectorIR, /) -> Self: # type: ignore[override] def _to_expr(self) -> DummyExpr: return self._ir.to_narwhals(self.version) - def __or__(self, other: t.Any) -> Self | t.Any: + @t.overload # type: ignore[override] + def __or__(self, other: Self) -> Self: ... + @t.overload + def __or__(self, other: IntoExpr) -> DummyExpr: ... + def __or__(self, other: IntoExpr) -> Self | DummyExpr: if isinstance(other, type(self)): op = ops.Or() return self._from_ir(op.to_binary_selector(self._ir, other._ir)) return self._to_expr() | other - def __and__(self, other: t.Any) -> Self | t.Any: + @t.overload # type: ignore[override] + def __and__(self, other: Self) -> Self: ... + @t.overload + def __and__(self, other: IntoExpr) -> DummyExpr: ... + def __and__(self, other: IntoExpr) -> Self | DummyExpr: if isinstance(other, type(self)): op = ops.And() return self._from_ir(op.to_binary_selector(self._ir, other._ir)) return self._to_expr() & other - def __sub__(self, other: t.Any) -> Self | t.Any: + @t.overload # type: ignore[override] + def __sub__(self, other: Self) -> Self: ... + @t.overload + def __sub__(self, other: IntoExpr) -> DummyExpr: ... + def __sub__(self, other: IntoExpr) -> Self | DummyExpr: if isinstance(other, type(self)): op = ops.Sub() return self._from_ir(op.to_binary_selector(self._ir, other._ir)) return self._to_expr() - other - def __xor__(self, other: t.Any) -> Self | t.Any: + @t.overload # type: ignore[override] + def __xor__(self, other: Self) -> Self: ... + @t.overload + def __xor__(self, other: IntoExpr) -> DummyExpr: ... + def __xor__(self, other: IntoExpr) -> Self | DummyExpr: if isinstance(other, type(self)): op = ops.ExclusiveOr() return self._from_ir(op.to_binary_selector(self._ir, other._ir)) From 6baa808ae9e704674ee8398de05502840509f1bb Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 1 Jun 2025 18:17:28 +0100 Subject: [PATCH 157/368] feat: Implement rhs selector ops, add `selectors.by_name` `polars` implements them this way and we can too! https://github.com/pola-rs/polars/blob/a3d6a3a7863b4d42e720a05df69ff6b6f5fc551f/py-polars/polars/selectors.py#L420-L423 --- narwhals/_plan/common.py | 8 ++++++ narwhals/_plan/dummy.py | 54 +++++++++++++++++++++++++++++-------- narwhals/_plan/selectors.py | 11 ++++++++ 3 files changed, 62 insertions(+), 11 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 5e45df50bf..65efa1eba1 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -300,6 +300,14 @@ def is_expr(obj: Any) -> TypeIs[DummyExpr]: return isinstance(obj, DummyExpr) +def is_column(obj: Any) -> TypeIs[DummyExpr]: + """Indicate if the given object is a basic/unaliased column. + + https://github.com/pola-rs/polars/blob/a3d6a3a7863b4d42e720a05df69ff6b6f5fc551f/py-polars/polars/_utils/various.py#L164-L168. + """ + return is_expr(obj) and obj.meta.is_column() + + def is_series(obj: Any) -> TypeIs[DummySeries]: from narwhals._plan.dummy import DummySeries diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 9c2279265f..f5959e7c5b 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -13,7 +13,7 @@ functions as F, # noqa: N812 operators as ops, ) -from narwhals._plan.common import is_expr, is_series +from narwhals._plan.common import is_column, is_expr, is_series from narwhals._plan.options import ( EWMOptions, RankOptions, @@ -22,6 +22,7 @@ SortMultipleOptions, SortOptions, ) +from narwhals._plan.selectors import by_name from narwhals._plan.window import Over from narwhals.dtypes import DType from narwhals.exceptions import ComputeError @@ -506,16 +507,25 @@ def __and__(self, other: IntoExpr) -> Self: rhs = parse.parse_into_expr_ir(other, str_as_lit=True) return self._from_ir(op.to_binary_expr(self._ir, rhs)) + def __rand__(self, other: IntoExpr) -> Self: + return (self & other).alias("literal") + def __or__(self, other: IntoExpr) -> Self: op = ops.Or() rhs = parse.parse_into_expr_ir(other, str_as_lit=True) return self._from_ir(op.to_binary_expr(self._ir, rhs)) + def __ror__(self, other: IntoExpr) -> Self: + return (self | other).alias("literal") + def __xor__(self, other: IntoExpr) -> Self: op = ops.ExclusiveOr() rhs = parse.parse_into_expr_ir(other, str_as_lit=True) return self._from_ir(op.to_binary_expr(self._ir, rhs)) + def __rxor__(self, other: IntoExpr) -> Self: + return (self ^ other).alias("literal") + def __invert__(self) -> Self: return self._from_ir(boolean.Not().to_function_expr(self._ir)) @@ -620,6 +630,8 @@ def __and__(self, other: Self) -> Self: ... @t.overload def __and__(self, other: IntoExpr) -> DummyExpr: ... def __and__(self, other: IntoExpr) -> Self | DummyExpr: + if is_column(other) and (name := other.meta.output_name()): + other = by_name(name) if isinstance(other, type(self)): op = ops.And() return self._from_ir(op.to_binary_selector(self._ir, other._ir)) @@ -654,20 +666,40 @@ def __add__(self, other: t.Any) -> DummyExpr: # type: ignore[override] raise TypeError(msg) return self._to_expr() + other # type: ignore[no-any-return] - def __rsub__(self, other: t.Any) -> Never: - raise NotImplementedError + def __radd__(self, other: t.Any) -> Never: + msg = "unsupported operand type(s) for op: ('Expr' + 'Selector')" + raise TypeError(msg) - def __rand__(self, other: t.Any) -> Never: - raise NotImplementedError + def __rsub__(self, other: t.Any) -> Never: + msg = "unsupported operand type(s) for op: ('Expr' - 'Selector')" + raise TypeError(msg) - def __ror__(self, other: t.Any) -> Never: - raise NotImplementedError + @t.overload # type: ignore[override] + def __rand__(self, other: Self) -> Self: ... + @t.overload + def __rand__(self, other: IntoExpr) -> DummyExpr: ... + def __rand__(self, other: IntoExpr) -> Self | DummyExpr: + if is_column(other) and (name := other.meta.output_name()): + return by_name(name) & self + return self._to_expr().__rand__(other) - def __rxor__(self, other: t.Any) -> Never: - raise NotImplementedError + @t.overload # type: ignore[override] + def __ror__(self, other: Self) -> Self: ... + @t.overload + def __ror__(self, other: IntoExpr) -> DummyExpr: ... + def __ror__(self, other: IntoExpr) -> Self | DummyExpr: + if is_column(other) and (name := other.meta.output_name()): + return by_name(name) | self + return self._to_expr().__ror__(other) - def __radd__(self, other: t.Any) -> Never: - raise NotImplementedError + @t.overload # type: ignore[override] + def __rxor__(self, other: Self) -> Self: ... + @t.overload + def __rxor__(self, other: IntoExpr) -> DummyExpr: ... + def __rxor__(self, other: IntoExpr) -> Self | DummyExpr: + if is_column(other) and (name := other.meta.output_name()): + return by_name(name) ^ self + return self._to_expr().__rxor__(other) class DummyExprV1(DummyExpr): diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/selectors.py index 56f8ec3f4c..4360902e56 100644 --- a/narwhals/_plan/selectors.py +++ b/narwhals/_plan/selectors.py @@ -99,6 +99,13 @@ class Matches(Selector): def from_string(pattern: str, /) -> Matches: return Matches(pattern=re.compile(pattern)) + @staticmethod + def from_names(*names: str | Iterable[str]) -> Matches: + """Implements `cs.by_name` to support `__r__` with column selections.""" + it: Iterator[str] = _flatten_hash_safe(names) + pattern = f"^({'|'.join(re.escape(name) for name in it)})$" + return Matches.from_string(pattern) + def __repr__(self) -> str: return f"ncs.matches(pattern={self.pattern.pattern!r})" @@ -123,6 +130,10 @@ def by_dtype( return ByDType.from_dtypes(*dtypes).to_selector().to_narwhals() +def by_name(*names: str | Iterable[str]) -> DummySelector: + return Matches.from_names(*names).to_selector().to_narwhals() + + def boolean() -> DummySelector: return Boolean().to_selector().to_narwhals() From 5eb20af2f2dbb22d9c98fdfcf19692dde4b2aee4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 1 Jun 2025 19:32:09 +0100 Subject: [PATCH 158/368] feat: Ban multi-output expressions from `Alias` --- narwhals/_plan/exceptions.py | 6 ++++++ narwhals/_plan/expr.py | 11 ++++++++++- tests/plan/expr_parsing_test.py | 22 ++++++++++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index 6c448ad433..0f10bd410b 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -6,6 +6,7 @@ from narwhals.exceptions import ( ComputeError, + DuplicateError, InvalidIntoExprError, InvalidOperationError, LengthChangingExprError, @@ -157,3 +158,8 @@ def is_iterable_polars_error( "- used `pl.col` instead of `nw.col`?" ) return TypeError(msg) + + +def alias_duplicate_error(expr: ExprIR, name: str) -> DuplicateError: + msg = f"Cannot apply alias {name!r} to multi-output expression:\n{expr!r}" + return DuplicateError(msg) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 7aaf65bbd8..1aab360be6 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -8,7 +8,10 @@ from narwhals._plan.aggregation import Agg, OrderableAgg from narwhals._plan.common import ExprIR, SelectorIR, _field_str, is_non_nested_literal -from narwhals._plan.exceptions import function_expr_invalid_operation_error +from narwhals._plan.exceptions import ( + alias_duplicate_error, + function_expr_invalid_operation_error, +) from narwhals._plan.name import KeepName, RenameAlias from narwhals._plan.typing import ( ExprT, @@ -88,6 +91,12 @@ def iter_right(self) -> t.Iterator[ExprIR]: yield self yield from self.expr.iter_right() + def __init__(self, *, expr: ExprIR, name: str) -> None: + if expr.meta.has_multiple_outputs(): + raise alias_duplicate_error(expr, name) + kwds = {"expr": expr, "name": name} + super().__init__(**kwds) + class Column(ExprIR): __slots__ = ("name",) diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 4fa567386e..b7e71ad86a 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -11,11 +11,13 @@ from narwhals._plan import ( boolean, functions as F, # noqa: N812 + selectors as ndcs, ) from narwhals._plan.common import ExprIR, Function from narwhals._plan.dummy import DummyExpr, DummySeries from narwhals._plan.expr import BinaryExpr, FunctionExpr from narwhals.exceptions import ( + DuplicateError, InvalidOperationError, LengthChangingExprError, MultiOutputExpressionError, @@ -299,3 +301,23 @@ def test_is_in_series() -> None: def test_invalid_is_in(other: Any, context: ContextManager[Any]) -> None: with context: nwd.col("a").is_in(other) + + +@pytest.mark.parametrize( + "expr", + [ + nwd.all(), + nwd.nth(1, 2, 3), + nwd.col("a", "b", "c"), + ndcs.boolean(), + (ndcs.by_name("a", "b") | ndcs.string()), + (nwd.col("b", "c") & nwd.col("a")), + nwd.col("a", "b").min().over("c", order_by="e"), + (~ndcs.by_dtype(nw.Int64()) - ndcs.datetime()), + nwd.nth(6, 2).abs().cast(nw.Int32()) + 10, + ], +) +def test_invalid_alias(expr: DummyExpr) -> None: + pattern = re.compile(r"alias.+dupe.+multi\-output") + with pytest.raises(DuplicateError, match=pattern): + expr.alias("dupe") From 01f4d26346f3046db5d34c330b372356540ab9da Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 1 Jun 2025 22:08:15 +0100 Subject: [PATCH 159/368] feat(DRAFT): Start working on `expr_expansion` Now this'll be quite the feat to pull off https://github.com/pola-rs/polars/blob/df4d21c30c2b383b651e194f8263244f2afaeda3/crates/polars-plan/src/plans/conversion/expr_expansion.rs --- narwhals/_plan/expr_expansion.py | 75 ++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 narwhals/_plan/expr_expansion.py diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py new file mode 100644 index 0000000000..798c6947ac --- /dev/null +++ b/narwhals/_plan/expr_expansion.py @@ -0,0 +1,75 @@ +"""Based on [polars-plan/src/plans/conversion/expr_expansion.rs]. + +- Goal is to expand every selection into a named column. +- Most will require only the column names of the schema. + +[polars-plan/src/plans/conversion/expr_expansion.rs]: https://github.com/pola-rs/polars/blob/df4d21c30c2b383b651e194f8263244f2afaeda3/crates/polars-plan/src/plans/conversion/expr_expansion.rs +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from narwhals._plan.common import Immutable + +if TYPE_CHECKING: + from narwhals._plan.common import ExprIR + from narwhals._plan.dummy import DummyExpr + + +class ExpansionFlags(Immutable): + """`polars` uses a struct, but we may want to use `enum.Flag`.""" + + __slots__ = ( + "has_exclude", + "has_nth", + "has_selector", + "has_wildcard", + "multiple_columns", + ) + multiple_columns: bool + has_nth: bool + has_wildcard: bool + has_selector: bool + has_exclude: bool + + @property + def expands(self) -> bool: + """If we add struct stuff, that would slot in here as well.""" + return self.multiple_columns + + @staticmethod + def from_ir(ir: ExprIR, /) -> ExpansionFlags: + """Subset of [`find_flags`]. + + [`find_flags`]: https://github.com/pola-rs/polars/blob/df4d21c30c2b383b651e194f8263244f2afaeda3/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L607-L660 + """ + from narwhals._plan import expr + + multiple_columns: bool = False + has_nth: bool = False + has_wildcard: bool = False + has_selector: bool = False + has_exclude: bool = False + for e in ir.iter_left(): + if isinstance(e, (expr.Columns, expr.IndexColumns)): + multiple_columns = True + elif isinstance(e, expr.Nth): + has_nth = True + elif isinstance(e, expr.All): + has_wildcard = True + elif isinstance(e, expr.SelectorIR): + has_selector = True + elif isinstance(e, expr.Exclude): + has_exclude = True + return ExpansionFlags( + multiple_columns=multiple_columns, + has_nth=has_nth, + has_wildcard=has_wildcard, + has_selector=has_selector, + has_exclude=has_exclude, + ) + + @classmethod + def from_expr(cls, expr: DummyExpr, /) -> ExpansionFlags: + return cls.from_ir(expr._ir) From 5a1d15f18b2fb754a4b2716f9667f9d740a14e44 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 2 Jun 2025 16:53:36 +0100 Subject: [PATCH 160/368] start stubbing out `expr_expansion` - Quite a lot going on here - Turning the `result` parameter into the returned value seems like a logical step --- narwhals/_plan/expr_expansion.py | 206 ++++++++++++++++++++++++++++++- 1 file changed, 204 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 798c6947ac..d614629bee 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -6,15 +6,45 @@ [polars-plan/src/plans/conversion/expr_expansion.rs]: https://github.com/pola-rs/polars/blob/df4d21c30c2b383b651e194f8263244f2afaeda3/crates/polars-plan/src/plans/conversion/expr_expansion.rs """ +# ruff: noqa: A002 from __future__ import annotations -from typing import TYPE_CHECKING +from copy import deepcopy +from types import MappingProxyType +from typing import TYPE_CHECKING, Any, Mapping, Sequence from narwhals._plan.common import Immutable if TYPE_CHECKING: - from narwhals._plan.common import ExprIR + from typing_extensions import TypeAlias + + from narwhals._plan import expr, selectors + from narwhals._plan.common import ExprIR, Seq from narwhals._plan.dummy import DummyExpr + from narwhals.dtypes import DType + + +FrozenSchema: TypeAlias = "MappingProxyType[str, DType]" +FrozenColumns: TypeAlias = "Seq[str]" +Excluded: TypeAlias = "frozenset[str]" +"""Internally use a `set`, then freeze before returning.""" + +Inplace: TypeAlias = Any +"""Functions where `polars` does in-place mutations on `Expr`. + +Very likely that **we won't** do this in `narwhals`, instead return a new object. +""" + + +# NOTE: Both `_freeze` functions will probably want to be cached +# In the traversal/expand/replacement functions, their returns will be hashable -> safe to cache those as well +def _freeze_schema(**schema: DType) -> FrozenSchema: + copied = deepcopy(schema) + return MappingProxyType(copied) + + +def _freeze_columns(schema: FrozenSchema, /) -> FrozenColumns: + return tuple(schema) class ExpansionFlags(Immutable): @@ -73,3 +103,175 @@ def from_ir(ir: ExprIR, /) -> ExpansionFlags: @classmethod def from_expr(cls, expr: DummyExpr, /) -> ExpansionFlags: return cls.from_ir(expr._ir) + + +def prepare_projection( + exprs: Sequence[ExprIR], schema: Mapping[str, DType] +) -> tuple[Seq[ExprIR], FrozenSchema]: + frozen_schema = _freeze_schema(**schema) + rewritten = rewrite_projections(tuple(exprs), keys=(), schema=frozen_schema) + # NOTE: There's an `expressions_to_schema` step that I'm skipping for now + # seems too big of a rabbit hole to go down + return rewritten, frozen_schema + + +# NOTE: Parameters have been re-ordered, renamed, changed types +# - `origin` is the `Expr` that's being iterated over +# - `result` *haven't got to yet* +# - Couldn't this just be the return type? +# - Certainly less complicated in python +# - `` is the current child of `origin` +# - `col_names: FrozenColumns` is used when we don't need the dtypes +# - `exclude` is the return of `prepare_excluded` + + +def expand_function_inputs(origin: ExprIR, /, *, schema: FrozenSchema) -> ExprIR: + raise NotImplementedError + + +def rewrite_projections( + input: Seq[ExprIR], # `FunctionExpr.input` + /, + keys: Seq[ExprIR], + *, + schema: FrozenSchema, +) -> Seq[ExprIR]: + raise NotImplementedError + + +def replace_selector( + ir: ExprIR, # an element of `FunctionExpr.input` + /, + keys: Seq[ExprIR], + *, + schema: FrozenSchema, +) -> ExprIR: + raise NotImplementedError + + +def expand_selector( + s: expr.SelectorIR, /, keys: Seq[ExprIR], *, schema: FrozenSchema +) -> Seq[str]: + """Converts into input of `Columns(...)`.""" + raise NotImplementedError + + +def replace_selector_inner( + s: expr.SelectorIR, + /, + keys: Seq[ExprIR], + members: Any, # mutable, insertion order preserving set `PlIndexSet` + scratch: Seq[ExprIR], # passed as `result` into `replace_and_add_to_results` + *, + schema: FrozenSchema, +) -> Inplace: + raise NotImplementedError + + +def replace_and_add_to_results( + origin: ExprIR, + /, + result: Seq[ExprIR], + keys: Seq[ExprIR], + *, + schema: FrozenSchema, + flags: ExpansionFlags, +) -> Inplace: + raise NotImplementedError + + +# NOTE: See how far we can get with just the direct node replacements +# - `polars` is using `map_expr`, but I haven't implemented that (yet?) +def replace_nth(nth: expr.Nth, /, col_names: FrozenColumns) -> expr.Column: + from narwhals._plan import expr + + return expr.Column(name=col_names[nth.index]) + + +def prepare_excluded( + origin: ExprIR, /, keys: Seq[ExprIR], *, schema: FrozenSchema, has_exclude: bool +) -> Excluded: + raise NotImplementedError + + +def expand_columns( + origin: ExprIR, + /, + result: Seq[ExprIR], + columns: expr.Columns, # `polars` uses columns.names + *, + col_names: FrozenColumns, + exclude: Excluded, +) -> Inplace: + raise NotImplementedError + + +def expand_dtypes( + origin: ExprIR, + /, + result: Seq[ExprIR], + dtypes: selectors.ByDType, # we haven't got `DtypeColumn` + *, + schema: FrozenSchema, + exclude: Excluded, +) -> Inplace: + raise NotImplementedError + + +def expand_indices( + origin: ExprIR, + /, + result: Seq[ExprIR], + indices: expr.IndexColumns, + *, + schema: FrozenSchema, + exclude: Excluded, +) -> Inplace: + raise NotImplementedError + + +def replace_wildcard( + origin: ExprIR, /, result: Seq[ExprIR], *, col_names: FrozenColumns, exclude: Excluded +) -> Inplace: + raise NotImplementedError + + +def replace_wildcard_with_column(origin: ExprIR, /, column_name: str) -> ExprIR: + """`expr.All` and `Exclude`.""" + raise NotImplementedError + + +def rewrite_special_aliases(origin: ExprIR, /) -> ExprIR: + """`KeepName` and `RenameAlias`. + + Reuses some of the `meta` functions to traverse the names. + """ + raise NotImplementedError + + +def replace_dtype_or_index_with_column( + origin: ExprIR, /, column_name: str, *, replace_dtype: bool +) -> ExprIR: + raise NotImplementedError + + +def dtypes_match(left: DType, right: DType | type[DType]) -> bool: + return left == right + + +def replace_regex( + origin: ExprIR, + /, + result: Seq[ExprIR], + pattern: selectors.Matches, + *, + col_names: FrozenColumns, + exclude: Excluded, +) -> Inplace: + raise NotImplementedError + + +def expand_regex( + origin: ExprIR, /, result: Seq[ExprIR], *, col_names: FrozenColumns, exclude: Excluded +) -> Inplace: + raise NotImplementedError From 8c50c578be79620ce49fb884fd68858f8b62c61b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 3 Jun 2025 15:02:11 +0100 Subject: [PATCH 161/368] docs: Explain where `expr_expansion` fits in - Currently the main hurdle towards integrating `ExprIR` into `narwhals` proper - Resolving this will give us a backend agnostic solution to converting all of these nodes into simpler versions: 1. `Alias(Column, str)` 2. `RenameAlias(Columns, function)` 3. `KeepName(Column)` Each backend then just needs to provide a `Schema` and a way to select columns by name --- narwhals/_plan/expr_expansion.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index d614629bee..0dbe3c11ac 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -1,9 +1,38 @@ """Based on [polars-plan/src/plans/conversion/expr_expansion.rs]. +## Notes - Goal is to expand every selection into a named column. - Most will require only the column names of the schema. +## Current `narwhals` +As of [6e57eff4f059c748cf84ddcae276a74318720b85], many of the problems +this module would solve *currently* have solutions distributed throughout `narwhals`. + +Their dependencies are **quite** complex, with the main ones being: +- `CompliantExpr` + - _evaluate_output_names + - `CompliantSelector.__(sub|or|and|invert)__` + - `CompliantThen._evaluate_output_names` + - _alias_output_names + - from_column_names, from_column_indices + - `CompliantNamespace.(all|col|exclude|nth)` + - _eval_names_indices + - _evaluate_aliases + - `Compliant*Frame._evaluate_aliases` + - `EagerDataFrame._evaluate_into_expr(s)` +- `CompliantExprNameNamespace` + - EagerExprNameNamespace + - LazyExprNameNamespace +- `_expression_parsing.py` + - combine_evaluate_output_names + - 6-7x per `CompliantNamespace` + - combine_alias_output_names + - 6-7x per `CompliantNamespace` + - evaluate_output_names_and_aliases + - Depth tracking (`Expr.over`, `GroupyBy.agg`) + [polars-plan/src/plans/conversion/expr_expansion.rs]: https://github.com/pola-rs/polars/blob/df4d21c30c2b383b651e194f8263244f2afaeda3/crates/polars-plan/src/plans/conversion/expr_expansion.rs +[6e57eff4f059c748cf84ddcae276a74318720b85]: https://github.com/narwhals-dev/narwhals/commit/6e57eff4f059c748cf84ddcae276a74318720b85 """ # ruff: noqa: A002 From cb79cb63dd5de7b41eedf4c86de734ba4a4c3ef9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 5 Jun 2025 13:31:33 +0100 Subject: [PATCH 162/368] perf: Cache the hash of `Immutable` instances - I noticed something similar in both `cudf_polars` and `ibis` - Redcues the cost of comparisons as the size of the graph grows https://github.com/rapidsai/cudf/blob/f97ff6c952e16327503ad8bc2e1ece89899e1acb/python/cudf_polars/cudf_polars/dsl/nodebase.py#L68-L97 https://github.com/ibis-project/ibis/blob/e943f49890b0de8a5472a0651751d12a090d62bd/ibis/common/bases.py#L132-L246 --- narwhals/_plan/common.py | 58 ++++++++++++++++++++++++++---------- tests/plan/immutable_test.py | 16 ++++++++++ 2 files changed, 58 insertions(+), 16 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 65efa1eba1..a53cc1a0e4 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -8,7 +8,7 @@ from narwhals.utils import Version if TYPE_CHECKING: - from typing import Any, Callable, Iterator + from typing import Any, Callable, Iterator, Literal from typing_extensions import Never, Self, TypeAlias, TypeIs, dataclass_transform @@ -62,10 +62,38 @@ def decorator(cls_or_fn: T) -> T: IntoExprColumn: TypeAlias = "DummyExpr | DummySeries | str" IntoExpr: TypeAlias = "NonNestedLiteral | IntoExprColumn" +_IMMUTABLE_HASH_NAME: Literal["__immutable_hash_value__"] = "__immutable_hash_value__" + @dataclass_transform(kw_only_default=True, frozen_default=True) class Immutable: - __slots__ = () + __slots__ = (_IMMUTABLE_HASH_NAME,) + __immutable_hash_value__: int + + @property + def __immutable_keys__(self) -> Iterator[str]: + slots: tuple[str, ...] = self.__slots__ + for name in slots: + if name != _IMMUTABLE_HASH_NAME: + yield name + + @property + def __immutable_values__(self) -> Iterator[Any]: + for name in self.__immutable_keys__: + yield getattr(self, name) + + @property + def __immutable_items__(self) -> Iterator[tuple[str, Any]]: + for name in self.__immutable_keys__: + yield name, getattr(self, name) + + @property + def __immutable_hash__(self) -> int: + if hasattr(self, _IMMUTABLE_HASH_NAME): + return self.__immutable_hash_value__ + hash_value = hash((self.__class__, *self.__immutable_values__)) + object.__setattr__(self, _IMMUTABLE_HASH_NAME, hash_value) + return self.__immutable_hash_value__ def __setattr__(self, name: str, value: Never) -> Never: msg = f"{type(self).__name__!r} is immutable, {name!r} cannot be set." @@ -79,45 +107,43 @@ def __init_subclass__(cls, *args: Any, **kwds: Any) -> None: cls.__slots__ = () def __hash__(self) -> int: - slots: tuple[str, ...] = self.__slots__ - it = (getattr(self, name) for name in slots) - return hash((self.__class__, *it)) + return self.__immutable_hash__ def __eq__(self, other: object) -> bool: if self is other: return True elif type(self) is not type(other): return False - slots: tuple[str, ...] = self.__slots__ - return all(getattr(self, name) == getattr(other, name) for name in slots) + return all( + getattr(self, key) == getattr(other, key) for key in self.__immutable_keys__ + ) def __str__(self) -> str: # NOTE: Debug repr, closer to constructor - slots: tuple[str, ...] = self.__slots__ - fields = ", ".join(f"{_field_str(name, getattr(self, name))}" for name in slots) + fields = ", ".join(f"{_field_str(k, v)}" for k, v in self.__immutable_items__) return f"{type(self).__name__}({fields})" def __init__(self, **kwds: Any) -> None: # NOTE: DUMMY CONSTRUCTOR - don't use beyond prototyping! # Just need a quick way to demonstrate `ExprIR` and interactions - slots: set[str] = set(self.__slots__) - if not slots and not kwds: + required: set[str] = set(self.__immutable_keys__) + if not required and not kwds: # NOTE: Fastpath for empty slots ... - elif slots == set(kwds): + elif required == set(kwds): # NOTE: Everything is as expected for name, value in kwds.items(): object.__setattr__(self, name, value) - elif missing := slots.difference(kwds): + elif missing := required.difference(kwds): msg = ( - f"{type(self).__name__!r} requires attributes {sorted(slots)!r}, \n" + f"{type(self).__name__!r} requires attributes {sorted(required)!r}, \n" f"but missing values for {sorted(missing)!r}" ) raise TypeError(msg) else: - extra = set(kwds).difference(slots) + extra = set(kwds).difference(required) msg = ( - f"{type(self).__name__!r} only supports attributes {sorted(slots)!r}, \n" + f"{type(self).__name__!r} only supports attributes {sorted(required)!r}, \n" f"but got unknown arguments {sorted(extra)!r}" ) raise TypeError(msg) diff --git a/tests/plan/immutable_test.py b/tests/plan/immutable_test.py index ba15a50828..3c5e97439e 100644 --- a/tests/plan/immutable_test.py +++ b/tests/plan/immutable_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +import string +from itertools import repeat from typing import Any import pytest @@ -131,3 +133,17 @@ def test_immutable_invalid_constructor() -> None: OneSlot(1, 2, 3) # type: ignore[call-arg, misc] with pytest.raises(TypeError): OneSlot(1, a=1) # type: ignore[misc] + + +def test_immutable_hash_cache() -> None: + int_long = 9999999999999999999999999999999999999999999999999999999999 + str_long = "\n".join(repeat(string.printable, 100)) + obj = TwoSlot(a=int_long, b=str_long) + + with pytest.raises(AttributeError): + uncached = obj.__immutable_hash_value__ # noqa: F841 + + hash_cache_miss = hash(obj) + cached = obj.__immutable_hash_value__ + hash_cache_hit = hash(obj) + assert hash_cache_miss == cached == hash_cache_hit From 570d6e17c6ca0ed072dd16a8a70bd6b74688a4f4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 5 Jun 2025 16:53:04 +0100 Subject: [PATCH 163/368] feat(DRAFT): Add `_ColumnSelection`, `expand_columns` - Very basic to start - Skipping lots of obvious bounds checks - Trying to get an idea of how best to handle `RenameAlias` on a multi-selection --- narwhals/_plan/exceptions.py | 14 ++++ narwhals/_plan/expr.py | 158 ++++++++++++++++++++++------------- narwhals/_plan/meta.py | 13 +-- 3 files changed, 117 insertions(+), 68 deletions(-) diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index 0f10bd410b..5c165e89d9 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from narwhals.exceptions import ( + ColumnNotFoundError, ComputeError, DuplicateError, InvalidIntoExprError, @@ -163,3 +164,16 @@ def is_iterable_polars_error( def alias_duplicate_error(expr: ExprIR, name: str) -> DuplicateError: msg = f"Cannot apply alias {name!r} to multi-output expression:\n{expr!r}" return DuplicateError(msg) + + +def column_not_found_error( + subset: Iterable[str], /, available: Iterable[str] +) -> ColumnNotFoundError: + """Similar to `utils.check_columns_exist`, but when we already know there are missing. + + Signature differs to allow passing in a schema to `available`. + That form is what we're working with here. + """ + available = tuple(available) + missing = set(subset).difference(available) + return ColumnNotFoundError.from_missing_and_available_column_names(missing, available) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 1aab360be6..a566535598 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -10,6 +10,7 @@ from narwhals._plan.common import ExprIR, SelectorIR, _field_str, is_non_nested_literal from narwhals._plan.exceptions import ( alias_duplicate_error, + column_not_found_error, function_expr_invalid_operation_error, ) from narwhals._plan.name import KeepName, RenameAlias @@ -30,7 +31,7 @@ from narwhals.utils import flatten if t.TYPE_CHECKING: - from typing_extensions import Self + from typing_extensions import Self, TypeAlias from narwhals._plan.common import Seq from narwhals._plan.functions import MapBatches # noqa: F401 @@ -69,6 +70,12 @@ "WindowExpr", ] +_Schema: TypeAlias = "t.Mapping[str, DType]" +"""Equivalent to `expr_expansion.FrozenSchema`. + +Using temporarily before adding caching into the mix. +""" + class Alias(ExprIR): __slots__ = ("expr", "name") @@ -110,7 +117,23 @@ def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: return plx.col(self.name) -class Columns(ExprIR): +def _col(name: str, /) -> Column: + return Column(name=name) + + +def _cols(names: t.Iterable[str], /) -> Seq[Column]: + return tuple(_col(name) for name in names) + + +class _ColumnSelection(ExprIR): + """Nodes which can resolve to `Column`(s) with a `Schema`.""" + + def expand_columns(self, schema: _Schema, /) -> Seq[Column]: + """Transform selection in context of `schema` into simpler nodes.""" + raise NotImplementedError + + +class Columns(_ColumnSelection): __slots__ = ("names",) names: Seq[str] @@ -121,6 +144,83 @@ def __repr__(self) -> str: def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: return plx.col(*self.names) + def expand_columns(self, schema: _Schema) -> Seq[Column]: + if set(schema).issuperset(self.names): + return _cols(self.names) + raise column_not_found_error(self.names, schema) + + +class Nth(_ColumnSelection): + __slots__ = ("index",) + + index: int + + def __repr__(self) -> str: + return f"nth({self.index})" + + def expand_columns(self, schema: _Schema) -> Seq[Column]: + name = tuple(schema)[self.index] + return (_col(name),) + + +class IndexColumns(_ColumnSelection): + """Renamed from `IndexColumn`. + + `Nth` provides the singular variant. + + https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L80 + """ + + __slots__ = ("indices",) + + indices: Seq[int] + + def __repr__(self) -> str: + return f"index_columns({self.indices!r})" + + def expand_columns(self, schema: _Schema) -> Seq[Column]: + names = tuple(schema) + return _cols(names[index] for index in self.indices) + + +class All(_ColumnSelection): + """Aka Wildcard (`pl.all()` or `pl.col("*")`). + + https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L137 + """ + + def __repr__(self) -> str: + return "all()" + + def expand_columns(self, schema: _Schema) -> Seq[Column]: + return _cols(schema) + + +class Exclude(_ColumnSelection): + __slots__ = ("expr", "names") + + expr: ExprIR + """Default is `all()`.""" + names: Seq[str] + """Excluded names. + + - We're using a `frozenset` in main. + - Might want to switch to that later. + """ + + @staticmethod + def from_names(expr: ExprIR, *names: str | t.Iterable[str]) -> Exclude: + return Exclude(expr=expr, names=tuple(flatten(names))) + + def __repr__(self) -> str: + return f"{self.expr!r}.exclude({list(self.names)!r})" + + def expand_columns(self, schema: _Schema) -> Seq[Column]: + if not isinstance(self.expr, All): + msg = f"Only {All()!r} is currently supported with `exclude()`" + raise NotImplementedError(msg) + return _cols(name for name in schema if name not in self.names) + class Literal(ExprIR, t.Generic[LiteralT]): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L81.""" @@ -442,60 +542,6 @@ def __repr__(self) -> str: return "len()" -class Exclude(ExprIR): - __slots__ = ("expr", "names") - - expr: ExprIR - """Default is `all()`.""" - names: Seq[str] - """We're using a `frozenset` in main. - - Might want to switch to that later. - """ - - @staticmethod - def from_names(expr: ExprIR, *names: str | t.Iterable[str]) -> Exclude: - return Exclude(expr=expr, names=tuple(flatten(names))) - - def __repr__(self) -> str: - return f"{self.expr!r}.exclude({list(self.names)!r})" - - -class Nth(ExprIR): - __slots__ = ("index",) - - index: int - - def __repr__(self) -> str: - return f"nth({self.index})" - - -class IndexColumns(ExprIR): - """Renamed from `IndexColumn`. - - `Nth` provides the singular variant. - - https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L80 - """ - - __slots__ = ("indices",) - - indices: Seq[int] - - def __repr__(self) -> str: - return f"index_columns({self.indices!r})" - - -class All(ExprIR): - """Aka Wildcard (`pl.all()` or `pl.col("*")`). - - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L137 - """ - - def __repr__(self) -> str: - return "all()" - - class RootSelector(SelectorIR): """A single selector expression.""" diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index ad2b30c055..330320fac7 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -196,18 +196,7 @@ def _is_literal(ir: ExprIR, *, allow_aliasing: bool) -> bool: def _is_column_selection(ir: ExprIR, *, allow_aliasing: bool) -> bool: from narwhals._plan import expr - if isinstance( - ir, - ( - expr.Column, - expr.Columns, - expr.Exclude, - expr.Nth, - expr.IndexColumns, - expr.SelectorIR, - expr.All, - ), - ): + if isinstance(ir, (expr.Column, expr._ColumnSelection, expr.SelectorIR)): return True if isinstance(ir, (expr.Alias, expr.KeepName, expr.RenameAlias)): return allow_aliasing From 590af55073cad8d7a2835d53c87772f98faa20ef Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 5 Jun 2025 20:19:56 +0100 Subject: [PATCH 164/368] test: Add `expr_expansion_test` Would be good know these stay working while implementing the more complex cases --- tests/plan/expr_expansion_test.py | 144 ++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 tests/plan/expr_expansion_test.py diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py new file mode 100644 index 0000000000..0b54f8801b --- /dev/null +++ b/tests/plan/expr_expansion_test.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence + +import pytest + +import narwhals as nw +import narwhals._plan.demo as nwd +from narwhals._plan.expr import Column, _ColumnSelection +from narwhals.exceptions import ColumnNotFoundError + +if TYPE_CHECKING: + from typing_extensions import TypeIs + + from narwhals._plan.common import ExprIR + from narwhals._plan.dummy import DummyExpr + from narwhals.dtypes import DType + + +@pytest.fixture +def schema_1() -> dict[str, DType]: + return { + "a": nw.Int64(), + "b": nw.Int32(), + "c": nw.Int16(), + "d": nw.Int8(), + "e": nw.UInt64(), + "f": nw.UInt32(), + "g": nw.UInt16(), + "h": nw.UInt8(), + "i": nw.Float64(), + "j": nw.Float32(), + "k": nw.String(), + "l": nw.Datetime(), + "m": nw.Boolean(), + "n": nw.Date(), + "o": nw.Datetime(), + "p": nw.Categorical(), + "q": nw.Duration(), + "r": nw.Enum(["A", "B", "C"]), + "s": nw.List(nw.String()), + "u": nw.Struct({"a": nw.Int64(), "k": nw.String()}), + } + + +# NOTE: The meta check doesn't provide typing and describes a superset of `_ColumnSelection` +def is_column_selection(obj: ExprIR) -> TypeIs[_ColumnSelection]: + return obj.meta.is_column_selection(allow_aliasing=False) and isinstance( + obj, _ColumnSelection + ) + + +def seq_column_from_names(names: Sequence[str]) -> tuple[Column, ...]: + return tuple(Column(name=name) for name in names) + + +@pytest.mark.parametrize( + ("expr", "into_expected"), + [ + (nwd.col("a", "c"), ["a", "c"]), + (nwd.col("o", "k", "b"), ["o", "k", "b"]), + (nwd.nth(5), ["f"]), + (nwd.nth(0, 1, 2, 3, 4), ["a", "b", "c", "d", "e"]), + (nwd.nth(-1), ["u"]), + (nwd.nth([-2, -3, -4]), ["s", "r", "q"]), + ( + nwd.all(), + [ + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "u", + ], + ), + ( + nwd.exclude("a", "c", "e", "l", "q"), + ["b", "d", "f", "g", "h", "i", "j", "k", "m", "n", "o", "p", "r", "s", "u"], + ), + ], +) +def test_expand_columns_root( + expr: DummyExpr, into_expected: Sequence[str], schema_1: dict[str, DType] +) -> None: + expected = seq_column_from_names(into_expected) + selection = expr._ir + assert is_column_selection(selection) + actual = selection.expand_columns(schema_1) + assert actual == expected + + +@pytest.mark.parametrize( + "expr", + [ + nwd.col("y", "z"), + nwd.col("a", "b", "z"), + nwd.col("x", "b", "a"), + nwd.col( + [ + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "FIVE", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "u", + ] + ), + ], +) +def test_invalid_expand_columns(expr: DummyExpr, schema_1: dict[str, DType]) -> None: + selection = expr._ir + assert is_column_selection(selection) + with pytest.raises(ColumnNotFoundError): + selection.expand_columns(schema_1) From da2fa94aa51115443a3642d52cf8076d84cec3a8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 5 Jun 2025 21:46:30 +0100 Subject: [PATCH 165/368] feat(DRAFT): Fill out some of `rewrite_special_aliases` --- narwhals/_plan/expr_expansion.py | 20 +++++++++++++++++++- narwhals/_plan/meta.py | 19 +++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 0dbe3c11ac..255f667ace 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -43,6 +43,7 @@ from typing import TYPE_CHECKING, Any, Mapping, Sequence from narwhals._plan.common import Immutable +from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: from typing_extensions import TypeAlias @@ -270,12 +271,29 @@ def replace_wildcard_with_column(origin: ExprIR, /, column_name: str) -> ExprIR: raise NotImplementedError +# TODO @dangotbanned: `meta.get_single_leaf_name` def rewrite_special_aliases(origin: ExprIR, /) -> ExprIR: """`KeepName` and `RenameAlias`. Reuses some of the `meta` functions to traverse the names. """ - raise NotImplementedError + from narwhals._plan import expr, meta + + if meta.has_expr_ir(origin, expr.KeepName, expr.RenameAlias): + if isinstance(origin, expr.KeepName): + parent = origin.expr + roots = parent.meta.root_names() + alias = next(iter(roots)) + return expr.Alias(expr=parent, name=alias) + elif isinstance(origin, expr.RenameAlias): + parent = origin.expr + leaf_name = meta.get_single_leaf_name(parent) + alias = origin.function(leaf_name) + return expr.Alias(expr=parent, name=alias) + else: + msg = "`keep`, `suffix`, `prefix` should be last expression" + raise InvalidOperationError(msg) + return origin def replace_dtype_or_index_with_column( diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 330320fac7..09b0be9f6f 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -176,6 +176,25 @@ def _has_multiple_outputs(ir: ExprIR) -> bool: return isinstance(ir, (expr.Columns, expr.IndexColumns, expr.SelectorIR, expr.All)) +def has_expr_ir(ir: ExprIR, *matches: type[ExprIR]) -> bool: + """Return True if any node in the tree is in type `matches`. + + Based on [`polars_plan::utils::has_expr`] + + [`polars_plan::utils::has_expr`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/utils.rs#L70-L77 + """ + return any(isinstance(e, matches) for e in ir.iter_right()) + + +# TODO @dangotbanned: Adapt this one for `rewrite_special_aliases` +def get_single_leaf_name(ir: ExprIR) -> str: + """Not yet implemented! + + https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/utils.rs#L151-L168. + """ + raise NotImplementedError + + def _is_literal(ir: ExprIR, *, allow_aliasing: bool) -> bool: from narwhals._plan import expr from narwhals._plan.literal import ScalarLiteral From bcc8ba7f6cca9a45df3ff48f1435100311d47616 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 5 Jun 2025 21:59:00 +0100 Subject: [PATCH 166/368] revert: Remove unplanned `meta` methods --- narwhals/_plan/meta.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 09b0be9f6f..0762a53f24 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -80,31 +80,6 @@ def root_names(self) -> list[str]: """Get the root column names.""" return _expr_to_leaf_column_names(self._ir) - # NOTE: Seems too complex to do whilst keeping things immutable - def undo_aliases(self) -> ExprIR: - """Investigate components. - - Seems like it unnests each of these: - - `Alias.expr` - - `KeepName.expr` - - `RenameAlias.expr` - - Notes: - - [`meta.undo_aliases`] - - [`Expr.map_expr`] - - [`TreeWalker.rewrite`] - - [`Expr.map_expr`]: https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/plans/iterator.rs#L146-L149 - [`meta.undo_aliases`]: https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/meta.rs#L45-L53 - [`TreeWalker.rewrite`]: https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/plans/visitor/visitors.rs#L46-L68 - """ - raise NotImplementedError - - # NOTE: Less important for us, but maybe nice to have - def pop(self) -> list[ExprIR]: - """https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/meta.rs#L14-L25.""" - raise NotImplementedError - def _expr_to_leaf_column_names(ir: ExprIR) -> list[str]: """After a lot of indirection, [root_names] resolves [here]. From e4ac0d7e3bd6975f309fd4f4b75fecc986b2e38c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 5 Jun 2025 22:01:02 +0100 Subject: [PATCH 167/368] revert: Remove debug `meta` functions Haven't needed for a while now --- narwhals/_plan/meta.py | 39 +-------------------------------------- 1 file changed, 1 insertion(+), 38 deletions(-) diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 0762a53f24..20da92af6c 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -13,9 +13,7 @@ from narwhals.utils import Version if TYPE_CHECKING: - from typing import Any, Iterator - - import polars as pl + from typing import Iterator from narwhals._plan.common import ExprIR @@ -195,38 +193,3 @@ def _is_column_selection(ir: ExprIR, *, allow_aliasing: bool) -> bool: if isinstance(ir, (expr.Alias, expr.KeepName, expr.RenameAlias)): return allow_aliasing return False - - -def polars_expr_metadata(expr: pl.Expr) -> dict[str, Any]: - """Gather all metadata for a native `Expr`. - - Eventual goal would be that a `nw.Expr` matches a `pl.Expr` in as much of this as possible. - """ - return { - "has_multiple_outputs": expr.meta.has_multiple_outputs(), - "is_column": expr.meta.is_column(), - "is_regex_projection": expr.meta.is_regex_projection(), - "is_column_selection": expr.meta.is_column_selection(), - "is_column_selection(allow_aliasing=True)": expr.meta.is_column_selection( - allow_aliasing=True - ), - "is_literal": expr.meta.is_literal(), - "is_literal(allow_aliasing=True)": expr.meta.is_literal(allow_aliasing=True), - "output_name": expr.meta.output_name(raise_if_undetermined=False), - "root_names": expr.meta.root_names(), - "pop": expr.meta.pop(), - "undo_aliases": expr.meta.undo_aliases(), - "expr": expr, - } - - -def polars_expr_to_dict(expr: pl.Expr) -> dict[str, Any]: - """Serialize a native `Expr`, roundtrip back to `dict`. - - Using to inspect [`FunctionOptions`] and ensure we combine them in a similar way. - - [`FunctionOptions`]: https://github.com/narwhals-dev/narwhals/pull/2572#issuecomment-2891577685 - """ - import json - - return json.loads(expr.meta.serialize(format="json")) # type: ignore[no-any-return] From ac19b0401f84456838c845a0e56c9f86ea2ac4ec Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 5 Jun 2025 22:03:12 +0100 Subject: [PATCH 168/368] revert: Remove unused `list` accessor from `Expr.IR` --- narwhals/_plan/common.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index a53cc1a0e4..84ace8c374 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -14,7 +14,6 @@ from narwhals._plan.dummy import DummyExpr, DummySelector, DummySeries from narwhals._plan.expr import FunctionExpr - from narwhals._plan.lists import IRListNamespace from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.options import FunctionOptions from narwhals.dtypes import DType @@ -234,12 +233,6 @@ def meta(self) -> IRMetaNamespace: return IRMetaNamespace(_ir=self) - @property - def list(self) -> IRListNamespace: - from narwhals._plan.lists import IRListNamespace - - return IRListNamespace(_ir=self) - class SelectorIR(ExprIR): def to_narwhals(self, version: Version = Version.MAIN) -> DummySelector: From 9955bdb674d241e2e5a4e8d0ffa0a9f218f2095e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 6 Jun 2025 13:50:29 +0100 Subject: [PATCH 169/368] feat: Support single `col` rewrites for `.name` methods Some big caveats, but after getting everything into `Column` - hopefully this won't need to be changed --- narwhals/_plan/expr_expansion.py | 14 ++++--- narwhals/_plan/meta.py | 36 +++++++++++++----- tests/plan/expr_expansion_test.py | 61 ++++++++++++++++++++++++++++++- 3 files changed, 95 insertions(+), 16 deletions(-) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 255f667ace..d3d2580a6d 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -271,11 +271,13 @@ def replace_wildcard_with_column(origin: ExprIR, /, column_name: str) -> ExprIR: raise NotImplementedError -# TODO @dangotbanned: `meta.get_single_leaf_name` def rewrite_special_aliases(origin: ExprIR, /) -> ExprIR: - """`KeepName` and `RenameAlias`. + """Expand `KeepName` and `RenameAlias` into `Alias`. - Reuses some of the `meta` functions to traverse the names. + Warning: + Only valid **after** + - Expanding all selections into `Column` + - Dealing with `FunctionExpr.input` """ from narwhals._plan import expr, meta @@ -287,8 +289,10 @@ def rewrite_special_aliases(origin: ExprIR, /) -> ExprIR: return expr.Alias(expr=parent, name=alias) elif isinstance(origin, expr.RenameAlias): parent = origin.expr - leaf_name = meta.get_single_leaf_name(parent) - alias = origin.function(leaf_name) + leaf_name_or_err = meta.get_single_leaf_name(parent) + if not isinstance(leaf_name_or_err, str): + raise leaf_name_or_err + alias = origin.function(leaf_name_or_err) return expr.Alias(expr=parent, name=alias) else: msg = "`keep`, `suffix`, `prefix` should be last expression" diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 20da92af6c..18461f5900 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal, overload from narwhals._plan.common import IRNamespace from narwhals.exceptions import ComputeError @@ -40,6 +40,10 @@ def is_literal(self, *, allow_aliasing: bool = False) -> bool: _is_literal(e, allow_aliasing=allow_aliasing) for e in self._ir.iter_left() ) + @overload + def output_name(self, *, raise_if_undetermined: Literal[True] = True) -> str: ... + @overload + def output_name(self, *, raise_if_undetermined: Literal[False]) -> str | None: ... def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: """Get the output name of this expression. @@ -143,6 +147,27 @@ def _expr_output_name(ir: ExprIR) -> str | ComputeError: return ComputeError(msg) +def get_single_leaf_name(ir: ExprIR) -> str | ComputeError: + """Find the name at the start of an expression. + + Normal iteration would just return the first root column it found. + + Based on [`polars_plan::utils::get_single_leaf`] + + [`polars_plan::utils::get_single_leaf`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/utils.rs#L151-L168 + """ + from narwhals._plan import expr + + for e in ir.iter_right(): + if isinstance(e, (expr.WindowExpr, expr.SortBy, expr.Filter)): + return get_single_leaf_name(e.expr) + # NOTE: `polars` doesn't include `Literal` here + if isinstance(e, (expr.Column, expr.Len)): + return e.name + msg = f"unable to find a single leaf column in expr '{ir!r}'" + return ComputeError(msg) + + def _has_multiple_outputs(ir: ExprIR) -> bool: from narwhals._plan import expr @@ -159,15 +184,6 @@ def has_expr_ir(ir: ExprIR, *matches: type[ExprIR]) -> bool: return any(isinstance(e, matches) for e in ir.iter_right()) -# TODO @dangotbanned: Adapt this one for `rewrite_special_aliases` -def get_single_leaf_name(ir: ExprIR) -> str: - """Not yet implemented! - - https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/utils.rs#L151-L168. - """ - raise NotImplementedError - - def _is_literal(ir: ExprIR, *, allow_aliasing: bool) -> bool: from narwhals._plan import expr from narwhals._plan.literal import ScalarLiteral diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 0b54f8801b..d7b825087e 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -7,7 +7,8 @@ import narwhals as nw import narwhals._plan.demo as nwd from narwhals._plan.expr import Column, _ColumnSelection -from narwhals.exceptions import ColumnNotFoundError +from narwhals._plan.expr_expansion import rewrite_special_aliases +from narwhals.exceptions import ColumnNotFoundError, ComputeError if TYPE_CHECKING: from typing_extensions import TypeIs @@ -142,3 +143,61 @@ def test_invalid_expand_columns(expr: DummyExpr, schema_1: dict[str, DType]) -> assert is_column_selection(selection) with pytest.raises(ColumnNotFoundError): selection.expand_columns(schema_1) + + +def udf_name_map(name: str) -> str: + original = name + upper = name.upper() + lower = name.lower() + title = name.title() + return f"{original=} | {upper=} | {lower=} | {title=}" + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (nwd.col("a").name.to_uppercase(), "A"), + (nwd.col("B").name.to_lowercase(), "b"), + (nwd.col("c").name.suffix("_after"), "c_after"), + (nwd.col("d").name.prefix("before_"), "before_d"), + ( + nwd.col("aBcD EFg hi").name.map(udf_name_map), + "original='aBcD EFg hi' | upper='ABCD EFG HI' | lower='abcd efg hi' | title='Abcd Efg Hi'", + ), + (nwd.col("a").min().alias("b").over("c").alias("d").max().name.keep(), "a"), + ( + ( + nwd.col("hello") + .sort_by(nwd.col("ignore me")) + .max() + .over("ignore me as well") + .first() + .name.to_uppercase() + ), + "HELLO", + ), + ( + ( + nwd.col("start") + .alias("next") + .sort() + .round() + .fill_null(5) + .alias("noise") + .name.suffix("_end") + ), + "start_end", + ), + ], +) +def test_rewrite_special_aliases_single(expr: DummyExpr, expected: str) -> None: + # NOTE: We can't use `output_name()` without resolving these rewrites + # Once they're done, `output_name()` just peeks into `Alias(name=...)` + ir_input = expr._ir + with pytest.raises(ComputeError): + ir_input.meta.output_name() + + ir_output = rewrite_special_aliases(ir_input) + assert ir_input != ir_output + actual = ir_output.meta.output_name() + assert actual == expected From 61a51037231aa03fb9c7204d29039d71c15ff77a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 6 Jun 2025 16:06:27 +0100 Subject: [PATCH 170/368] feat(DRAFT): Sketch out parts of `rewrite_projections`, `expand_function_inputs` --- narwhals/_plan/common.py | 11 ++++++ narwhals/_plan/expr.py | 5 +++ narwhals/_plan/expr_expansion.py | 60 ++++++++++++++++++++++++++------ narwhals/_plan/options.py | 6 ++++ 4 files changed, 71 insertions(+), 11 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 84ace8c374..11398b7b19 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -174,6 +174,17 @@ def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: def is_scalar(self) -> bool: return False + def map_ir(self, function: Callable[[ExprIR], ExprIR], /) -> ExprIR: + """Apply `function` to each child node, returning a new `ExprIR`. + + See [`polars_plan::plans::iterator::Expr.map_expr`] and [`polars_plan::plans::visitor::visitors`]. + + [`polars_plan::plans::iterator::Expr.map_expr`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/iterator.rs#L152-L159 + [`polars_plan::plans::visitor::visitors`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/visitor/visitors.rs + """ + msg = "Need to handle recursive visiting first!" + raise NotImplementedError(msg) + def iter_left(self) -> Iterator[ExprIR]: """Yield nodes root->leaf. diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index a566535598..271ba04dd5 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -393,6 +393,11 @@ def with_options(self, options: FunctionOptions, /) -> Self: options = self.options.with_flags(options.flags) return type(self)(input=self.input, function=self.function, options=options) + def with_input(self, input: t.Iterable[ExprIR], /) -> Self: # noqa: A002 + if not isinstance(input, tuple): + input = tuple(input) + return type(self)(input=input, function=self.function, options=self.options) + def __repr__(self) -> str: if self.input: first = self.input[0] diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index d3d2580a6d..9e72135f01 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -38,6 +38,7 @@ # ruff: noqa: A002 from __future__ import annotations +from collections import deque from copy import deepcopy from types import MappingProxyType from typing import TYPE_CHECKING, Any, Mapping, Sequence @@ -65,6 +66,8 @@ Very likely that **we won't** do this in `narwhals`, instead return a new object. """ +ResultIRs: TypeAlias = "deque[ExprIR]" + # NOTE: Both `_freeze` functions will probably want to be cached # In the traversal/expand/replacement functions, their returns will be hashable -> safe to cache those as well @@ -134,6 +137,15 @@ def from_ir(ir: ExprIR, /) -> ExpansionFlags: def from_expr(cls, expr: DummyExpr, /) -> ExpansionFlags: return cls.from_ir(expr._ir) + def with_multiple_columns(self) -> ExpansionFlags: + return ExpansionFlags( + multiple_columns=True, + has_nth=self.has_nth, + has_wildcard=self.has_wildcard, + has_selector=self.has_selector, + has_exclude=self.has_exclude, + ) + def prepare_projection( exprs: Sequence[ExprIR], schema: Mapping[str, DType] @@ -155,18 +167,44 @@ def prepare_projection( # - `exclude` is the return of `prepare_excluded` +# NOTE: The inner function is ready def expand_function_inputs(origin: ExprIR, /, *, schema: FrozenSchema) -> ExprIR: - raise NotImplementedError + from narwhals._plan import expr + + def fn(child: ExprIR, /) -> ExprIR: + if not ( + isinstance(child, expr.FunctionExpr) + and child.options.is_input_wildcard_expansion() + ): + return child + return child.with_input(rewrite_projections(child.input, keys=(), schema=schema)) + + return origin.map_ir(fn) def rewrite_projections( input: Seq[ExprIR], # `FunctionExpr.input` /, - keys: Seq[ExprIR], - *, + keys: Seq[ + ExprIR + ], # NOTE: Mutable (empty) array initialized on call (except in `polars_plan::plans::conversion::dsl_to_ir::resolve_group_by`) + *, # NOTE: Represents group_by keys schema: FrozenSchema, ) -> Seq[ExprIR]: - raise NotImplementedError + # NOTE: This is where the mutable `result` is initialized + result_length = len(input) + len(schema) + result: deque[ExprIR] = deque(maxlen=result_length) + for expr in input: + expanded = expand_function_inputs(expr, schema=schema) + flags = ExpansionFlags.from_ir(expanded) + if flags.has_selector: + expanded = replace_selector(expanded, keys, schema=schema) + flags = flags.with_multiple_columns() + # NOTE: `result` is what I'd want as a return, rather than inplace + replace_and_add_to_results( + expanded, result, keys=keys, schema=schema, flags=flags + ) + return tuple(result) def replace_selector( @@ -201,7 +239,7 @@ def replace_selector_inner( def replace_and_add_to_results( origin: ExprIR, /, - result: Seq[ExprIR], + result: ResultIRs, keys: Seq[ExprIR], *, schema: FrozenSchema, @@ -227,7 +265,7 @@ def prepare_excluded( def expand_columns( origin: ExprIR, /, - result: Seq[ExprIR], + result: ResultIRs, columns: expr.Columns, # `polars` uses columns.names *, col_names: FrozenColumns, @@ -239,7 +277,7 @@ def expand_columns( def expand_dtypes( origin: ExprIR, /, - result: Seq[ExprIR], + result: ResultIRs, dtypes: selectors.ByDType, # we haven't got `DtypeColumn` *, schema: FrozenSchema, @@ -251,7 +289,7 @@ def expand_dtypes( def expand_indices( origin: ExprIR, /, - result: Seq[ExprIR], + result: ResultIRs, indices: expr.IndexColumns, *, schema: FrozenSchema, @@ -261,7 +299,7 @@ def expand_indices( def replace_wildcard( - origin: ExprIR, /, result: Seq[ExprIR], *, col_names: FrozenColumns, exclude: Excluded + origin: ExprIR, /, result: ResultIRs, *, col_names: FrozenColumns, exclude: Excluded ) -> Inplace: raise NotImplementedError @@ -313,7 +351,7 @@ def dtypes_match(left: DType, right: DType | type[DType]) -> bool: def replace_regex( origin: ExprIR, /, - result: Seq[ExprIR], + result: ResultIRs, pattern: selectors.Matches, *, col_names: FrozenColumns, @@ -323,6 +361,6 @@ def replace_regex( def expand_regex( - origin: ExprIR, /, result: Seq[ExprIR], *, col_names: FrozenColumns, exclude: Excluded + origin: ExprIR, /, result: ResultIRs, *, col_names: FrozenColumns, exclude: Excluded ) -> Inplace: raise NotImplementedError diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 2e9410ec14..6283c764d2 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -47,6 +47,9 @@ def is_length_preserving(self) -> bool: def is_row_separable(self) -> bool: return FunctionFlags.ROW_SEPARABLE in self + def is_input_wildcard_expansion(self) -> bool: + return FunctionFlags.INPUT_WILDCARD_EXPANSION in self + @staticmethod def default() -> FunctionFlags: return FunctionFlags.ALLOW_GROUP_AWARE @@ -81,6 +84,9 @@ def is_length_preserving(self) -> bool: def is_row_separable(self) -> bool: return self.flags.is_row_separable() + def is_input_wildcard_expansion(self) -> bool: + return self.flags.is_input_wildcard_expansion() + def with_flags(self, flags: FunctionFlags, /) -> FunctionOptions: if (FunctionFlags.RETURNS_SCALAR | FunctionFlags.LENGTH_PRESERVING) in flags: msg = "A function cannot both return a scalar and preserve length, they are mutually exclusive." From eb3298675f7fb652b40675ed39259a7edff74f20 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 6 Jun 2025 16:56:12 +0100 Subject: [PATCH 171/368] test: Plan out tests for `ExprIR.map_ir` --- tests/plan/expr_expansion_test.py | 68 +++++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 3 deletions(-) diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index d7b825087e..742df17e9d 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -1,23 +1,26 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Callable, Sequence import pytest import narwhals as nw import narwhals._plan.demo as nwd -from narwhals._plan.expr import Column, _ColumnSelection +from narwhals._plan.expr import Alias, Column, _ColumnSelection from narwhals._plan.expr_expansion import rewrite_special_aliases from narwhals.exceptions import ColumnNotFoundError, ComputeError if TYPE_CHECKING: - from typing_extensions import TypeIs + from typing_extensions import TypeAlias, TypeIs from narwhals._plan.common import ExprIR from narwhals._plan.dummy import DummyExpr from narwhals.dtypes import DType +MapIR: TypeAlias = "Callable[[ExprIR], ExprIR]" + + @pytest.fixture def schema_1() -> dict[str, DType]: return { @@ -201,3 +204,62 @@ def test_rewrite_special_aliases_single(expr: DummyExpr, expected: str) -> None: assert ir_input != ir_output actual = ir_output.meta.output_name() assert actual == expected + + +def alias_replace_guarded(name: str) -> MapIR: + """Guards against repeatedly creating the same alias.""" + + def fn(ir: ExprIR) -> ExprIR: + if isinstance(ir, Alias) and ir.name != name: + return Alias(expr=ir.expr, name=name) + return ir + + return fn + + +def alias_replace_unguarded(name: str) -> MapIR: + """**Does not guard against recursion**! + + Handling the recursion stopping **should be** part of the impl of `ExprIR.map_ir`. + + - *Ideally*, return an identical result to `alias_replace_guarded` (after the same number of iterations) + - *Pragmatically*, it might require an extra iteration to detect a cycle + """ + + def fn(ir: ExprIR) -> ExprIR: + if isinstance(ir, Alias): + return Alias(expr=ir.expr, name=name) + return ir + + return fn + + +@pytest.mark.xfail( + reason="Not implemented `ExprIR.map_ir` yet", raises=NotImplementedError +) +@pytest.mark.parametrize( + ("expr", "function", "into_expected"), + [ + (nwd.col("a"), alias_replace_guarded("never"), nwd.col("a")), + (nwd.col("a"), alias_replace_unguarded("never"), nwd.col("a")), + (nwd.col("a").alias("b"), alias_replace_guarded("c"), nwd.col("a").alias("c")), + (nwd.col("a").alias("b"), alias_replace_unguarded("c"), nwd.col("a").alias("c")), + ( + nwd.col("a").alias("d").first().over("b", order_by="c").alias("e"), + alias_replace_guarded("d"), + nwd.col("a").alias("d").first().over("b", order_by="c").alias("d"), + ), + ( + nwd.col("a").alias("d").first().over("b", order_by="c").alias("e"), + alias_replace_unguarded("d"), + nwd.col("a").alias("d").first().over("b", order_by="c").alias("d"), + ), + ], +) +def test_map_ir_recursive( + expr: DummyExpr, function: MapIR, into_expected: DummyExpr +) -> None: + ir = expr._ir + expected = into_expected._ir + actual = ir.map_ir(function) + assert actual == expected From 77c6321623cf1c87a41dd4244e63682c9a941f54 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 6 Jun 2025 19:35:52 +0100 Subject: [PATCH 172/368] feat: Add missing `DummyExpr.__str__` --- narwhals/_plan/dummy.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index f5959e7c5b..97905c49db 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -60,6 +60,10 @@ class DummyExpr: def __repr__(self) -> str: return f"Narwhals DummyExpr ({self.version.name.lower()}):\n{self._ir!r}" + def __str__(self) -> str: + """Use `print(self)` for formatting.""" + return f"Narwhals DummyExpr ({self.version.name.lower()}):\n{self._ir!s}" + @classmethod def _from_ir(cls, ir: ExprIR, /) -> Self: obj = cls.__new__(cls) From 781dd275e2066c1dd39dafb928cc07dd0fac2b05 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 6 Jun 2025 19:38:22 +0100 Subject: [PATCH 173/368] cov --- tests/plan/expr_expansion_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 742df17e9d..cba1f35366 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -206,7 +206,7 @@ def test_rewrite_special_aliases_single(expr: DummyExpr, expected: str) -> None: assert actual == expected -def alias_replace_guarded(name: str) -> MapIR: +def alias_replace_guarded(name: str) -> MapIR: # pragma: no cover """Guards against repeatedly creating the same alias.""" def fn(ir: ExprIR) -> ExprIR: @@ -217,7 +217,7 @@ def fn(ir: ExprIR) -> ExprIR: return fn -def alias_replace_unguarded(name: str) -> MapIR: +def alias_replace_unguarded(name: str) -> MapIR: # pragma: no cover """**Does not guard against recursion**! Handling the recursion stopping **should be** part of the impl of `ExprIR.map_ir`. @@ -262,4 +262,4 @@ def test_map_ir_recursive( ir = expr._ir expected = into_expected._ir actual = ir.map_ir(function) - assert actual == expected + assert actual == expected # pragma: no cover From d6965c6b799740736a68f1440e7a69bb60bfa3a9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 7 Jun 2025 18:55:41 +0100 Subject: [PATCH 174/368] chore: fix `_utils` imports --- narwhals/_plan/demo.py | 2 +- narwhals/_plan/dummy.py | 2 +- narwhals/_plan/expr.py | 2 +- narwhals/_plan/selectors.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index d4a009b6b2..290a358b3c 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -21,9 +21,9 @@ from narwhals._plan.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.strings import ConcatHorizontal from narwhals._plan.when_then import When +from narwhals._utils import Version, flatten from narwhals.dtypes import DType from narwhals.exceptions import OrderDependentExprError -from narwhals.utils import Version, flatten if t.TYPE_CHECKING: from typing_extensions import TypeIs diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 97905c49db..96fb7f25c4 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -24,9 +24,9 @@ ) from narwhals._plan.selectors import by_name from narwhals._plan.window import Over +from narwhals._utils import Version, _hasattr_static from narwhals.dtypes import DType from narwhals.exceptions import ComputeError -from narwhals.utils import Version, _hasattr_static if TYPE_CHECKING: from typing_extensions import Never, Self diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 271ba04dd5..0f31319a4c 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -28,7 +28,7 @@ SelectorOperatorT, SelectorT, ) -from narwhals.utils import flatten +from narwhals._utils import flatten if t.TYPE_CHECKING: from typing_extensions import Self, TypeAlias diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/selectors.py index 4360902e56..e2e3936ba7 100644 --- a/narwhals/_plan/selectors.py +++ b/narwhals/_plan/selectors.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Iterable from narwhals._plan.common import Immutable, is_iterable_reject -from narwhals.utils import _parse_time_unit_and_time_zone +from narwhals._utils import _parse_time_unit_and_time_zone if TYPE_CHECKING: from datetime import timezone From 392015800ba4a9482250cbd58e28869c65bf3fd9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 8 Jun 2025 12:27:34 +0100 Subject: [PATCH 175/368] feat(DRAFT): Fill out more of `expr_expansion` - Starting off with a *mostly* direct translation - There are likely going to be things that get factored out later --- narwhals/_plan/expr.py | 3 + narwhals/_plan/expr_expansion.py | 168 +++++++++++++++++++++++++------ narwhals/_plan/meta.py | 13 ++- 3 files changed, 153 insertions(+), 31 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 0f31319a4c..da3f12a373 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -116,6 +116,9 @@ def __repr__(self) -> str: def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: return plx.col(self.name) + def with_name(self, name: str, /) -> Column: + return self if name == self.name else Column(name=name) + def _col(name: str, /) -> Column: return Column(name=name) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 9e72135f01..4de7c38d48 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -41,12 +41,14 @@ from collections import deque from copy import deepcopy from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence from narwhals._plan.common import Immutable -from narwhals.exceptions import InvalidOperationError +from narwhals.exceptions import ComputeError, InvalidOperationError if TYPE_CHECKING: + import re + from typing_extensions import TypeAlias from narwhals._plan import expr, selectors @@ -182,6 +184,28 @@ def fn(child: ExprIR, /) -> ExprIR: return origin.map_ir(fn) +def replace_nth(origin: ExprIR, /, schema: FrozenSchema) -> ExprIR: + from narwhals._plan import expr + + def fn(child: ExprIR, /) -> ExprIR: + if isinstance(child, expr.Nth): + return expr.Column(name=_freeze_columns(schema)[child.index]) + return child + + return origin.map_ir(fn) + + +def remove_exclude(origin: ExprIR, /) -> ExprIR: + from narwhals._plan import expr + + def fn(child: ExprIR, /) -> ExprIR: + if isinstance(child, expr.Exclude): + return child.expr + return child + + return origin.map_ir(fn) + + def rewrite_projections( input: Seq[ExprIR], # `FunctionExpr.input` /, @@ -207,6 +231,52 @@ def rewrite_projections( return tuple(result) +def replace_and_add_to_results( + origin: ExprIR, + /, + result: ResultIRs, + keys: Seq[ExprIR], + *, + schema: FrozenSchema, + flags: ExpansionFlags, +) -> Inplace: + from narwhals._plan import expr + + if flags.has_nth: + origin = replace_nth(origin, schema) + if flags.expands: + it = ( + e + for e in origin.iter_left() + if isinstance(e, (expr.Columns, expr.IndexColumns)) + ) + if e := next(it, None): + if isinstance(e, expr.Columns): + exclude = prepare_excluded( + origin, keys=(), schema=schema, has_exclude=flags.has_exclude + ) + expand_columns( + origin, result, e, col_names=_freeze_columns(schema), exclude=exclude + ) + else: + exclude = prepare_excluded( + origin, keys=keys, schema=schema, has_exclude=flags.has_exclude + ) + expand_indices(origin, result, e, schema=schema, exclude=exclude) + elif flags.has_wildcard: + exclude = prepare_excluded( + origin, keys=keys, schema=schema, has_exclude=flags.has_exclude + ) + replace_wildcard( + origin, result, col_names=_freeze_columns(schema), exclude=exclude + ) + else: + exclude = prepare_excluded( + origin, keys=keys, schema=schema, has_exclude=flags.has_exclude + ) + replace_regex(origin, result, col_names=_freeze_columns(schema), exclude=exclude) + + def replace_selector( ir: ExprIR, # an element of `FunctionExpr.input` /, @@ -217,6 +287,7 @@ def replace_selector( raise NotImplementedError +# TODO @dangotbanned: Huge def expand_selector( s: expr.SelectorIR, /, keys: Seq[ExprIR], *, schema: FrozenSchema ) -> Seq[str]: @@ -236,32 +307,14 @@ def replace_selector_inner( raise NotImplementedError -def replace_and_add_to_results( - origin: ExprIR, - /, - result: ResultIRs, - keys: Seq[ExprIR], - *, - schema: FrozenSchema, - flags: ExpansionFlags, -) -> Inplace: - raise NotImplementedError - - -# NOTE: See how far we can get with just the direct node replacements -# - `polars` is using `map_expr`, but I haven't implemented that (yet?) -def replace_nth(nth: expr.Nth, /, col_names: FrozenColumns) -> expr.Column: - from narwhals._plan import expr - - return expr.Column(name=col_names[nth.index]) - - +# TODO @dangotbanned: Priority High def prepare_excluded( origin: ExprIR, /, keys: Seq[ExprIR], *, schema: FrozenSchema, has_exclude: bool ) -> Excluded: raise NotImplementedError +# TODO @dangotbanned: Priority High def expand_columns( origin: ExprIR, /, @@ -274,6 +327,7 @@ def expand_columns( raise NotImplementedError +# TODO @dangotbanned: Priority Low def expand_dtypes( origin: ExprIR, /, @@ -286,6 +340,7 @@ def expand_dtypes( raise NotImplementedError +# TODO @dangotbanned: Priority Mid def expand_indices( origin: ExprIR, /, @@ -298,6 +353,7 @@ def expand_indices( raise NotImplementedError +# TODO @dangotbanned: Priority Mid def replace_wildcard( origin: ExprIR, /, result: ResultIRs, *, col_names: FrozenColumns, exclude: Excluded ) -> Inplace: @@ -348,19 +404,75 @@ def dtypes_match(left: DType, right: DType | type[DType]) -> bool: return left == right +def into_pattern(obj: str | re.Pattern[str] | selectors.Matches, /) -> re.Pattern[str]: + import re + + from narwhals._plan import selectors + + if isinstance(obj, str): + return re.compile(obj) + elif isinstance(obj, selectors.Matches): + return obj.pattern + elif isinstance(obj, re.Pattern): + return obj + else: + msg = f"Cannot convert {type(obj).__name__!r} into a regular expression" + raise TypeError(msg) + + +def is_regex_projection(name: str) -> bool: + return name.startswith("^") and name.endswith("$") + + +# NOTE: Will likely be using `selectors.Matches` for this +# Doing a direct translation from `rust` *first*, to make replacing +# the deviations *later* not as daunting def replace_regex( + origin: ExprIR, /, result: ResultIRs, *, col_names: FrozenColumns, exclude: Excluded +) -> Inplace: + regex: str | None = None + for name in origin.meta.root_names(): + if is_regex_projection(name): + if regex is None: + regex = name + expand_regex( + origin, + result, + into_pattern(name), + col_names=col_names, + exclude=exclude, + ) + elif regex != name: + msg = "an expression is not allowed to have different regexes" + raise ComputeError(msg) + if regex is None: + origin = rewrite_special_aliases(origin) + result.append(origin) + + +def expand_regex( origin: ExprIR, /, result: ResultIRs, - pattern: selectors.Matches, + pattern: re.Pattern[str], *, col_names: FrozenColumns, exclude: Excluded, ) -> Inplace: - raise NotImplementedError + for name in col_names: + if pattern.match(name) and name not in exclude: + expanded = remove_exclude(origin) + expanded = expanded.map_ir(_replace_regex(pattern, name)) + expanded = rewrite_special_aliases(expanded) + result.append(expanded) -def expand_regex( - origin: ExprIR, /, result: ResultIRs, *, col_names: FrozenColumns, exclude: Excluded -) -> Inplace: - raise NotImplementedError +def _replace_regex(pattern: re.Pattern[str], name: str, /) -> Callable[[ExprIR], ExprIR]: + from narwhals._plan.meta import is_column + + pat = pattern.pattern + + def fn(ir: ExprIR, /) -> ExprIR: + return ir.with_name(name) if is_column(ir) and ir.name == pat else ir + + return fn diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 18461f5900..0f66560636 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -15,7 +15,10 @@ if TYPE_CHECKING: from typing import Iterator + from typing_extensions import TypeIs + from narwhals._plan.common import ExprIR + from narwhals._plan.expr import Column class IRMetaNamespace(IRNamespace): @@ -25,9 +28,7 @@ def has_multiple_outputs(self) -> bool: return any(_has_multiple_outputs(e) for e in self._ir.iter_left()) def is_column(self) -> bool: - from narwhals._plan.expr import Column - - return isinstance(self._ir, Column) + return is_column(self._ir) def is_column_selection(self, *, allow_aliasing: bool = False) -> bool: return all( @@ -184,6 +185,12 @@ def has_expr_ir(ir: ExprIR, *matches: type[ExprIR]) -> bool: return any(isinstance(e, matches) for e in ir.iter_right()) +def is_column(ir: ExprIR) -> TypeIs[Column]: + from narwhals._plan.expr import Column + + return isinstance(ir, Column) + + def _is_literal(ir: ExprIR, *, allow_aliasing: bool) -> bool: from narwhals._plan import expr from narwhals._plan.literal import ScalarLiteral From f85a827f3303b879116e476facbf5fef3aa3ce82 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 8 Jun 2025 14:15:31 +0100 Subject: [PATCH 176/368] feat: Impl `prepare_excluded`, return from `replace_regex` --- narwhals/_plan/common.py | 4 ++ narwhals/_plan/expr.py | 14 ++++++- narwhals/_plan/expr_expansion.py | 64 ++++++++++++++++++++------------ 3 files changed, 56 insertions(+), 26 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 11398b7b19..7ab92b599b 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -350,6 +350,10 @@ def is_iterable_reject(obj: Any) -> TypeIs[str | bytes | DummySeries]: return isinstance(obj, (str, bytes, DummySeries)) +def is_regex_projection(name: str) -> bool: + return name.startswith("^") and name.endswith("$") + + def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) -> DType: dtypes = version.dtypes mapping: dict[type[NonNestedLiteral], type[DType]] = { diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index da3f12a373..5a7bb3499d 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -7,7 +7,13 @@ import typing as t from narwhals._plan.aggregation import Agg, OrderableAgg -from narwhals._plan.common import ExprIR, SelectorIR, _field_str, is_non_nested_literal +from narwhals._plan.common import ( + ExprIR, + SelectorIR, + _field_str, + is_non_nested_literal, + is_regex_projection, +) from narwhals._plan.exceptions import ( alias_duplicate_error, column_not_found_error, @@ -213,7 +219,11 @@ class Exclude(_ColumnSelection): @staticmethod def from_names(expr: ExprIR, *names: str | t.Iterable[str]) -> Exclude: - return Exclude(expr=expr, names=tuple(flatten(names))) + flat = flatten(names) + if any(is_regex_projection(nm) for nm in flat): + msg = f"Using regex in `exclude(...)` is not yet supported.\nnames={flat!r}" + raise NotImplementedError(msg) + return Exclude(expr=expr, names=tuple(flat)) def __repr__(self) -> str: return f"{self.expr!r}.exclude({list(self.names)!r})" diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 4de7c38d48..b98bfea332 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -41,9 +41,9 @@ from collections import deque from copy import deepcopy from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence -from narwhals._plan.common import Immutable +from narwhals._plan.common import Immutable, is_regex_projection from narwhals.exceptions import ComputeError, InvalidOperationError if TYPE_CHECKING: @@ -196,10 +196,10 @@ def fn(child: ExprIR, /) -> ExprIR: def remove_exclude(origin: ExprIR, /) -> ExprIR: - from narwhals._plan import expr + from narwhals._plan.expr import Exclude def fn(child: ExprIR, /) -> ExprIR: - if isinstance(child, expr.Exclude): + if isinstance(child, Exclude): return child.expr return child @@ -252,29 +252,26 @@ def replace_and_add_to_results( ) if e := next(it, None): if isinstance(e, expr.Columns): - exclude = prepare_excluded( - origin, keys=(), schema=schema, has_exclude=flags.has_exclude - ) + exclude = prepare_excluded(origin, keys=(), has_exclude=flags.has_exclude) expand_columns( origin, result, e, col_names=_freeze_columns(schema), exclude=exclude ) else: exclude = prepare_excluded( - origin, keys=keys, schema=schema, has_exclude=flags.has_exclude + origin, keys=keys, has_exclude=flags.has_exclude ) expand_indices(origin, result, e, schema=schema, exclude=exclude) elif flags.has_wildcard: - exclude = prepare_excluded( - origin, keys=keys, schema=schema, has_exclude=flags.has_exclude - ) + exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude) replace_wildcard( origin, result, col_names=_freeze_columns(schema), exclude=exclude ) else: - exclude = prepare_excluded( - origin, keys=keys, schema=schema, has_exclude=flags.has_exclude + exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude) + # NOTE: First case transitioned to return result! + result = replace_regex( + origin, result, col_names=_freeze_columns(schema), exclude=exclude ) - replace_regex(origin, result, col_names=_freeze_columns(schema), exclude=exclude) def replace_selector( @@ -307,11 +304,32 @@ def replace_selector_inner( raise NotImplementedError -# TODO @dangotbanned: Priority High +def _iter_exclude_names(origin: ExprIR, /) -> Iterator[str]: + """Yield all excluded names in `origin`.""" + from narwhals._plan.expr import Exclude + + for e in origin.iter_left(): + if isinstance(e, Exclude): + yield from e.names + + def prepare_excluded( - origin: ExprIR, /, keys: Seq[ExprIR], *, schema: FrozenSchema, has_exclude: bool + origin: ExprIR, /, keys: Seq[ExprIR], *, has_exclude: bool ) -> Excluded: - raise NotImplementedError + """Huge simplification of [`polars_plan::plans::conversion::expr_expansion::prepare_excluded`]. + + - `DTypes` are not allowed + - regex in `exclude(...)` is not allowed + + [`polars_plan::plans::conversion::expr_expansion::prepare_excluded`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L484-L555 + """ + exclude: set[str] = set() + if has_exclude: + exclude.update(_iter_exclude_names(origin)) + for group_by_key in keys: + if name := group_by_key.meta.output_name(raise_if_undetermined=False): + exclude.add(name) + return frozenset(exclude) # TODO @dangotbanned: Priority High @@ -420,22 +438,18 @@ def into_pattern(obj: str | re.Pattern[str] | selectors.Matches, /) -> re.Patter raise TypeError(msg) -def is_regex_projection(name: str) -> bool: - return name.startswith("^") and name.endswith("$") - - # NOTE: Will likely be using `selectors.Matches` for this # Doing a direct translation from `rust` *first*, to make replacing # the deviations *later* not as daunting def replace_regex( origin: ExprIR, /, result: ResultIRs, *, col_names: FrozenColumns, exclude: Excluded -) -> Inplace: +) -> ResultIRs: regex: str | None = None for name in origin.meta.root_names(): if is_regex_projection(name): if regex is None: regex = name - expand_regex( + result = expand_regex( origin, result, into_pattern(name), @@ -448,6 +462,7 @@ def replace_regex( if regex is None: origin = rewrite_special_aliases(origin) result.append(origin) + return result def expand_regex( @@ -458,13 +473,14 @@ def expand_regex( *, col_names: FrozenColumns, exclude: Excluded, -) -> Inplace: +) -> ResultIRs: for name in col_names: if pattern.match(name) and name not in exclude: expanded = remove_exclude(origin) expanded = expanded.map_ir(_replace_regex(pattern, name)) expanded = rewrite_special_aliases(expanded) result.append(expanded) + return result def _replace_regex(pattern: re.Pattern[str], name: str, /) -> Callable[[ExprIR], ExprIR]: From cfdacf24362a028040d6bce42bd92502e8a519a2 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 8 Jun 2025 15:04:59 +0100 Subject: [PATCH 177/368] feat: Impl `expand_columns` --- narwhals/_plan/expr_expansion.py | 47 ++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index b98bfea332..28f2573eab 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -206,6 +206,23 @@ def fn(child: ExprIR, /) -> ExprIR: return origin.map_ir(fn) +def _replace_columns_exclude(origin: ExprIR, /, name: str) -> ExprIR: + """Based on the anonymous function in [`polars_plan::plans::conversion::expr_expansion::expand_columns`]. + + [`polars_plan::plans::conversion::expr_expansion::expand_columns`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L187-L191 + """ + from narwhals._plan.expr import Column, Columns, Exclude + + def fn(child: ExprIR, /) -> ExprIR: + if isinstance(child, Columns): + return Column(name=name) + if isinstance(child, Exclude): + return child.expr + return child + + return origin.map_ir(fn) + + def rewrite_projections( input: Seq[ExprIR], # `FunctionExpr.input` /, @@ -253,7 +270,8 @@ def replace_and_add_to_results( if e := next(it, None): if isinstance(e, expr.Columns): exclude = prepare_excluded(origin, keys=(), has_exclude=flags.has_exclude) - expand_columns( + # NOTE: Transitioned to return result + result = expand_columns( origin, result, e, col_names=_freeze_columns(schema), exclude=exclude ) else: @@ -268,7 +286,7 @@ def replace_and_add_to_results( ) else: exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude) - # NOTE: First case transitioned to return result! + # NOTE: Transitioned to return result result = replace_regex( origin, result, col_names=_freeze_columns(schema), exclude=exclude ) @@ -332,17 +350,30 @@ def prepare_excluded( return frozenset(exclude) -# TODO @dangotbanned: Priority High +def _all_columns_match(origin: ExprIR, /, columns: expr.Columns) -> bool: + from narwhals._plan.expr import Columns + + it = (e == columns if isinstance(e, Columns) else True for e in origin.iter_left()) + return all(it) + + def expand_columns( origin: ExprIR, /, result: ResultIRs, - columns: expr.Columns, # `polars` uses columns.names + columns: expr.Columns, *, col_names: FrozenColumns, exclude: Excluded, -) -> Inplace: - raise NotImplementedError +) -> ResultIRs: + if not _all_columns_match(origin, columns): + msg = "expanding more than one `col` is not allowed" + raise ComputeError(msg) + for name in columns.names: + if name not in exclude: + new_expr = _replace_columns_exclude(origin, name) + result = replace_regex(new_expr, result, col_names=col_names, exclude=exclude) + return result # TODO @dangotbanned: Priority Low @@ -358,7 +389,7 @@ def expand_dtypes( raise NotImplementedError -# TODO @dangotbanned: Priority Mid +# TODO @dangotbanned: Priority High def expand_indices( origin: ExprIR, /, @@ -371,7 +402,7 @@ def expand_indices( raise NotImplementedError -# TODO @dangotbanned: Priority Mid +# TODO @dangotbanned: Priority High def replace_wildcard( origin: ExprIR, /, result: ResultIRs, *, col_names: FrozenColumns, exclude: Excluded ) -> Inplace: From 04427dbfd976fbec90f96b5084ebff069087fe35 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 8 Jun 2025 15:51:37 +0100 Subject: [PATCH 178/368] feat: Impl `expand_indices`, `replace_dtype_or_index_with_column` --- narwhals/_plan/expr_expansion.py | 45 +++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 28f2573eab..9944b8749d 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -223,6 +223,25 @@ def fn(child: ExprIR, /) -> ExprIR: return origin.map_ir(fn) +def replace_dtype_or_index_with_column( + origin: ExprIR, /, name: str, *, replace_dtype: bool = False +) -> ExprIR: + from narwhals._plan.expr import Column, Exclude, IndexColumns + + if replace_dtype: + msg = "We don't have a `Expr::DtypeColumn` node yet, may need to add one for the selectors lift?" + raise NotImplementedError(msg) + + def fn(child: ExprIR, /) -> ExprIR: + if isinstance(child, IndexColumns): + return Column(name=name) + if isinstance(child, Exclude): + return child.expr + return child + + return origin.map_ir(fn) + + def rewrite_projections( input: Seq[ExprIR], # `FunctionExpr.input` /, @@ -278,7 +297,8 @@ def replace_and_add_to_results( exclude = prepare_excluded( origin, keys=keys, has_exclude=flags.has_exclude ) - expand_indices(origin, result, e, schema=schema, exclude=exclude) + # NOTE: Transitioned to return result + result = expand_indices(origin, result, e, schema=schema, exclude=exclude) elif flags.has_wildcard: exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude) replace_wildcard( @@ -389,7 +409,6 @@ def expand_dtypes( raise NotImplementedError -# TODO @dangotbanned: Priority High def expand_indices( origin: ExprIR, /, @@ -398,8 +417,20 @@ def expand_indices( *, schema: FrozenSchema, exclude: Excluded, -) -> Inplace: - raise NotImplementedError +) -> ResultIRs: + n_fields = len(schema) + names = tuple(schema) + for index in indices.indices: + idx = index + n_fields if index < 0 else index + if idx < 0 or idx > n_fields: + msg = f"invalid column index {idx!r}" + raise ComputeError(msg) + name = names[idx] + if name not in exclude: + new_expr = replace_dtype_or_index_with_column(origin, name) + new_expr = rewrite_special_aliases(new_expr) + result.append(new_expr) + return result # TODO @dangotbanned: Priority High @@ -443,12 +474,6 @@ def rewrite_special_aliases(origin: ExprIR, /) -> ExprIR: return origin -def replace_dtype_or_index_with_column( - origin: ExprIR, /, column_name: str, *, replace_dtype: bool -) -> ExprIR: - raise NotImplementedError - - def dtypes_match(left: DType, right: DType | type[DType]) -> bool: return left == right From 7b3641bbf1481855d7e16f004ed4e81e0878814a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 8 Jun 2025 16:03:00 +0100 Subject: [PATCH 179/368] feat: Impl `replace_wildcard`, `replace_wildcard_with_column` --- narwhals/_plan/expr_expansion.py | 37 +++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 9944b8749d..fa8d7af5ec 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -242,6 +242,20 @@ def fn(child: ExprIR, /) -> ExprIR: return origin.map_ir(fn) +def replace_wildcard_with_column(origin: ExprIR, /, name: str) -> ExprIR: + """`expr.All` and `Exclude`.""" + from narwhals._plan.expr import All, Column, Exclude + + def fn(child: ExprIR, /) -> ExprIR: + if isinstance(child, All): + return Column(name=name) + if isinstance(child, Exclude): + return child.expr + return child + + return origin.map_ir(fn) + + def rewrite_projections( input: Seq[ExprIR], # `FunctionExpr.input` /, @@ -275,7 +289,7 @@ def replace_and_add_to_results( *, schema: FrozenSchema, flags: ExpansionFlags, -) -> Inplace: +) -> ResultIRs: from narwhals._plan import expr if flags.has_nth: @@ -289,7 +303,6 @@ def replace_and_add_to_results( if e := next(it, None): if isinstance(e, expr.Columns): exclude = prepare_excluded(origin, keys=(), has_exclude=flags.has_exclude) - # NOTE: Transitioned to return result result = expand_columns( origin, result, e, col_names=_freeze_columns(schema), exclude=exclude ) @@ -297,19 +310,18 @@ def replace_and_add_to_results( exclude = prepare_excluded( origin, keys=keys, has_exclude=flags.has_exclude ) - # NOTE: Transitioned to return result result = expand_indices(origin, result, e, schema=schema, exclude=exclude) elif flags.has_wildcard: exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude) - replace_wildcard( + result = replace_wildcard( origin, result, col_names=_freeze_columns(schema), exclude=exclude ) else: exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude) - # NOTE: Transitioned to return result result = replace_regex( origin, result, col_names=_freeze_columns(schema), exclude=exclude ) + return result def replace_selector( @@ -433,16 +445,15 @@ def expand_indices( return result -# TODO @dangotbanned: Priority High def replace_wildcard( origin: ExprIR, /, result: ResultIRs, *, col_names: FrozenColumns, exclude: Excluded -) -> Inplace: - raise NotImplementedError - - -def replace_wildcard_with_column(origin: ExprIR, /, column_name: str) -> ExprIR: - """`expr.All` and `Exclude`.""" - raise NotImplementedError +) -> ResultIRs: + for name in col_names: + if name not in exclude: + new_expr = replace_wildcard_with_column(origin, name) + new_expr = rewrite_special_aliases(new_expr) + result.append(new_expr) + return result def rewrite_special_aliases(origin: ExprIR, /) -> ExprIR: From f76c9ddbafd367c05d99b572140caadfba1bd1f7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 8 Jun 2025 16:10:08 +0100 Subject: [PATCH 180/368] docs(DRAFT): Add more notes on selectors todo Last? major hurdle before working on `map_ir`/vistor pattern impl --- narwhals/_plan/expr_expansion.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index fa8d7af5ec..b9b8fc24e3 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -324,6 +324,7 @@ def replace_and_add_to_results( return result +# TODO @dangotbanned: Priority high (entry, called by `rewrite_projections`) def replace_selector( ir: ExprIR, # an element of `FunctionExpr.input` /, @@ -334,7 +335,7 @@ def replace_selector( raise NotImplementedError -# TODO @dangotbanned: Huge +# TODO @dangotbanned: Huge, called by `replace_selector` def expand_selector( s: expr.SelectorIR, /, keys: Seq[ExprIR], *, schema: FrozenSchema ) -> Seq[str]: @@ -342,6 +343,7 @@ def expand_selector( raise NotImplementedError +# TODO @dangotbanned: Huge, called by `expand_selector` def replace_selector_inner( s: expr.SelectorIR, /, From f42e202886c6fa1f76f5b44359ffa17f0b23f1cd Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 8 Jun 2025 22:00:19 +0100 Subject: [PATCH 181/368] feat: Impl `replace_selector` --- narwhals/_plan/common.py | 10 ++++++ narwhals/_plan/expr.py | 11 ++++++ narwhals/_plan/expr_expansion.py | 58 ++++++++++++-------------------- narwhals/_plan/selectors.py | 35 +++++++++++++++++++ 4 files changed, 78 insertions(+), 36 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 7ab92b599b..c84ab834f5 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -253,6 +253,16 @@ def to_narwhals(self, version: Version = Version.MAIN) -> DummySelector: return dummy.DummySelector._from_ir(self) return dummy.DummySelectorV1._from_ir(self) + def matches_column(self, name: str, dtype: DType) -> bool: + """Return True if we can select this column. + + - Thinking that we could get more cache hits on an individual column basis. + - May also be more efficient to not iterate over the schema for every selector + - Instead do one pass, evaluating every selector against a single column at a time + - Is that possible? + """ + raise NotImplementedError(type(self)) + class IRNamespace(Immutable): __slots__ = ("_ir",) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 5a7bb3499d..9e2850f360 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -571,6 +571,9 @@ class RootSelector(SelectorIR): def __repr__(self) -> str: return f"{self.selector!r}" + def matches_column(self, name: str, dtype: DType) -> bool: + return self.selector.matches_column(name, dtype) + class BinarySelector( _BinaryOp[LeftSelectorT, SelectorOperatorT, RightSelectorT], @@ -583,6 +586,11 @@ class BinarySelector( `left` and `right` may also nest other `BinarySelector`s. """ + def matches_column(self, name: str, dtype: DType) -> bool: + left = self.left.matches_column(name, dtype) + right = self.right.matches_column(name, dtype) + return bool(self.op(left, right)) + class InvertSelector(SelectorIR, t.Generic[SelectorT]): __slots__ = ("selector",) @@ -593,6 +601,9 @@ class InvertSelector(SelectorIR, t.Generic[SelectorT]): def __repr__(self) -> str: return f"~{self.selector!r}" + def matches_column(self, name: str, dtype: DType) -> bool: + return not self.selector.matches_column(name, dtype) + class Ternary(ExprIR): """When-Then-Otherwise.""" diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index b9b8fc24e3..a63a30f9e3 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -43,7 +43,7 @@ from types import MappingProxyType from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence -from narwhals._plan.common import Immutable, is_regex_projection +from narwhals._plan.common import ExprIR, Immutable, SelectorIR, is_regex_projection from narwhals.exceptions import ComputeError, InvalidOperationError if TYPE_CHECKING: @@ -52,7 +52,7 @@ from typing_extensions import TypeAlias from narwhals._plan import expr, selectors - from narwhals._plan.common import ExprIR, Seq + from narwhals._plan.common import Seq from narwhals._plan.dummy import DummyExpr from narwhals.dtypes import DType @@ -169,7 +169,6 @@ def prepare_projection( # - `exclude` is the return of `prepare_excluded` -# NOTE: The inner function is ready def expand_function_inputs(origin: ExprIR, /, *, schema: FrozenSchema) -> ExprIR: from narwhals._plan import expr @@ -256,6 +255,25 @@ def fn(child: ExprIR, /) -> ExprIR: return origin.map_ir(fn) +def replace_selector( + ir: ExprIR, + /, + keys: Seq[ExprIR], # noqa: ARG001 + *, + schema: FrozenSchema, +) -> ExprIR: + """Fully diverging from `polars`, we'll see how that goes.""" + from narwhals._plan import expr + + def fn(child: ExprIR, /) -> ExprIR: + if isinstance(child, SelectorIR): + cols = (k for k, v in schema.items() if child.matches_column(k, v)) + return expr.Columns(names=tuple(cols)) + return child + + return ir.map_ir(fn) + + def rewrite_projections( input: Seq[ExprIR], # `FunctionExpr.input` /, @@ -275,7 +293,7 @@ def rewrite_projections( expanded = replace_selector(expanded, keys, schema=schema) flags = flags.with_multiple_columns() # NOTE: `result` is what I'd want as a return, rather than inplace - replace_and_add_to_results( + result = replace_and_add_to_results( expanded, result, keys=keys, schema=schema, flags=flags ) return tuple(result) @@ -324,38 +342,6 @@ def replace_and_add_to_results( return result -# TODO @dangotbanned: Priority high (entry, called by `rewrite_projections`) -def replace_selector( - ir: ExprIR, # an element of `FunctionExpr.input` - /, - keys: Seq[ExprIR], - *, - schema: FrozenSchema, -) -> ExprIR: - raise NotImplementedError - - -# TODO @dangotbanned: Huge, called by `replace_selector` -def expand_selector( - s: expr.SelectorIR, /, keys: Seq[ExprIR], *, schema: FrozenSchema -) -> Seq[str]: - """Converts into input of `Columns(...)`.""" - raise NotImplementedError - - -# TODO @dangotbanned: Huge, called by `expand_selector` -def replace_selector_inner( - s: expr.SelectorIR, - /, - keys: Seq[ExprIR], - members: Any, # mutable, insertion order preserving set `PlIndexSet` - scratch: Seq[ExprIR], # passed as `result` into `replace_and_add_to_results` - *, - schema: FrozenSchema, -) -> Inplace: - raise NotImplementedError - - def _iter_exclude_names(origin: ExprIR, /) -> Iterator[str]: """Yield all excluded names in `origin`.""" from narwhals._plan.expr import Exclude diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/selectors.py index e2e3936ba7..9aec0d5a19 100644 --- a/narwhals/_plan/selectors.py +++ b/narwhals/_plan/selectors.py @@ -9,6 +9,7 @@ import re from typing import TYPE_CHECKING, Iterable +from narwhals import dtypes from narwhals._plan.common import Immutable, is_iterable_reject from narwhals._utils import _parse_time_unit_and_time_zone @@ -30,11 +31,17 @@ def to_selector(self) -> RootSelector: return RootSelector(selector=self) + def matches_column(self, name: str, dtype: DType) -> bool: + raise NotImplementedError(type(self)) + class All(Selector): def __repr__(self) -> str: return "ncs.all()" + def matches_column(self, name: str, dtype: DType) -> bool: + return True + class ByDType(Selector): __slots__ = ("dtypes",) @@ -53,16 +60,25 @@ def __repr__(self) -> str: ) return f"ncs.by_dtype(dtypes=[{els}])" + def matches_column(self, name: str, dtype: DType) -> bool: + return dtype in self.dtypes + class Boolean(Selector): def __repr__(self) -> str: return "ncs.boolean()" + def matches_column(self, name: str, dtype: DType) -> bool: + return isinstance(dtype, dtypes.Boolean) + class Categorical(Selector): def __repr__(self) -> str: return "ncs.categorical()" + def matches_column(self, name: str, dtype: DType) -> bool: + return isinstance(dtype, dtypes.Categorical) + class Datetime(Selector): """Should swallow the [`utils` functions]. @@ -89,6 +105,16 @@ def from_time_unit_and_time_zone( def __repr__(self) -> str: return f"ncs.datetime(time_unit={list(self.time_units)}, time_zone={list(self.time_zones)})" + def matches_column(self, name: str, dtype: DType) -> bool: + units, zones = self.time_units, self.time_zones + return ( + isinstance(dtype, dtypes.Datetime) + and (dtype.time_unit in units) + and ( + dtype.time_zone in zones or ("*" in zones and dtype.time_zone is not None) + ) + ) + class Matches(Selector): __slots__ = ("pattern",) @@ -109,16 +135,25 @@ def from_names(*names: str | Iterable[str]) -> Matches: def __repr__(self) -> str: return f"ncs.matches(pattern={self.pattern.pattern!r})" + def matches_column(self, name: str, dtype: DType) -> bool: + return bool(self.pattern.search(name)) + class Numeric(Selector): def __repr__(self) -> str: return "ncs.numeric()" + def matches_column(self, name: str, dtype: DType) -> bool: + return dtype.is_numeric() + class String(Selector): def __repr__(self) -> str: return "ncs.string()" + def matches_column(self, name: str, dtype: DType) -> bool: + return isinstance(dtype, dtypes.String) + def all() -> DummySelector: return All().to_selector().to_narwhals() From cfe3229d2cd985fc27fc1b94034beb999563885a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 8 Jun 2025 22:05:54 +0100 Subject: [PATCH 182/368] revert: Remove unplanned dtypes stuff and comments --- narwhals/_plan/expr_expansion.py | 38 ++------------------------------ 1 file changed, 2 insertions(+), 36 deletions(-) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index a63a30f9e3..d6d6789004 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -159,16 +159,6 @@ def prepare_projection( return rewritten, frozen_schema -# NOTE: Parameters have been re-ordered, renamed, changed types -# - `origin` is the `Expr` that's being iterated over -# - `result` *haven't got to yet* -# - Couldn't this just be the return type? -# - Certainly less complicated in python -# - `` is the current child of `origin` -# - `col_names: FrozenColumns` is used when we don't need the dtypes -# - `exclude` is the return of `prepare_excluded` - - def expand_function_inputs(origin: ExprIR, /, *, schema: FrozenSchema) -> ExprIR: from narwhals._plan import expr @@ -222,15 +212,9 @@ def fn(child: ExprIR, /) -> ExprIR: return origin.map_ir(fn) -def replace_dtype_or_index_with_column( - origin: ExprIR, /, name: str, *, replace_dtype: bool = False -) -> ExprIR: +def replace_index_with_column(origin: ExprIR, /, name: str) -> ExprIR: from narwhals._plan.expr import Column, Exclude, IndexColumns - if replace_dtype: - msg = "We don't have a `Expr::DtypeColumn` node yet, may need to add one for the selectors lift?" - raise NotImplementedError(msg) - def fn(child: ExprIR, /) -> ExprIR: if isinstance(child, IndexColumns): return Column(name=name) @@ -292,7 +276,6 @@ def rewrite_projections( if flags.has_selector: expanded = replace_selector(expanded, keys, schema=schema) flags = flags.with_multiple_columns() - # NOTE: `result` is what I'd want as a return, rather than inplace result = replace_and_add_to_results( expanded, result, keys=keys, schema=schema, flags=flags ) @@ -396,19 +379,6 @@ def expand_columns( return result -# TODO @dangotbanned: Priority Low -def expand_dtypes( - origin: ExprIR, - /, - result: ResultIRs, - dtypes: selectors.ByDType, # we haven't got `DtypeColumn` - *, - schema: FrozenSchema, - exclude: Excluded, -) -> Inplace: - raise NotImplementedError - - def expand_indices( origin: ExprIR, /, @@ -427,7 +397,7 @@ def expand_indices( raise ComputeError(msg) name = names[idx] if name not in exclude: - new_expr = replace_dtype_or_index_with_column(origin, name) + new_expr = replace_index_with_column(origin, name) new_expr = rewrite_special_aliases(new_expr) result.append(new_expr) return result @@ -473,10 +443,6 @@ def rewrite_special_aliases(origin: ExprIR, /) -> ExprIR: return origin -def dtypes_match(left: DType, right: DType | type[DType]) -> bool: - return left == right - - def into_pattern(obj: str | re.Pattern[str] | selectors.Matches, /) -> re.Pattern[str]: import re From 804ac3da4f25cd9218bc5dacc2257224057fc55c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 8 Jun 2025 22:07:44 +0100 Subject: [PATCH 183/368] =?UTF-8?q?chore:=20Remove=20factored-out=20`Inpla?= =?UTF-8?q?ce`=20=F0=9F=A5=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- narwhals/_plan/expr_expansion.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index d6d6789004..8098f8ff2f 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -41,7 +41,7 @@ from collections import deque from copy import deepcopy from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence +from typing import TYPE_CHECKING, Callable, Iterator, Mapping, Sequence from narwhals._plan.common import ExprIR, Immutable, SelectorIR, is_regex_projection from narwhals.exceptions import ComputeError, InvalidOperationError @@ -62,12 +62,6 @@ Excluded: TypeAlias = "frozenset[str]" """Internally use a `set`, then freeze before returning.""" -Inplace: TypeAlias = Any -"""Functions where `polars` does in-place mutations on `Expr`. - -Very likely that **we won't** do this in `narwhals`, instead return a new object. -""" - ResultIRs: TypeAlias = "deque[ExprIR]" From 353ef597189f46148cce04d1e3b527b86d1b60e5 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 8 Jun 2025 22:10:16 +0100 Subject: [PATCH 184/368] chore: use `Version.dtypes` https://results.pre-commit.ci/run/github/760058710/1749416767.4pbB92p5RLmBWJdFLtx-yg --- narwhals/_plan/selectors.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/selectors.py index 9aec0d5a19..0fa94bfec7 100644 --- a/narwhals/_plan/selectors.py +++ b/narwhals/_plan/selectors.py @@ -9,9 +9,8 @@ import re from typing import TYPE_CHECKING, Iterable -from narwhals import dtypes from narwhals._plan.common import Immutable, is_iterable_reject -from narwhals._utils import _parse_time_unit_and_time_zone +from narwhals._utils import Version, _parse_time_unit_and_time_zone if TYPE_CHECKING: from datetime import timezone @@ -24,6 +23,8 @@ T = TypeVar("T") +dtypes = Version.MAIN.dtypes + class Selector(Immutable): def to_selector(self) -> RootSelector: From 1d6332607ff179ba97168f558c7f9b34b3e52027 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 9 Jun 2025 15:21:54 +0100 Subject: [PATCH 185/368] feat: Impl `ExprIR.map_ir` for most nodes https://github.com/narwhals-dev/narwhals/pull/2572#discussion_r2134750336 TODO: - `FunctionExpr` - `RollingExpr` - `AnonymousExpr` - `WindowExpr` --- narwhals/_plan/aggregation.py | 12 ++++ narwhals/_plan/common.py | 6 +- narwhals/_plan/expr.py | 113 ++++++++++++++++++++++++++++++ narwhals/_plan/name.py | 35 +++++++++ narwhals/_plan/typing.py | 10 +++ tests/plan/expr_expansion_test.py | 37 +++++++--- 6 files changed, 200 insertions(+), 13 deletions(-) diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index b4cf4fb44f..e8013cc02b 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -8,6 +8,9 @@ if TYPE_CHECKING: from typing import Iterator + from typing_extensions import Self + + from narwhals._plan.typing import MapIR from narwhals.typing import RollingInterpolationMethod @@ -36,6 +39,15 @@ def iter_right(self) -> Iterator[ExprIR]: yield self yield from self.expr.iter_right() + def map_ir(self, function: MapIR, /) -> ExprIR: + return function(self.with_expr(self.expr.map_ir(function))) + + def with_expr(self, expr: ExprIR, /) -> Self: + if expr == self.expr: + return self + it = ((k, v) for k, v in self.__immutable_items__ if k != "expr") + return type(self)(expr=expr, **dict(it)) + def __init__(self, *, expr: ExprIR, **kwds: Any) -> None: if expr.is_scalar: raise agg_scalar_error(self, expr) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index c84ab834f5..812de34122 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -4,7 +4,7 @@ from decimal import Decimal from typing import TYPE_CHECKING, Generic, TypeVar -from narwhals._plan.typing import ExprT, IRNamespaceT, Ns +from narwhals._plan.typing import ExprT, IRNamespaceT, MapIR, Ns from narwhals.utils import Version if TYPE_CHECKING: @@ -174,7 +174,7 @@ def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: def is_scalar(self) -> bool: return False - def map_ir(self, function: Callable[[ExprIR], ExprIR], /) -> ExprIR: + def map_ir(self, function: MapIR, /) -> ExprIR: """Apply `function` to each child node, returning a new `ExprIR`. See [`polars_plan::plans::iterator::Expr.map_expr`] and [`polars_plan::plans::visitor::visitors`]. @@ -182,7 +182,7 @@ def map_ir(self, function: Callable[[ExprIR], ExprIR], /) -> ExprIR: [`polars_plan::plans::iterator::Expr.map_expr`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/iterator.rs#L152-L159 [`polars_plan::plans::visitor::visitors`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/visitor/visitors.rs """ - msg = "Need to handle recursive visiting first!" + msg = f"Need to handle recursive visiting first for {type(self).__qualname__!r}!\n\n{self!r}" raise NotImplementedError(msg) def iter_left(self) -> Iterator[ExprIR]: diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 9e2850f360..70f78cdf6f 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -25,11 +25,14 @@ FunctionT, LeftSelectorT, LeftT, + LeftT2, LiteralT, + MapIR, Ns, OperatorT, RightSelectorT, RightT, + RightT2, RollingT, SelectorOperatorT, SelectorT, @@ -110,6 +113,12 @@ def __init__(self, *, expr: ExprIR, name: str) -> None: kwds = {"expr": expr, "name": name} super().__init__(**kwds) + def map_ir(self, function: MapIR, /) -> ExprIR: + return function(self.with_expr(self.expr.map_ir(function))) + + def with_expr(self, expr: ExprIR, /) -> Self: + return self if expr == self.expr else type(self)(expr=expr, name=self.name) + class Column(ExprIR): __slots__ = ("name",) @@ -125,6 +134,9 @@ def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: def with_name(self, name: str, /) -> Column: return self if name == self.name else Column(name=name) + def map_ir(self, function: MapIR, /) -> ExprIR: + return function(self) + def _col(name: str, /) -> Column: return Column(name=name) @@ -141,6 +153,9 @@ def expand_columns(self, schema: _Schema, /) -> Seq[Column]: """Transform selection in context of `schema` into simpler nodes.""" raise NotImplementedError + def map_ir(self, function: MapIR, /) -> ExprIR: + return function(self) + class Columns(_ColumnSelection): __slots__ = ("names",) @@ -234,6 +249,12 @@ def expand_columns(self, schema: _Schema) -> Seq[Column]: raise NotImplementedError(msg) return _cols(name for name in schema if name not in self.names) + def map_ir(self, function: MapIR, /) -> ExprIR: + return function(self.with_expr(self.expr.map_ir(function))) + + def with_expr(self, expr: ExprIR, /) -> Self: + return self if expr == self.expr else type(self)(expr=expr, names=self.names) + class Literal(ExprIR, t.Generic[LiteralT]): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L81.""" @@ -266,6 +287,9 @@ def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: def unwrap(self) -> LiteralT: return self.value.unwrap() + def map_ir(self, function: MapIR, /) -> ExprIR: + return function(self) + class _BinaryOp(ExprIR, t.Generic[LeftT, OperatorT, RightT]): __slots__ = ("left", "op", "right") @@ -297,6 +321,23 @@ def iter_right(self) -> t.Iterator[ExprIR]: yield from self.right.iter_right() yield from self.left.iter_right() + def with_left(self, left: LeftT2, /) -> BinaryExpr[LeftT2, OperatorT, RightT]: + if left == self.left: + return t.cast("BinaryExpr[LeftT2, OperatorT, RightT]", self) + return BinaryExpr(left=left, op=self.op, right=self.right) + + def with_right(self, right: RightT2, /) -> BinaryExpr[LeftT, OperatorT, RightT2]: + if right == self.right: + return t.cast("BinaryExpr[LeftT, OperatorT, RightT2]", self) + return BinaryExpr(left=self.left, op=self.op, right=right) + + def map_ir(self, function: MapIR, /) -> ExprIR: + return function( + self.with_left(self.left.map_ir(function)).with_right( + self.right.map_ir(function) + ) + ) + class Cast(ExprIR): __slots__ = ("dtype", "expr") @@ -319,6 +360,12 @@ def iter_right(self) -> t.Iterator[ExprIR]: yield self yield from self.expr.iter_right() + def map_ir(self, function: MapIR, /) -> ExprIR: + return function(self.with_expr(self.expr.map_ir(function))) + + def with_expr(self, expr: ExprIR, /) -> Self: + return self if expr == self.expr else type(self)(expr=expr, dtype=self.dtype) + class Sort(ExprIR): __slots__ = ("expr", "options") @@ -342,6 +389,12 @@ def iter_right(self) -> t.Iterator[ExprIR]: yield self yield from self.expr.iter_right() + def map_ir(self, function: MapIR, /) -> ExprIR: + return function(self.with_expr(self.expr.map_ir(function))) + + def with_expr(self, expr: ExprIR, /) -> Self: + return self if expr == self.expr else type(self)(expr=expr, options=self.options) + class SortBy(ExprIR): """https://github.com/narwhals-dev/narwhals/issues/2534.""" @@ -371,7 +424,23 @@ def iter_right(self) -> t.Iterator[ExprIR]: yield from e.iter_right() yield from self.expr.iter_right() + def map_ir(self, function: MapIR, /) -> ExprIR: + by = (ir.map_ir(function) for ir in self.by) + return function(self.with_expr(self.expr.map_ir(function)).with_by(by)) + + def with_expr(self, expr: ExprIR, /) -> Self: + if expr == self.expr: + return self + return type(self)(expr=expr, by=self.by, options=self.options) + def with_by(self, by: t.Iterable[ExprIR], /) -> Self: + by = tuple(by) if not isinstance(by, tuple) else by + if by == self.by: + return self + return type(self)(expr=self.expr, by=by, options=self.options) + + +# TODO @dangotbanned: recursive `map_ir` scheme class FunctionExpr(ExprIR, t.Generic[FunctionT]): """**Representing `Expr::Function`**. @@ -451,6 +520,7 @@ class AnonymousExpr(FunctionExpr["MapBatches"]): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L158-L166.""" +# TODO @dangotbanned: add `DummyExpr.filter` class Filter(ExprIR): __slots__ = ("by", "expr") @@ -474,7 +544,15 @@ def iter_right(self) -> t.Iterator[ExprIR]: yield from self.by.iter_right() yield from self.expr.iter_right() + def map_ir(self, function: MapIR, /) -> ExprIR: + expr = self.expr.map_ir(function) + by = self.by.map_ir(function) + expr = self.expr if self.expr == expr else expr + by = self.by if self.by == by else by + return function(Filter(expr=expr, by=by)) + +# TODO @dangotbanned: recursive `map_ir` scheme class WindowExpr(ExprIR): """A fully specified `.over()`, that occurred after another expression. @@ -559,6 +637,9 @@ def name(self) -> str: def __repr__(self) -> str: return "len()" + def map_ir(self, function: MapIR, /) -> ExprIR: + return function(self) + class RootSelector(SelectorIR): """A single selector expression.""" @@ -574,7 +655,12 @@ def __repr__(self) -> str: def matches_column(self, name: str, dtype: DType) -> bool: return self.selector.matches_column(name, dtype) + def map_ir(self, function: MapIR, /) -> ExprIR: + return function(self) + +# NOTE: selectors don't make sense to have recusrive mapping *for now* `(Binary|Invert)Selector` +# If a function replaces the inner type with a non-selector, the other methods will break class BinarySelector( _BinaryOp[LeftSelectorT, SelectorOperatorT, RightSelectorT], SelectorIR, @@ -591,6 +677,9 @@ def matches_column(self, name: str, dtype: DType) -> bool: right = self.right.matches_column(name, dtype) return bool(self.op(left, right)) + def map_ir(self, function: MapIR, /) -> ExprIR: + return function(self) + class InvertSelector(SelectorIR, t.Generic[SelectorT]): __slots__ = ("selector",) @@ -604,6 +693,9 @@ def __repr__(self) -> str: def matches_column(self, name: str, dtype: DType) -> bool: return not self.selector.matches_column(name, dtype) + def map_ir(self, function: MapIR, /) -> ExprIR: + return function(self) + class Ternary(ExprIR): """When-Then-Otherwise.""" @@ -627,3 +719,24 @@ def __repr__(self) -> str: return ( f".when({self.predicate!r}).then({self.truthy!r}).otherwise({self.falsy!r})" ) + + def iter_left(self) -> t.Iterator[ExprIR]: + yield from self.truthy.iter_left() + yield from self.falsy.iter_left() + yield from self.predicate.iter_left() + yield self + + def iter_right(self) -> t.Iterator[ExprIR]: + yield self + yield from self.predicate.iter_right() + yield from self.falsy.iter_right() + yield from self.truthy.iter_right() + + def map_ir(self, function: MapIR, /) -> ExprIR: + predicate = self.predicate.map_ir(function) + truthy = self.truthy.map_ir(function) + falsy = self.falsy.map_ir(function) + predicate = self.predicate if self.predicate == predicate else predicate + truthy = self.truthy if self.truthy == truthy else truthy + falsy = self.falsy if self.falsy == falsy else falsy + return function(Ternary(predicate=predicate, truthy=truthy, falsy=falsy)) diff --git a/narwhals/_plan/name.py b/narwhals/_plan/name.py index d382c49292..86a54d6fbe 100644 --- a/narwhals/_plan/name.py +++ b/narwhals/_plan/name.py @@ -5,8 +5,13 @@ from narwhals._plan.common import ExprIR, ExprNamespace, Immutable, IRNamespace if TYPE_CHECKING: + from typing import Iterator + + from typing_extensions import Self + from narwhals._compliant.typing import AliasName from narwhals._plan.dummy import DummyExpr + from narwhals._plan.typing import MapIR class KeepName(ExprIR): @@ -19,6 +24,20 @@ class KeepName(ExprIR): def __repr__(self) -> str: return f"{self.expr!r}.name.keep()" + def iter_left(self) -> Iterator[ExprIR]: + yield from self.expr.iter_left() + yield self + + def iter_right(self) -> Iterator[ExprIR]: + yield self + yield from self.expr.iter_right() + + def map_ir(self, function: MapIR, /) -> ExprIR: + return function(self.with_expr(self.expr.map_ir(function))) + + def with_expr(self, expr: ExprIR, /) -> Self: + return self if expr == self.expr else type(self)(expr=expr) + class RenameAlias(ExprIR): __slots__ = ("expr", "function") @@ -29,6 +48,22 @@ class RenameAlias(ExprIR): def __repr__(self) -> str: return f".rename_alias({self.expr!r})" + def iter_left(self) -> Iterator[ExprIR]: + yield from self.expr.iter_left() + yield self + + def iter_right(self) -> Iterator[ExprIR]: + yield self + yield from self.expr.iter_right() + + def map_ir(self, function: MapIR, /) -> ExprIR: + return function(self.with_expr(self.expr.map_ir(function))) + + def with_expr(self, expr: ExprIR, /) -> Self: + return ( + self if expr == self.expr else type(self)(expr=expr, function=self.function) + ) + class Prefix(Immutable): __slots__ = ("prefix",) diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 1297974bea..dee99325ee 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -17,8 +17,14 @@ __all__ = [ "FunctionT", + "LeftSelectorT", "LeftT", + "LiteralT", + "MapIR", + "NonNestedLiteralT", + "OperatorFn", "OperatorT", + "RightSelectorT", "RightT", "RollingT", "SelectorOperatorT", @@ -29,8 +35,10 @@ FunctionT = TypeVar("FunctionT", bound="Function") RollingT = TypeVar("RollingT", bound="RollingWindow") LeftT = TypeVar("LeftT", bound="ExprIR", default="ExprIR") +LeftT2 = TypeVar("LeftT2", bound="ExprIR", default="ExprIR") OperatorT = TypeVar("OperatorT", bound="ops.Operator", default="ops.Operator") RightT = TypeVar("RightT", bound="ExprIR", default="ExprIR") +RightT2 = TypeVar("RightT2", bound="ExprIR", default="ExprIR") OperatorFn: TypeAlias = "t.Callable[[t.Any, t.Any], t.Any]" SelectorT = TypeVar("SelectorT", bound="SelectorIR", default="SelectorIR") @@ -45,6 +53,8 @@ "NonNestedLiteralT", bound="NonNestedLiteral", default="NonNestedLiteral" ) LiteralT = TypeVar("LiteralT", bound="NonNestedLiteral | DummySeries", default=t.Any) +MapIR: TypeAlias = "t.Callable[[ExprIR], ExprIR]" +"""A function to apply to all nodes in this tree.""" # NOTE: Shorter aliases of `_compliant.typing` # - Aiming to try and preserve the types as much as possible diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index cba1f35366..2cd0506a1a 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Sequence +from typing import TYPE_CHECKING, Sequence import pytest @@ -11,16 +11,14 @@ from narwhals.exceptions import ColumnNotFoundError, ComputeError if TYPE_CHECKING: - from typing_extensions import TypeAlias, TypeIs + from typing_extensions import TypeIs from narwhals._plan.common import ExprIR from narwhals._plan.dummy import DummyExpr + from narwhals._plan.typing import MapIR from narwhals.dtypes import DType -MapIR: TypeAlias = "Callable[[ExprIR], ExprIR]" - - @pytest.fixture def schema_1() -> dict[str, DType]: return { @@ -234,9 +232,14 @@ def fn(ir: ExprIR) -> ExprIR: return fn -@pytest.mark.xfail( - reason="Not implemented `ExprIR.map_ir` yet", raises=NotImplementedError +xfail_window_expr_map_ir = pytest.mark.xfail( + reason="Not implemented `WindowExpr.map_ir` yet", raises=NotImplementedError +) +xfail_function_expr_map_ir = pytest.mark.xfail( + reason="Not implemented `FunctionExpr.map_ir` yet", raises=NotImplementedError ) + + @pytest.mark.parametrize( ("expr", "function", "into_expected"), [ @@ -244,15 +247,29 @@ def fn(ir: ExprIR) -> ExprIR: (nwd.col("a"), alias_replace_unguarded("never"), nwd.col("a")), (nwd.col("a").alias("b"), alias_replace_guarded("c"), nwd.col("a").alias("c")), (nwd.col("a").alias("b"), alias_replace_unguarded("c"), nwd.col("a").alias("c")), - ( + pytest.param( nwd.col("a").alias("d").first().over("b", order_by="c").alias("e"), alias_replace_guarded("d"), nwd.col("a").alias("d").first().over("b", order_by="c").alias("d"), + marks=xfail_window_expr_map_ir, ), - ( + pytest.param( nwd.col("a").alias("d").first().over("b", order_by="c").alias("e"), alias_replace_unguarded("d"), nwd.col("a").alias("d").first().over("b", order_by="c").alias("d"), + marks=xfail_window_expr_map_ir, + ), + pytest.param( + nwd.col("a").alias("e").abs().alias("f").sort().alias("g"), + alias_replace_guarded("e"), + nwd.col("a").alias("e").abs().alias("e").sort().alias("e"), + marks=xfail_function_expr_map_ir, + ), + pytest.param( + nwd.col("a").alias("e").abs().alias("f").sort().alias("g"), + alias_replace_unguarded("e"), + nwd.col("a").alias("e").abs().alias("e").sort().alias("e"), + marks=xfail_function_expr_map_ir, ), ], ) @@ -262,4 +279,4 @@ def test_map_ir_recursive( ir = expr._ir expected = into_expected._ir actual = ir.map_ir(function) - assert actual == expected # pragma: no cover + assert actual == expected From d3ea987b3dbb8c34f376d41e2ee8b1ae4aeef76a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 9 Jun 2025 15:25:03 +0100 Subject: [PATCH 186/368] fix: typo https://results.pre-commit.ci/run/github/760058710/1749478993.R3U9GbAWTQmNenQtoD2ELw --- narwhals/_plan/expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 70f78cdf6f..0961ef2e49 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -659,7 +659,7 @@ def map_ir(self, function: MapIR, /) -> ExprIR: return function(self) -# NOTE: selectors don't make sense to have recusrive mapping *for now* `(Binary|Invert)Selector` +# NOTE: selectors don't make sense to have recursive mapping *for now* `(Binary|Invert)Selector` # If a function replaces the inner type with a non-selector, the other methods will break class BinarySelector( _BinaryOp[LeftSelectorT, SelectorOperatorT, RightSelectorT], From 4604d9aecb1e6eddd195c379e82e5fa35d6b7643 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 9 Jun 2025 16:06:27 +0100 Subject: [PATCH 187/368] test: add `assert_expr_ir_equal` Need to do this a lot for the selectors tests --- tests/plan/expr_expansion_test.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 2cd0506a1a..457f5dd3b7 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -6,6 +6,7 @@ import narwhals as nw import narwhals._plan.demo as nwd +from narwhals._plan.common import is_expr from narwhals._plan.expr import Alias, Column, _ColumnSelection from narwhals._plan.expr_expansion import rewrite_special_aliases from narwhals.exceptions import ColumnNotFoundError, ComputeError @@ -45,6 +46,12 @@ def schema_1() -> dict[str, DType]: } +def assert_expr_ir_equal(left: DummyExpr | ExprIR, right: DummyExpr | ExprIR) -> None: + lhs = left._ir if is_expr(left) else left + rhs = right._ir if is_expr(right) else right + assert lhs == rhs + + # NOTE: The meta check doesn't provide typing and describes a superset of `_ColumnSelection` def is_column_selection(obj: ExprIR) -> TypeIs[_ColumnSelection]: return obj.meta.is_column_selection(allow_aliasing=False) and isinstance( @@ -276,7 +283,5 @@ def fn(ir: ExprIR) -> ExprIR: def test_map_ir_recursive( expr: DummyExpr, function: MapIR, into_expected: DummyExpr ) -> None: - ir = expr._ir - expected = into_expected._ir - actual = ir.map_ir(function) - assert actual == expected + actual = expr._ir.map_ir(function) + assert_expr_ir_equal(actual, into_expected) From 360caec89d70f952a317da06b98406be06ce63a8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 9 Jun 2025 16:40:21 +0100 Subject: [PATCH 188/368] test: Add `test_replace_selector` Ensures selectors are expanded into `col(...)`, whilst keeping all other parts intact @MarcoGorelli --- tests/plan/expr_expansion_test.py | 60 ++++++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 9 deletions(-) diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 457f5dd3b7..1bb96c3758 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -5,17 +5,21 @@ import pytest import narwhals as nw -import narwhals._plan.demo as nwd +from narwhals._plan import demo as nwd, selectors as ndcs from narwhals._plan.common import is_expr -from narwhals._plan.expr import Alias, Column, _ColumnSelection -from narwhals._plan.expr_expansion import rewrite_special_aliases +from narwhals._plan.expr import Alias, Column, Columns, _ColumnSelection +from narwhals._plan.expr_expansion import ( + FrozenSchema, + replace_selector, + rewrite_special_aliases, +) from narwhals.exceptions import ColumnNotFoundError, ComputeError if TYPE_CHECKING: from typing_extensions import TypeIs from narwhals._plan.common import ExprIR - from narwhals._plan.dummy import DummyExpr + from narwhals._plan.dummy import DummyExpr, DummySelector from narwhals._plan.typing import MapIR from narwhals.dtypes import DType @@ -248,7 +252,7 @@ def fn(ir: ExprIR) -> ExprIR: @pytest.mark.parametrize( - ("expr", "function", "into_expected"), + ("expr", "function", "expected"), [ (nwd.col("a"), alias_replace_guarded("never"), nwd.col("a")), (nwd.col("a"), alias_replace_unguarded("never"), nwd.col("a")), @@ -280,8 +284,46 @@ def fn(ir: ExprIR) -> ExprIR: ), ], ) -def test_map_ir_recursive( - expr: DummyExpr, function: MapIR, into_expected: DummyExpr -) -> None: +def test_map_ir_recursive(expr: DummyExpr, function: MapIR, expected: DummyExpr) -> None: actual = expr._ir.map_ir(function) - assert_expr_ir_equal(actual, into_expected) + assert_expr_ir_equal(actual, expected) + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (nwd.col("a"), nwd.col("a")), + (nwd.col("a").max().alias("z"), nwd.col("a").max().alias("z")), + (ndcs.string(), Columns(names=("k",))), + ( + ndcs.by_dtype(nw.Datetime("ms"), nw.Date, nw.List(nw.String)), + nwd.col("n", "s"), + ), + (ndcs.string() | ndcs.boolean(), nwd.col("k", "m")), + ( + ~(ndcs.numeric() | ndcs.string()), + nwd.col("l", "m", "n", "o", "p", "q", "r", "s", "u"), + ), + ( + ( + ndcs.all() + - (ndcs.categorical() | ndcs.by_name("a", "b") | ndcs.matches("[fqohim]")) + ^ ndcs.by_name("u", "a", "b", "d", "e", "f", "g") + ).name.suffix("_after"), + nwd.col("a", "b", "c", "f", "j", "k", "l", "n", "r", "s").name.suffix( + "_after" + ), + ), + ( + (ndcs.matches("[a-m]") & ~ndcs.numeric()).sort(nulls_last=True).first() + != nwd.lit(None), + nwd.col("k", "l", "m").sort(nulls_last=True).first() != nwd.lit(None), + ), + ], +) +def test_replace_selector( + expr: DummySelector | DummyExpr, expected: DummyExpr | ExprIR, schema_1: FrozenSchema +) -> None: + group_by_keys = () + actual = replace_selector(expr._ir, group_by_keys, schema=schema_1) + assert_expr_ir_equal(actual, expected) From 25501036e8ad531225643841bb5d75c355f0c500 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 9 Jun 2025 18:33:51 +0100 Subject: [PATCH 189/368] feat: Impl `WindowExpr.map_ir` Included a test using selectors in `order_by` --- narwhals/_plan/expr.py | 64 ++++++++++++++++++++++++++++++- narwhals/_plan/options.py | 4 ++ tests/plan/expr_expansion_test.py | 21 ++++++---- 3 files changed, 80 insertions(+), 9 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 0961ef2e49..36ce025cb9 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -20,6 +20,7 @@ function_expr_invalid_operation_error, ) from narwhals._plan.name import KeepName, RenameAlias +from narwhals._plan.options import SortOptions from narwhals._plan.typing import ( ExprT, FunctionT, @@ -45,7 +46,7 @@ from narwhals._plan.common import Seq from narwhals._plan.functions import MapBatches # noqa: F401 from narwhals._plan.literal import LiteralValue - from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions + from narwhals._plan.options import FunctionOptions, SortMultipleOptions from narwhals._plan.selectors import Selector from narwhals._plan.window import Window from narwhals.dtypes import DType @@ -552,7 +553,8 @@ def map_ir(self, function: MapIR, /) -> ExprIR: return function(Filter(expr=expr, by=by)) -# TODO @dangotbanned: recursive `map_ir` scheme +# NOTE: Probably need to split out `order_by` +# Really frustrating to handle the `None` case everywhere class WindowExpr(ExprIR): """A fully specified `.over()`, that occurred after another expression. @@ -591,6 +593,13 @@ class WindowExpr(ExprIR): Expr::Window { options: WindowType::Rolling(RollingGroupOptions) } """ + @property + def sort_options(self) -> SortOptions: + if self.order_by: + _, opt = self.order_by + return opt + return SortOptions.default() + def __repr__(self) -> str: if self.order_by is None: return f"{self.expr!r}.over({list(self.partition_by)!r})" @@ -624,6 +633,57 @@ def iter_right(self) -> t.Iterator[ExprIR]: yield from e.iter_right() yield from self.expr.iter_right() + def map_ir(self, function: MapIR, /) -> ExprIR: + over = self.with_expr(self.expr.map_ir(function)).with_partition_by( + ir.map_ir(function) for ir in self.partition_by + ) + if self.order_by: + by, _ = self.order_by + over = over.with_order_by(ir.map_ir(function) for ir in by) + return function(over) + + def with_expr(self, expr: ExprIR, /) -> Self: + if expr == self.expr: + return self + return type(self)( + expr=expr, + partition_by=self.partition_by, + order_by=self.order_by, + options=self.options, + ) + + def with_partition_by(self, partition_by: t.Iterable[ExprIR], /) -> Self: + by = tuple(partition_by) if not isinstance(partition_by, tuple) else partition_by + if by == self.partition_by: + return self + return type(self)( + expr=self.expr, partition_by=by, order_by=self.order_by, options=self.options + ) + + def with_order_by(self, order_by: t.Iterable[ExprIR], /) -> Self: + # NOTE: Not thrilled about this but there's complexity to solve + next_order_by: tuple[Seq[ExprIR], SortOptions] | None + if by := (tuple(order_by) if not isinstance(order_by, tuple) else order_by): + if prev := self.order_by: + prev_by, prev_sort = prev + # NOTE: Very hidden check for no-op possibility + if by == prev_by: + return self + next_order_by = by, prev_sort + else: + next_order_by = by, self.sort_options + elif prev := self.order_by: + # NOTE: Unsure if we'd ever want to do this, but need to be exhaustive + next_order_by = None + else: + return self + return type(self)( + expr=self.expr, + partition_by=self.partition_by, + order_by=next_order_by, + options=self.options, + ) + class Len(ExprIR): @property diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 6283c764d2..9008ec1979 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -137,6 +137,10 @@ def __repr__(self) -> str: args = f"descending={self.descending!r}, nulls_last={self.nulls_last!r}" return f"{type(self).__name__}({args})" + @staticmethod + def default() -> SortOptions: + return SortOptions(descending=False, nulls_last=False) + class SortMultipleOptions(Immutable): __slots__ = ("descending", "nulls_last") diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 1bb96c3758..370689e2b9 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -243,9 +243,6 @@ def fn(ir: ExprIR) -> ExprIR: return fn -xfail_window_expr_map_ir = pytest.mark.xfail( - reason="Not implemented `WindowExpr.map_ir` yet", raises=NotImplementedError -) xfail_function_expr_map_ir = pytest.mark.xfail( reason="Not implemented `FunctionExpr.map_ir` yet", raises=NotImplementedError ) @@ -258,17 +255,15 @@ def fn(ir: ExprIR) -> ExprIR: (nwd.col("a"), alias_replace_unguarded("never"), nwd.col("a")), (nwd.col("a").alias("b"), alias_replace_guarded("c"), nwd.col("a").alias("c")), (nwd.col("a").alias("b"), alias_replace_unguarded("c"), nwd.col("a").alias("c")), - pytest.param( + ( nwd.col("a").alias("d").first().over("b", order_by="c").alias("e"), alias_replace_guarded("d"), nwd.col("a").alias("d").first().over("b", order_by="c").alias("d"), - marks=xfail_window_expr_map_ir, ), - pytest.param( + ( nwd.col("a").alias("d").first().over("b", order_by="c").alias("e"), alias_replace_unguarded("d"), nwd.col("a").alias("d").first().over("b", order_by="c").alias("d"), - marks=xfail_window_expr_map_ir, ), pytest.param( nwd.col("a").alias("e").abs().alias("f").sort().alias("g"), @@ -319,6 +314,18 @@ def test_map_ir_recursive(expr: DummyExpr, function: MapIR, expected: DummyExpr) != nwd.lit(None), nwd.col("k", "l", "m").sort(nulls_last=True).first() != nwd.lit(None), ), + ( + ( + ndcs.numeric() + .mean() + .over("k", order_by=ndcs.by_dtype(nw.Date()) | ndcs.boolean()) + ), + ( + nwd.col("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") + .mean() + .over(nwd.col("k"), order_by=nwd.col("m", "n")) + ), + ), ], ) def test_replace_selector( From 7dd7092a318bae80090a92d54de0e5387374e178 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 9 Jun 2025 19:27:38 +0100 Subject: [PATCH 190/368] feat: Impl `FunctionExpr.map_ir` https://github.com/narwhals-dev/narwhals/pull/2572#discussion_r2134750336 --- narwhals/_plan/expr.py | 6 +++++- tests/plan/expr_expansion_test.py | 29 ++++++++++++++++++++--------- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 36ce025cb9..684dbf9d06 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -441,7 +441,6 @@ def with_by(self, by: t.Iterable[ExprIR], /) -> Self: return type(self)(expr=self.expr, by=by, options=self.options) -# TODO @dangotbanned: recursive `map_ir` scheme class FunctionExpr(ExprIR, t.Generic[FunctionT]): """**Representing `Expr::Function`**. @@ -479,8 +478,13 @@ def with_options(self, options: FunctionOptions, /) -> Self: def with_input(self, input: t.Iterable[ExprIR], /) -> Self: # noqa: A002 if not isinstance(input, tuple): input = tuple(input) + if input == self.input: + return self return type(self)(input=input, function=self.function, options=self.options) + def map_ir(self, function: MapIR, /) -> ExprIR: + return function(self.with_input(ir.map_ir(function) for ir in self.input)) + def __repr__(self) -> str: if self.input: first = self.input[0] diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 370689e2b9..f219c91a92 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -243,11 +243,6 @@ def fn(ir: ExprIR) -> ExprIR: return fn -xfail_function_expr_map_ir = pytest.mark.xfail( - reason="Not implemented `FunctionExpr.map_ir` yet", raises=NotImplementedError -) - - @pytest.mark.parametrize( ("expr", "function", "expected"), [ @@ -265,17 +260,15 @@ def fn(ir: ExprIR) -> ExprIR: alias_replace_unguarded("d"), nwd.col("a").alias("d").first().over("b", order_by="c").alias("d"), ), - pytest.param( + ( nwd.col("a").alias("e").abs().alias("f").sort().alias("g"), alias_replace_guarded("e"), nwd.col("a").alias("e").abs().alias("e").sort().alias("e"), - marks=xfail_function_expr_map_ir, ), - pytest.param( + ( nwd.col("a").alias("e").abs().alias("f").sort().alias("g"), alias_replace_unguarded("e"), nwd.col("a").alias("e").abs().alias("e").sort().alias("e"), - marks=xfail_function_expr_map_ir, ), ], ) @@ -326,6 +319,24 @@ def test_map_ir_recursive(expr: DummyExpr, function: MapIR, expected: DummyExpr) .over(nwd.col("k"), order_by=nwd.col("m", "n")) ), ), + ( + ( + ndcs.datetime() + .dt.timestamp() + .min() + .over(ndcs.string() | ndcs.boolean()) + .last() + .name.to_uppercase() + ), + ( + nwd.col("l", "o") + .dt.timestamp("us") + .min() + .over(nwd.col("k", "m")) + .last() + .name.to_uppercase() + ), + ), ], ) def test_replace_selector( From a3b96ee2415f7a23b9f2d1cb83622a6f8657d92a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 9 Jun 2025 19:39:10 +0100 Subject: [PATCH 191/368] chore: Tidy up notes --- narwhals/_plan/common.py | 1 - narwhals/_plan/expr.py | 24 +++++++----------------- 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 812de34122..6560fe99af 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -259,7 +259,6 @@ def matches_column(self, name: str, dtype: DType) -> bool: - Thinking that we could get more cache hits on an individual column basis. - May also be more efficient to not iterate over the schema for every selector - Instead do one pass, evaluating every selector against a single column at a time - - Is that possible? """ raise NotImplementedError(type(self)) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 684dbf9d06..de93b67ddb 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -445,7 +445,6 @@ class FunctionExpr(ExprIR, t.Generic[FunctionT]): """**Representing `Expr::Function`**. https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L114-L120 - https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/function_expr/mod.rs#L123 """ @@ -453,19 +452,17 @@ class FunctionExpr(ExprIR, t.Generic[FunctionT]): input: Seq[ExprIR] function: FunctionT - """Enum type is named `FunctionExpr` in `polars`. + """Operation applied to each element of `input`. - Mirroring *exactly* doesn't make much sense in OOP. + Notes: + [Upstream enum type] is named `FunctionExpr` in `rust`. + Mirroring *exactly* doesn't make much sense in OOP. - https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/function_expr/mod.rs#L123 + [Upstream enum type]: https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/function_expr/mod.rs#L123 """ options: FunctionOptions - """Assuming this is **either**: - - 1. `function.function_options` - 2. The union of (1) and any `FunctionOptions` in `inputs` - """ + """Combined flags from chained operations.""" @property def is_scalar(self) -> bool: @@ -586,12 +583,7 @@ class WindowExpr(ExprIR): """ options: Window - """Little confused on the nesting. - - - We don't allow choosing `WindowMapping` kinds - - Haven't ventured into rolling much yet - - Turns out this is for `Expr.rolling` (not `Expr.rolling_`) - - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/mod.rs#L879-L888 + """Currently **always** represents over. Expr::Window { options: WindowType::Over(WindowMapping) } Expr::Window { options: WindowType::Rolling(RollingGroupOptions) } @@ -723,8 +715,6 @@ def map_ir(self, function: MapIR, /) -> ExprIR: return function(self) -# NOTE: selectors don't make sense to have recursive mapping *for now* `(Binary|Invert)Selector` -# If a function replaces the inner type with a non-selector, the other methods will break class BinarySelector( _BinaryOp[LeftSelectorT, SelectorOperatorT, RightSelectorT], SelectorIR, From b569b57c5cc037b5e29d648087102fb2dabd31f6 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 10 Jun 2025 13:15:20 +0100 Subject: [PATCH 192/368] test: Add `test_prepare_projection` Found a few bugs so far, but a surprising amount is working correctly --- tests/plan/expr_expansion_test.py | 178 +++++++++++++++++++++++++++++- 1 file changed, 177 insertions(+), 1 deletion(-) diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index f219c91a92..ce4c7a994f 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -6,13 +6,15 @@ import narwhals as nw from narwhals._plan import demo as nwd, selectors as ndcs -from narwhals._plan.common import is_expr +from narwhals._plan.common import IntoExpr, is_expr from narwhals._plan.expr import Alias, Column, Columns, _ColumnSelection from narwhals._plan.expr_expansion import ( FrozenSchema, + prepare_projection, replace_selector, rewrite_special_aliases, ) +from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir from narwhals.exceptions import ColumnNotFoundError, ComputeError if TYPE_CHECKING: @@ -345,3 +347,177 @@ def test_replace_selector( group_by_keys = () actual = replace_selector(expr._ir, group_by_keys, schema=schema_1) assert_expr_ir_equal(actual, expected) + + +BIG_EXCLUDE = ("k", "l", "m", "n", "o", "p", "s", "u", "r", "a", "b", "e", "q") + + +@pytest.mark.parametrize( + ("into_exprs", "expected"), + [ + ("a", [nwd.col("a")]), + (nwd.col("b", "c", "d"), [nwd.col("b"), nwd.col("c"), nwd.col("d")]), + (nwd.nth(6), [nwd.col("g")]), + (nwd.nth(9, 8, -5), [nwd.col("j"), nwd.col("i"), nwd.col("p")]), + ( + [nwd.nth(2).alias("c again"), nwd.nth(-1, -2).name.to_uppercase()], + [ + nwd.col("c").alias("c again"), + nwd.col("u").alias("U"), + nwd.col("s").alias("S"), + ], + ), + ( + nwd.all(), + [ + nwd.col("a"), + nwd.col("b"), + nwd.col("c"), + nwd.col("d"), + nwd.col("e"), + nwd.col("f"), + nwd.col("g"), + nwd.col("h"), + nwd.col("i"), + nwd.col("j"), + nwd.col("k"), + nwd.col("l"), + nwd.col("m"), + nwd.col("n"), + nwd.col("o"), + nwd.col("p"), + nwd.col("q"), + nwd.col("r"), + nwd.col("s"), + nwd.col("u"), + ], + ), + ( + (ndcs.numeric() - ndcs.by_dtype(nw.Float32(), nw.Float64())) + .cast(nw.Int64()) + .mean() + .name.suffix("_mean"), + [ + nwd.col("a").cast(nw.Int64()).mean().alias("a_mean"), + nwd.col("b").cast(nw.Int64()).mean().alias("b_mean"), + nwd.col("c").cast(nw.Int64()).mean().alias("c_mean"), + nwd.col("d").cast(nw.Int64()).mean().alias("d_mean"), + nwd.col("e").cast(nw.Int64()).mean().alias("e_mean"), + nwd.col("f").cast(nw.Int64()).mean().alias("f_mean"), + nwd.col("g").cast(nw.Int64()).mean().alias("g_mean"), + nwd.col("h").cast(nw.Int64()).mean().alias("h_mean"), + ], + ), + ( + nwd.col("u").alias("1").alias("2").alias("3").alias("4").name.keep(), + # NOTE: Would be nice to rewrite with less intermediate steps + # but retrieving the root name is enough for now + [nwd.col("u").alias("1").alias("2").alias("3").alias("4").alias("u")], + ), + ( + ( + (ndcs.numeric() ^ (ndcs.matches(r"[abcdg]") | ndcs.by_name("i", "f"))) + * 100 + ).name.suffix("_mult_100"), + [ + (nwd.col("e") * nwd.lit(100)).alias("e_mult_100"), + (nwd.col("h") * nwd.lit(100)).alias("h_mult_100"), + (nwd.col("j") * nwd.lit(100)).alias("j_mult_100"), + ], + ), + ( + ndcs.by_dtype(nw.Duration()) + .dt.total_minutes() + .name.map(lambda nm: f"total_mins: {nm!r} ?"), + [nwd.col("q").dt.total_minutes().alias("total_mins: 'q' ?")], + ), + ( + nwd.col("f", "g") + .cast(nw.String()) + .str.starts_with("1") + .all() + .name.suffix("_all_starts_with_1"), + [ + nwd.col("f") + .cast(nw.String()) + .str.starts_with("1") + .all() + .alias("f_all_starts_with_1"), + nwd.col("g") + .cast(nw.String()) + .str.starts_with("1") + .all() + .alias("g_all_starts_with_1"), + ], + ), + ( + nwd.col("a", "b") + .first() + .over("c", "e", order_by="d") + .name.suffix("_first_over_part_order_1"), + [ + nwd.col("a") + .first() + .over(nwd.col("c"), nwd.col("e"), order_by=[nwd.col("d")]) + .alias("a_first_over_part_order_1"), + nwd.col("b") + .first() + .over(nwd.col("c"), nwd.col("e"), order_by=[nwd.col("d")]) + .alias("b_first_over_part_order_1"), + ], + ), + pytest.param( + nwd.exclude(BIG_EXCLUDE), + [ + nwd.col("c"), + nwd.col("d"), + nwd.col("f"), + nwd.col("g"), + nwd.col("h"), + nwd.col("i"), + nwd.col("j"), + ], + marks=pytest.mark.xfail(reason="Exclude seems to be skipping expansion"), + ), + pytest.param( + nwd.exclude(BIG_EXCLUDE).name.suffix("_2"), + [ + nwd.col("c").alias("c_2"), + nwd.col("d").alias("d_2"), + nwd.col("f").alias("f_2"), + nwd.col("g").alias("g_2"), + nwd.col("h").alias("h_2"), + nwd.col("i").alias("i_2"), + nwd.col("j").alias("j_2"), + ], + marks=pytest.mark.xfail( + reason="Probably the same issue as bare `exclude(...)`, but is showing the effect after chaining:\n" + "'unable to find a single leaf column in expr' ", + raises=ComputeError, + ), + ), + pytest.param( + nwd.col("c").alias("c_min_over_order_by").min().over(order_by=ndcs.string()), + [ + nwd.col("c") + .alias("c_min_over_order_by") + .min() + .over(order_by=[nwd.col("k")]) + ], + marks=pytest.mark.xfail( + reason="BUG: `order_by` wasn't visited when collecting flags and failed to expand.\n" + "This slipped through as it *does* get visited if `partition_by` or `expr` contain selectors **as well**." + ), + ), + ], +) +def test_prepare_projection( + into_exprs: IntoExpr | Sequence[IntoExpr], + expected: Sequence[DummyExpr], + schema_1: FrozenSchema, +) -> None: + irs_in = parse_into_seq_of_expr_ir(into_exprs) + actual, _ = prepare_projection(irs_in, schema_1) + assert len(actual) == len(expected) + for lhs, rhs in zip(actual, expected): + assert_expr_ir_equal(lhs, rhs) From 534c902308fbfb33e25e6613319c3a5d1f031d8a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 10 Jun 2025 13:25:13 +0100 Subject: [PATCH 193/368] test: Add repro for horizontal alias bug --- tests/plan/expr_parsing_test.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index b7e71ad86a..2df7ef557d 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -321,3 +321,19 @@ def test_invalid_alias(expr: DummyExpr) -> None: pattern = re.compile(r"alias.+dupe.+multi\-output") with pytest.raises(DuplicateError, match=pattern): expr.alias("dupe") + + +@pytest.mark.xfail( + reason="BUG: Giving a false positive for horizontal reductions:\n" + "'Cannot apply alias 'abc' to multi-output expression'", + raises=DuplicateError, +) +def test_alias_horizontal() -> None: # pragma: no cover + assert nwd.sum_horizontal("a", "b", "c").alias("abc") + assert nwd.sum_horizontal(["a", "b", "c"]).alias("abc") + assert nwd.sum_horizontal("a", "b", nwd.col("c")).alias("abc") + assert nwd.sum_horizontal(nwd.col("a"), "b", "c").alias("abc") + # NOTE: Fails starting here, but all the others should be equivalent + assert nwd.sum_horizontal(nwd.col("a", "b"), "c").alias("abc") + assert nwd.sum_horizontal("a", nwd.col("b", "c")).alias("abc") + assert nwd.sum_horizontal(nwd.col("a", "b", "c")).alias("abc") From 23045dc35b5ebf05c25a738f683046e1303e7d0d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 10 Jun 2025 16:32:14 +0100 Subject: [PATCH 194/368] fix: Add missing `Exclude` iterators Fixes the two updated tests --- narwhals/_plan/expr.py | 8 ++++++++ tests/plan/expr_expansion_test.py | 10 ++-------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index de93b67ddb..3501f036d2 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -250,6 +250,14 @@ def expand_columns(self, schema: _Schema) -> Seq[Column]: raise NotImplementedError(msg) return _cols(name for name in schema if name not in self.names) + def iter_left(self) -> t.Iterator[ExprIR]: + yield from self.expr.iter_left() + yield self + + def iter_right(self) -> t.Iterator[ExprIR]: + yield self + yield from self.expr.iter_right() + def map_ir(self, function: MapIR, /) -> ExprIR: return function(self.with_expr(self.expr.map_ir(function))) diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index ce4c7a994f..2bd1162778 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -466,7 +466,7 @@ def test_replace_selector( .alias("b_first_over_part_order_1"), ], ), - pytest.param( + ( nwd.exclude(BIG_EXCLUDE), [ nwd.col("c"), @@ -477,9 +477,8 @@ def test_replace_selector( nwd.col("i"), nwd.col("j"), ], - marks=pytest.mark.xfail(reason="Exclude seems to be skipping expansion"), ), - pytest.param( + ( nwd.exclude(BIG_EXCLUDE).name.suffix("_2"), [ nwd.col("c").alias("c_2"), @@ -490,11 +489,6 @@ def test_replace_selector( nwd.col("i").alias("i_2"), nwd.col("j").alias("j_2"), ], - marks=pytest.mark.xfail( - reason="Probably the same issue as bare `exclude(...)`, but is showing the effect after chaining:\n" - "'unable to find a single leaf column in expr' ", - raises=ComputeError, - ), ), pytest.param( nwd.col("c").alias("c_min_over_order_by").min().over(order_by=ndcs.string()), From 4e65ebcb545a2f6e77b598d70bc019aa089cd047 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 10 Jun 2025 17:29:51 +0100 Subject: [PATCH 195/368] refactor(DRAFT): Start splitting out `WindowExpr` Not fixing the issue yet and still have loots of repetition --- narwhals/_plan/demo.py | 9 +-- narwhals/_plan/dummy.py | 10 +-- narwhals/_plan/exceptions.py | 9 ++- narwhals/_plan/expr.py | 131 ++++++++++++++++++++++------------- narwhals/_plan/window.py | 56 +++++++++++---- 5 files changed, 136 insertions(+), 79 deletions(-) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index 290a358b3c..7d6d066d71 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -29,7 +29,7 @@ from typing_extensions import TypeIs from narwhals._plan.dummy import DummyExpr - from narwhals._plan.expr import SortBy, WindowExpr + from narwhals._plan.expr import SortBy from narwhals.typing import NonNestedLiteral @@ -180,13 +180,6 @@ def _is_order_enforcing_previous(obj: t.Any) -> TypeIs[SortBy]: return isinstance(obj, allowed) -def _is_order_enforcing_next(obj: t.Any) -> TypeIs[WindowExpr]: - """Not sure how this one would work.""" - from narwhals._plan.expr import WindowExpr - - return isinstance(obj, WindowExpr) and obj.order_by is not None - - def _order_dependent_error(node: agg.OrderableAgg) -> OrderDependentExprError: previous = node.expr method = repr(node).removeprefix(f"{previous!r}.") diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 96fb7f25c4..4686d8a919 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -23,7 +23,7 @@ SortOptions, ) from narwhals._plan.selectors import by_name -from narwhals._plan.window import Over +from narwhals._plan.window import OrderedOver, Over from narwhals._utils import Version, _hasattr_static from narwhals.dtypes import DType from narwhals.exceptions import ComputeError @@ -140,8 +140,8 @@ def over( descending: bool = False, nulls_last: bool = False, ) -> Self: + node: expr.WindowExpr | expr.OrderedWindowExpr partition: Seq[ExprIR] = () - order: tuple[Seq[ExprIR], SortOptions] | None = None if not (partition_by) and order_by is None: msg = "At least one of `partition_by` or `order_by` must be specified." raise TypeError(msg) @@ -150,8 +150,10 @@ def over( if order_by is not None: by = parse.parse_into_seq_of_expr_ir(order_by) options = SortOptions(descending=descending, nulls_last=nulls_last) - order = by, options - return self._from_ir(Over().to_window_expr(self._ir, partition, order)) + node = OrderedOver().to_ordered_window_expr(self._ir, partition, by, options) + else: + node = Over().to_window_expr(self._ir, partition) + return self._from_ir(node) def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: options = SortOptions(descending=descending, nulls_last=nulls_last) diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index 5c165e89d9..9253e23fed 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -101,7 +101,8 @@ def binary_expr_length_changing_error( def over_nested_error( expr: WindowExpr, # noqa: ARG001 partition_by: Seq[ExprIR], # noqa: ARG001 - order_by: tuple[Seq[ExprIR], SortOptions] | None, # noqa: ARG001 + order_by: Seq[ExprIR] = (), # noqa: ARG001 + sort_options: SortOptions | None = None, # noqa: ARG001 ) -> InvalidOperationError: msg = "Cannot nest `over` statements." return InvalidOperationError(msg) @@ -111,7 +112,8 @@ def over_nested_error( def over_elementwise_error( expr: FunctionExpr[Function], partition_by: Seq[ExprIR], # noqa: ARG001 - order_by: tuple[Seq[ExprIR], SortOptions] | None, # noqa: ARG001 + order_by: Seq[ExprIR] = (), # noqa: ARG001 + sort_options: SortOptions | None = None, # noqa: ARG001 ) -> InvalidOperationError: msg = f"Cannot use `over` on expressions which are elementwise.\n{expr!r}" return InvalidOperationError(msg) @@ -121,7 +123,8 @@ def over_elementwise_error( def over_row_separable_error( expr: FunctionExpr[Function], partition_by: Seq[ExprIR], # noqa: ARG001 - order_by: tuple[Seq[ExprIR], SortOptions] | None, # noqa: ARG001 + order_by: Seq[ExprIR] = (), # noqa: ARG001 + sort_options: SortOptions | None = None, # noqa: ARG001 ) -> InvalidOperationError: msg = f"Cannot use `over` on expressions which change length.\n{expr!r}" return InvalidOperationError(msg) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 3501f036d2..300a574a88 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -20,7 +20,6 @@ function_expr_invalid_operation_error, ) from narwhals._plan.name import KeepName, RenameAlias -from narwhals._plan.options import SortOptions from narwhals._plan.typing import ( ExprT, FunctionT, @@ -46,7 +45,7 @@ from narwhals._plan.common import Seq from narwhals._plan.functions import MapBatches # noqa: F401 from narwhals._plan.literal import LiteralValue - from narwhals._plan.options import FunctionOptions, SortMultipleOptions + from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions from narwhals._plan.selectors import Selector from narwhals._plan.window import Window from narwhals.dtypes import DType @@ -562,7 +561,7 @@ def map_ir(self, function: MapIR, /) -> ExprIR: return function(Filter(expr=expr, by=by)) -# NOTE: Probably need to split out `order_by` +# TODO @dangotbanned: 100% split out `order_by` to a subclass # Really frustrating to handle the `None` case everywhere class WindowExpr(ExprIR): """A fully specified `.over()`, that occurred after another expression. @@ -575,7 +574,7 @@ class WindowExpr(ExprIR): - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/mod.rs#L840-L876 """ - __slots__ = ("expr", "options", "order_by", "partition_by") + __slots__ = ("expr", "options", "partition_by") expr: ExprIR """Renamed from `function`. @@ -584,11 +583,6 @@ class WindowExpr(ExprIR): """ partition_by: Seq[ExprIR] - order_by: tuple[Seq[ExprIR], SortOptions] | None - """Deviates from the `polars` version. - - - `order_by` starts the same as here, but `polars` reduces into a struct - becoming a single (nested) node. - """ options: Window """Currently **always** represents over. @@ -597,17 +591,62 @@ class WindowExpr(ExprIR): Expr::Window { options: WindowType::Rolling(RollingGroupOptions) } """ - @property - def sort_options(self) -> SortOptions: - if self.order_by: - _, opt = self.order_by - return opt - return SortOptions.default() + def __repr__(self) -> str: + return f"{self.expr!r}.over({list(self.partition_by)!r})" + + def __str__(self) -> str: + args = ( + f"expr={self.expr}, partition_by={self.partition_by}, options={self.options}" + ) + return f"{type(self).__name__}({args})" + + def iter_left(self) -> t.Iterator[ExprIR]: + yield from self.expr.iter_left() + for e in self.partition_by: + yield from e.iter_left() + yield self + + def iter_right(self) -> t.Iterator[ExprIR]: + yield self + for e in reversed(self.partition_by): + yield from e.iter_right() + yield from self.expr.iter_right() + + def map_ir(self, function: MapIR, /) -> ExprIR: + over = self.with_expr(self.expr.map_ir(function)).with_partition_by( + ir.map_ir(function) for ir in self.partition_by + ) + return function(over) + + def with_expr(self, expr: ExprIR, /) -> Self: + if expr == self.expr: + return self + return type(self)(expr=expr, partition_by=self.partition_by, options=self.options) + + def with_partition_by(self, partition_by: t.Iterable[ExprIR], /) -> Self: + by = tuple(partition_by) if not isinstance(partition_by, tuple) else partition_by + if by == self.partition_by: + return self + return type(self)(expr=self.expr, partition_by=by, options=self.options) + + +class OrderedWindowExpr(WindowExpr): + # `order_by` is required, only stores the `Seq[ExprIR]` + # `sort_options` is an attribute, not a property + __slots__ = ("expr", "options", "order_by", "partition_by", "sort_options") + + expr: ExprIR + partition_by: Seq[ExprIR] + order_by: Seq[ExprIR] + """Deviates from the `polars` version. + + - `order_by` starts the same as here, but `polars` reduces into a struct - becoming a single (nested) node. + """ + sort_options: SortOptions + options: Window def __repr__(self) -> str: - if self.order_by is None: - return f"{self.expr!r}.over({list(self.partition_by)!r})" - order, _ = self.order_by + order = self.order_by if not self.partition_by: args = f"order_by={list(order)!r}" else: @@ -615,11 +654,7 @@ def __repr__(self) -> str: return f"{self.expr!r}.over({args})" def __str__(self) -> str: - if self.order_by is None: - order_by = "None" - else: - order, opts = self.order_by - order_by = f"({order}, {opts})" + order_by = f"({self.order_by}, {self.sort_options})" args = f"expr={self.expr}, partition_by={self.partition_by}, order_by={order_by}, options={self.options}" return f"{type(self).__name__}({args})" @@ -641,11 +676,28 @@ def map_ir(self, function: MapIR, /) -> ExprIR: over = self.with_expr(self.expr.map_ir(function)).with_partition_by( ir.map_ir(function) for ir in self.partition_by ) - if self.order_by: - by, _ = self.order_by - over = over.with_order_by(ir.map_ir(function) for ir in by) + over = over.with_order_by(ir.map_ir(function) for ir in self.order_by) return function(over) + def with_order_by(self, order_by: t.Iterable[ExprIR], /) -> Self: + # NOTE: Not thrilled about this but there's complexity to solve + if by := (tuple(order_by) if not isinstance(order_by, tuple) else order_by): + if by == self.order_by: + return self + next_order_by = by + elif not self.order_by: + return self + else: + # NOTE: Unsure if we'd ever want to do this, but need to be exhaustive + next_order_by = () + return type(self)( + expr=self.expr, + partition_by=self.partition_by, + order_by=next_order_by, + sort_options=self.sort_options, + options=self.options, + ) + def with_expr(self, expr: ExprIR, /) -> Self: if expr == self.expr: return self @@ -653,6 +705,7 @@ def with_expr(self, expr: ExprIR, /) -> Self: expr=expr, partition_by=self.partition_by, order_by=self.order_by, + sort_options=self.sort_options, options=self.options, ) @@ -660,31 +713,11 @@ def with_partition_by(self, partition_by: t.Iterable[ExprIR], /) -> Self: by = tuple(partition_by) if not isinstance(partition_by, tuple) else partition_by if by == self.partition_by: return self - return type(self)( - expr=self.expr, partition_by=by, order_by=self.order_by, options=self.options - ) - - def with_order_by(self, order_by: t.Iterable[ExprIR], /) -> Self: - # NOTE: Not thrilled about this but there's complexity to solve - next_order_by: tuple[Seq[ExprIR], SortOptions] | None - if by := (tuple(order_by) if not isinstance(order_by, tuple) else order_by): - if prev := self.order_by: - prev_by, prev_sort = prev - # NOTE: Very hidden check for no-op possibility - if by == prev_by: - return self - next_order_by = by, prev_sort - else: - next_order_by = by, self.sort_options - elif prev := self.order_by: - # NOTE: Unsure if we'd ever want to do this, but need to be exhaustive - next_order_by = None - else: - return self return type(self)( expr=self.expr, - partition_by=self.partition_by, - order_by=next_order_by, + partition_by=by, + order_by=self.order_by, + sort_options=self.sort_options, options=self.options, ) diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py index b742c0e3e3..64f8984ab4 100644 --- a/narwhals/_plan/window.py +++ b/narwhals/_plan/window.py @@ -11,8 +11,9 @@ if TYPE_CHECKING: from narwhals._plan.common import ExprIR, Seq - from narwhals._plan.expr import WindowExpr + from narwhals._plan.expr import OrderedWindowExpr, WindowExpr from narwhals._plan.options import SortOptions + from narwhals.exceptions import InvalidOperationError class Window(Immutable): @@ -22,28 +23,53 @@ class Window(Immutable): """ -# TODO @dangotbanned: What are all the variants we have code paths for? -# - Over has *at least* (partition_by,), (order_by,), (partition_by, order_by), + options -# - `_plan.expr.WindowExpr` has: -# - expr (last node) -# - partition_by, optional order_by, `options` which is one of these classes? class Over(Window): - def to_window_expr( - self, + @staticmethod + def _validate_over( expr: ExprIR, partition_by: Seq[ExprIR], - order_by: tuple[Seq[ExprIR], SortOptions] | None, + order_by: Seq[ExprIR] = (), + sort_options: SortOptions | None = None, /, - ) -> WindowExpr: + ) -> InvalidOperationError | None: from narwhals._plan.expr import FunctionExpr, WindowExpr if isinstance(expr, WindowExpr): - raise over_nested_error(expr, partition_by, order_by) + return over_nested_error(expr, partition_by, order_by, sort_options) if isinstance(expr, FunctionExpr): if expr.options.is_elementwise(): - raise over_elementwise_error(expr, partition_by, order_by) + return over_elementwise_error(expr, partition_by, order_by, sort_options) if expr.options.is_row_separable(): - raise over_row_separable_error(expr, partition_by, order_by) - return WindowExpr( - expr=expr, partition_by=partition_by, order_by=order_by, options=self + return over_row_separable_error( + expr, partition_by, order_by, sort_options + ) + return None + + def to_window_expr(self, expr: ExprIR, partition_by: Seq[ExprIR], /) -> WindowExpr: + from narwhals._plan.expr import WindowExpr + + if err := self._validate_over(expr, partition_by): + raise err + return WindowExpr(expr=expr, partition_by=partition_by, options=self) + + +class OrderedOver(Over): + def to_ordered_window_expr( + self, + expr: ExprIR, + partition_by: Seq[ExprIR], + order_by: Seq[ExprIR], + sort_options: SortOptions, + /, + ) -> OrderedWindowExpr: + from narwhals._plan.expr import OrderedWindowExpr + + if err := self._validate_over(expr, partition_by, order_by, sort_options): + raise err + return OrderedWindowExpr( + expr=expr, + partition_by=partition_by, + order_by=order_by, + sort_options=sort_options, + options=self, ) From 99cb01abdc75715bca9b8bd0683e0c6dd23a7e2f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 10 Jun 2025 18:05:03 +0100 Subject: [PATCH 196/368] refactor: Use a single `Over` with two builder methods Need the two methods to avoid an incompatible override, so no point in subclassing --- narwhals/_plan/dummy.py | 4 ++-- narwhals/_plan/window.py | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 4686d8a919..287c12027f 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -23,7 +23,7 @@ SortOptions, ) from narwhals._plan.selectors import by_name -from narwhals._plan.window import OrderedOver, Over +from narwhals._plan.window import Over from narwhals._utils import Version, _hasattr_static from narwhals.dtypes import DType from narwhals.exceptions import ComputeError @@ -150,7 +150,7 @@ def over( if order_by is not None: by = parse.parse_into_seq_of_expr_ir(order_by) options = SortOptions(descending=descending, nulls_last=nulls_last) - node = OrderedOver().to_ordered_window_expr(self._ir, partition, by, options) + node = Over().to_ordered_window_expr(self._ir, partition, by, options) else: node = Over().to_window_expr(self._ir, partition) return self._from_ir(node) diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py index 64f8984ab4..c2484a544a 100644 --- a/narwhals/_plan/window.py +++ b/narwhals/_plan/window.py @@ -52,8 +52,6 @@ def to_window_expr(self, expr: ExprIR, partition_by: Seq[ExprIR], /) -> WindowEx raise err return WindowExpr(expr=expr, partition_by=partition_by, options=self) - -class OrderedOver(Over): def to_ordered_window_expr( self, expr: ExprIR, From 422bbc744e218ae46a1a03d3d95df433921fe438 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 10 Jun 2025 18:52:49 +0100 Subject: [PATCH 197/368] fix: Expand exprs/selectors in `over(order_by=...)` --- narwhals/_plan/common.py | 8 ++++++++ narwhals/_plan/expr.py | 13 +++++++++++-- narwhals/_plan/meta.py | 2 +- tests/plan/expr_expansion_test.py | 6 +----- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 6560fe99af..5b5d4287df 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -238,6 +238,14 @@ def iter_right(self) -> Iterator[ExprIR]: """ yield self + def iter_root_names(self) -> Iterator[ExprIR]: + """Override for different iteration behavior in `ExprIR.meta.root_names`. + + Note: + Identical to `iter_left` by default. + """ + yield from self.iter_left() + @property def meta(self) -> IRMetaNamespace: from narwhals._plan.meta import IRMetaNamespace diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 300a574a88..e20d9d0316 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -659,19 +659,28 @@ def __str__(self) -> str: return f"{type(self).__name__}({args})" def iter_left(self) -> t.Iterator[ExprIR]: - # NOTE: `order_by` is never considered in `polars` - # https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/plans/iterator.rs#L76-L86 yield from self.expr.iter_left() for e in self.partition_by: yield from e.iter_left() + for e in self.order_by: + yield from e.iter_left() yield self def iter_right(self) -> t.Iterator[ExprIR]: yield self + for e in reversed(self.order_by): + yield from e.iter_right() for e in reversed(self.partition_by): yield from e.iter_right() yield from self.expr.iter_right() + def iter_root_names(self) -> t.Iterator[ExprIR]: + # NOTE: `order_by` is never considered in `polars` + # To match that behavior for `root_names` - but still expand in all other cases + # - this little escape hatch exists + # https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/plans/iterator.rs#L76-L86 + yield from super().iter_left() + def map_ir(self, function: MapIR, /) -> ExprIR: over = self.with_expr(self.expr.map_ir(function)).with_partition_by( ir.map_ir(function) for ir in self.partition_by diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 0f66560636..fd14e1d793 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -103,7 +103,7 @@ def _expr_to_leaf_column_names_iter(ir: ExprIR) -> Iterator[str]: def _expr_to_leaf_column_exprs_iter(ir: ExprIR) -> Iterator[ExprIR]: from narwhals._plan import expr - for outer in ir.iter_left(): + for outer in ir.iter_root_names(): if isinstance(outer, (expr.Column, expr.All)): yield outer diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 2bd1162778..cf776f5df1 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -490,7 +490,7 @@ def test_replace_selector( nwd.col("j").alias("j_2"), ], ), - pytest.param( + ( nwd.col("c").alias("c_min_over_order_by").min().over(order_by=ndcs.string()), [ nwd.col("c") @@ -498,10 +498,6 @@ def test_replace_selector( .min() .over(order_by=[nwd.col("k")]) ], - marks=pytest.mark.xfail( - reason="BUG: `order_by` wasn't visited when collecting flags and failed to expand.\n" - "This slipped through as it *does* get visited if `partition_by` or `expr` contain selectors **as well**." - ), ), ], ) From ee1bdb82818795dcd3b3f080947f4c3ef6888cc7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 10 Jun 2025 18:56:11 +0100 Subject: [PATCH 198/368] chore: Update comments --- narwhals/_plan/expr.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index e20d9d0316..5c374282f4 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -561,8 +561,7 @@ def map_ir(self, function: MapIR, /) -> ExprIR: return function(Filter(expr=expr, by=by)) -# TODO @dangotbanned: 100% split out `order_by` to a subclass -# Really frustrating to handle the `None` case everywhere +# TODO @dangotbanned: Clean up docs/notes class WindowExpr(ExprIR): """A fully specified `.over()`, that occurred after another expression. @@ -630,9 +629,8 @@ def with_partition_by(self, partition_by: t.Iterable[ExprIR], /) -> Self: return type(self)(expr=self.expr, partition_by=by, options=self.options) +# TODO @dangotbanned: Reduce repetition from `WindowExpr` class OrderedWindowExpr(WindowExpr): - # `order_by` is required, only stores the `Seq[ExprIR]` - # `sort_options` is an attribute, not a property __slots__ = ("expr", "options", "order_by", "partition_by", "sort_options") expr: ExprIR From 41f4070709c21a92d39748c748f7635bc58a7ace Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 10 Jun 2025 19:00:03 +0100 Subject: [PATCH 199/368] refactor: Simplify `with_order_by` --- narwhals/_plan/expr.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 5c374282f4..799d271b4f 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -687,20 +687,13 @@ def map_ir(self, function: MapIR, /) -> ExprIR: return function(over) def with_order_by(self, order_by: t.Iterable[ExprIR], /) -> Self: - # NOTE: Not thrilled about this but there's complexity to solve - if by := (tuple(order_by) if not isinstance(order_by, tuple) else order_by): - if by == self.order_by: - return self - next_order_by = by - elif not self.order_by: + by = tuple(order_by) if not isinstance(order_by, tuple) else order_by + if by == self.order_by: return self - else: - # NOTE: Unsure if we'd ever want to do this, but need to be exhaustive - next_order_by = () return type(self)( expr=self.expr, partition_by=self.partition_by, - order_by=next_order_by, + order_by=by, sort_options=self.sort_options, options=self.options, ) From 7d4543e072f5646a990b06ffed70fe6c973e9d2e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 10 Jun 2025 19:10:25 +0100 Subject: [PATCH 200/368] refactor: Factor out tuple boilerplate --- narwhals/_plan/common.py | 7 ++++++- narwhals/_plan/expr.py | 12 ++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 5b5d4287df..4fb1269e60 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -8,7 +8,7 @@ from narwhals.utils import Version if TYPE_CHECKING: - from typing import Any, Callable, Iterator, Literal + from typing import Any, Callable, Iterable, Iterator, Literal from typing_extensions import Never, Self, TypeAlias, TypeIs, dataclass_transform @@ -387,3 +387,8 @@ def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) type(None): dtypes.Unknown, } return mapping.get(type(obj), dtypes.Unknown)() + + +def collect(iterable: Seq[T] | Iterable[T], /) -> Seq[T]: + """Collect `iterable` into a `tuple`, *iff* it is not one already.""" + return iterable if isinstance(iterable, tuple) else tuple(iterable) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 799d271b4f..afcdfe09c7 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -11,6 +11,7 @@ ExprIR, SelectorIR, _field_str, + collect, is_non_nested_literal, is_regex_projection, ) @@ -442,7 +443,7 @@ def with_expr(self, expr: ExprIR, /) -> Self: return type(self)(expr=expr, by=self.by, options=self.options) def with_by(self, by: t.Iterable[ExprIR], /) -> Self: - by = tuple(by) if not isinstance(by, tuple) else by + by = collect(by) if by == self.by: return self return type(self)(expr=self.expr, by=by, options=self.options) @@ -480,8 +481,7 @@ def with_options(self, options: FunctionOptions, /) -> Self: return type(self)(input=self.input, function=self.function, options=options) def with_input(self, input: t.Iterable[ExprIR], /) -> Self: # noqa: A002 - if not isinstance(input, tuple): - input = tuple(input) + input = collect(input) if input == self.input: return self return type(self)(input=input, function=self.function, options=self.options) @@ -623,7 +623,7 @@ def with_expr(self, expr: ExprIR, /) -> Self: return type(self)(expr=expr, partition_by=self.partition_by, options=self.options) def with_partition_by(self, partition_by: t.Iterable[ExprIR], /) -> Self: - by = tuple(partition_by) if not isinstance(partition_by, tuple) else partition_by + by = collect(partition_by) if by == self.partition_by: return self return type(self)(expr=self.expr, partition_by=by, options=self.options) @@ -687,7 +687,7 @@ def map_ir(self, function: MapIR, /) -> ExprIR: return function(over) def with_order_by(self, order_by: t.Iterable[ExprIR], /) -> Self: - by = tuple(order_by) if not isinstance(order_by, tuple) else order_by + by = collect(order_by) if by == self.order_by: return self return type(self)( @@ -710,7 +710,7 @@ def with_expr(self, expr: ExprIR, /) -> Self: ) def with_partition_by(self, partition_by: t.Iterable[ExprIR], /) -> Self: - by = tuple(partition_by) if not isinstance(partition_by, tuple) else partition_by + by = collect(partition_by) if by == self.partition_by: return self return type(self)( From 93033385a910424b5aecebabc1c9a9d08b45fef7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 10 Jun 2025 22:51:55 +0100 Subject: [PATCH 201/368] perf: Prepare `FrozenSchema` for caching Once `result: ResultIRs` is made immutable (or mutability stays within function boundaries) - most of `expr_expansion` will be safe to cache --- narwhals/_plan/expr_expansion.py | 118 +++++++++++++++++++++++++----- tests/plan/expr_expansion_test.py | 10 ++- 2 files changed, 105 insertions(+), 23 deletions(-) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 8098f8ff2f..20f64fe0d8 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -39,15 +39,31 @@ from __future__ import annotations from collections import deque -from copy import deepcopy +from functools import lru_cache from types import MappingProxyType -from typing import TYPE_CHECKING, Callable, Iterator, Mapping, Sequence - -from narwhals._plan.common import ExprIR, Immutable, SelectorIR, is_regex_projection +from typing import TYPE_CHECKING, TypeVar, overload + +from narwhals._plan.common import ( + _IMMUTABLE_HASH_NAME, + ExprIR, + Immutable, + SelectorIR, + is_regex_projection, +) +from narwhals.dtypes import DType from narwhals.exceptions import ComputeError, InvalidOperationError if TYPE_CHECKING: import re + from typing import ( + Callable, + ItemsView, + Iterator, + KeysView, + Mapping, + Sequence, + ValuesView, + ) from typing_extensions import TypeAlias @@ -57,22 +73,88 @@ from narwhals.dtypes import DType -FrozenSchema: TypeAlias = "MappingProxyType[str, DType]" FrozenColumns: TypeAlias = "Seq[str]" Excluded: TypeAlias = "frozenset[str]" """Internally use a `set`, then freeze before returning.""" ResultIRs: TypeAlias = "deque[ExprIR]" +_FrozenSchemaHash: TypeAlias = "Seq[tuple[str, DType]]" +_T2 = TypeVar("_T2") # NOTE: Both `_freeze` functions will probably want to be cached # In the traversal/expand/replacement functions, their returns will be hashable -> safe to cache those as well -def _freeze_schema(**schema: DType) -> FrozenSchema: - copied = deepcopy(schema) - return MappingProxyType(copied) +class FrozenSchema(Immutable): + """Use `freeze_schema(...)` constructor to trigger caching!""" + + __slots__ = ("_mapping",) + _mapping: MappingProxyType[str, DType] + + @property + def __immutable_hash__(self) -> int: + if hasattr(self, _IMMUTABLE_HASH_NAME): + return self.__immutable_hash_value__ + hash_value = hash((self.__class__, *tuple(self._mapping.items()))) + object.__setattr__(self, _IMMUTABLE_HASH_NAME, hash_value) + return self.__immutable_hash_value__ + + @property + def names(self) -> FrozenColumns: + """Get the column names of the schema.""" + return freeze_columns(self) + + @staticmethod + def _from_mapping(mapping: MappingProxyType[str, DType], /) -> FrozenSchema: + return FrozenSchema(_mapping=mapping) + + @staticmethod + def _from_hash_safe(items: _FrozenSchemaHash, /) -> FrozenSchema: + clone = MappingProxyType(dict(items)) + return FrozenSchema._from_mapping(clone) + + def items(self) -> ItemsView[str, DType]: + return self._mapping.items() + + def keys(self) -> KeysView[str]: + return self._mapping.keys() + def values(self) -> ValuesView[DType]: + return self._mapping.values() -def _freeze_columns(schema: FrozenSchema, /) -> FrozenColumns: + @overload + def get(self, key: str, /) -> DType | None: ... + @overload + def get(self, key: str, default: DType | _T2, /) -> DType | _T2: ... + def get(self, key: str, default: DType | _T2 | None = None, /) -> DType | _T2 | None: + if default is not None: + return self._mapping.get(key, default) + return self._mapping.get(key) + + def __iter__(self) -> Iterator[str]: + yield from self._mapping + + def __contains__(self, key: object) -> bool: + return self._mapping.__contains__(key) + + def __getitem__(self, key: str, /) -> DType: + return self._mapping.__getitem__(key) + + def __len__(self) -> int: + return self._mapping.__len__() + + +def freeze_schema(**schema: DType) -> FrozenSchema: + schema_hash = tuple(schema.items()) + return _freeze_schema_cache(schema_hash) + + +@lru_cache(maxsize=100) +def _freeze_schema_cache(schema: _FrozenSchemaHash, /) -> FrozenSchema: + return FrozenSchema._from_hash_safe(schema) + + +@lru_cache(maxsize=100) +def freeze_columns(schema: FrozenSchema, /) -> FrozenColumns: return tuple(schema) @@ -144,9 +226,11 @@ def with_multiple_columns(self) -> ExpansionFlags: def prepare_projection( - exprs: Sequence[ExprIR], schema: Mapping[str, DType] + exprs: Sequence[ExprIR], schema: Mapping[str, DType] | FrozenSchema ) -> tuple[Seq[ExprIR], FrozenSchema]: - frozen_schema = _freeze_schema(**schema) + frozen_schema = ( + schema if isinstance(schema, FrozenSchema) else freeze_schema(**schema) + ) rewritten = rewrite_projections(tuple(exprs), keys=(), schema=frozen_schema) # NOTE: There's an `expressions_to_schema` step that I'm skipping for now # seems too big of a rabbit hole to go down @@ -172,7 +256,7 @@ def replace_nth(origin: ExprIR, /, schema: FrozenSchema) -> ExprIR: def fn(child: ExprIR, /) -> ExprIR: if isinstance(child, expr.Nth): - return expr.Column(name=_freeze_columns(schema)[child.index]) + return expr.Column(name=schema.names[child.index]) return child return origin.map_ir(fn) @@ -299,7 +383,7 @@ def replace_and_add_to_results( if isinstance(e, expr.Columns): exclude = prepare_excluded(origin, keys=(), has_exclude=flags.has_exclude) result = expand_columns( - origin, result, e, col_names=_freeze_columns(schema), exclude=exclude + origin, result, e, col_names=schema.names, exclude=exclude ) else: exclude = prepare_excluded( @@ -308,14 +392,10 @@ def replace_and_add_to_results( result = expand_indices(origin, result, e, schema=schema, exclude=exclude) elif flags.has_wildcard: exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude) - result = replace_wildcard( - origin, result, col_names=_freeze_columns(schema), exclude=exclude - ) + result = replace_wildcard(origin, result, col_names=schema.names, exclude=exclude) else: exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude) - result = replace_regex( - origin, result, col_names=_freeze_columns(schema), exclude=exclude - ) + result = replace_regex(origin, result, col_names=schema.names, exclude=exclude) return result diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index cf776f5df1..39037be63e 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -9,7 +9,7 @@ from narwhals._plan.common import IntoExpr, is_expr from narwhals._plan.expr import Alias, Column, Columns, _ColumnSelection from narwhals._plan.expr_expansion import ( - FrozenSchema, + freeze_schema, prepare_projection, replace_selector, rewrite_special_aliases, @@ -342,10 +342,12 @@ def test_map_ir_recursive(expr: DummyExpr, function: MapIR, expected: DummyExpr) ], ) def test_replace_selector( - expr: DummySelector | DummyExpr, expected: DummyExpr | ExprIR, schema_1: FrozenSchema + expr: DummySelector | DummyExpr, + expected: DummyExpr | ExprIR, + schema_1: dict[str, DType], ) -> None: group_by_keys = () - actual = replace_selector(expr._ir, group_by_keys, schema=schema_1) + actual = replace_selector(expr._ir, group_by_keys, schema=freeze_schema(**schema_1)) assert_expr_ir_equal(actual, expected) @@ -504,7 +506,7 @@ def test_replace_selector( def test_prepare_projection( into_exprs: IntoExpr | Sequence[IntoExpr], expected: Sequence[DummyExpr], - schema_1: FrozenSchema, + schema_1: dict[str, DType], ) -> None: irs_in = parse_into_seq_of_expr_ir(into_exprs) actual, _ = prepare_projection(irs_in, schema_1) From 090330c521e038da54bb517ce5ad0810004c2951 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 11 Jun 2025 21:06:47 +0100 Subject: [PATCH 202/368] feat: Validate expressions with schema Resolves - Parts of (https://github.com/narwhals-dev/narwhals/pull/2572#discussion_r2139984656) - All of (https://github.com/narwhals-dev/narwhals/pull/2572#discussion_r2139965411) --- narwhals/_plan/common.py | 6 + narwhals/_plan/exceptions.py | 20 ++- narwhals/_plan/expr.py | 7 -- narwhals/_plan/expr_expansion.py | 53 ++++++-- tests/plan/expr_expansion_test.py | 195 +++++++++++++++++++++++------- tests/plan/expr_parsing_test.py | 38 ------ 6 files changed, 216 insertions(+), 103 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 4fb1269e60..ee2eea0efa 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -371,6 +371,12 @@ def is_regex_projection(name: str) -> bool: return name.startswith("^") and name.endswith("$") +def is_horizontal_reduction(obj: FunctionExpr[Any] | Any) -> TypeIs[FunctionExpr[Any]]: + from narwhals._plan.expr import FunctionExpr + + return isinstance(obj, FunctionExpr) and obj.options.is_input_wildcard_expansion() + + def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) -> DType: dtypes = version.dtypes mapping: dict[type[NonNestedLiteral], type[DType]] = { diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index 9253e23fed..2237370b90 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -2,6 +2,8 @@ from __future__ import annotations +from collections import Counter +from itertools import groupby from typing import TYPE_CHECKING from narwhals.exceptions import ( @@ -164,11 +166,25 @@ def is_iterable_polars_error( return TypeError(msg) -def alias_duplicate_error(expr: ExprIR, name: str) -> DuplicateError: - msg = f"Cannot apply alias {name!r} to multi-output expression:\n{expr!r}" +def duplicate_error(exprs: Seq[ExprIR]) -> DuplicateError: + INDENT = "\n " # noqa: N806 + names = [_output_name(expr) for expr in exprs] + duplicates = {k for k, v in Counter(names).items() if v > 1} + group_by_name = groupby(exprs, _output_name) + name_exprs = { + k: INDENT.join(f"{el!r}" for el in it) + for k, it in group_by_name + if k in duplicates + } + msg = "\n".join(f"[{name!r}]{INDENT}{e}" for name, e in name_exprs.items()) + msg = f"Expected unique column names, but found duplicates:\n\n{msg}" return DuplicateError(msg) +def _output_name(expr: ExprIR) -> str: + return expr.meta.output_name() + + def column_not_found_error( subset: Iterable[str], /, available: Iterable[str] ) -> ColumnNotFoundError: diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index afcdfe09c7..f5ff58d5e3 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -16,7 +16,6 @@ is_regex_projection, ) from narwhals._plan.exceptions import ( - alias_duplicate_error, column_not_found_error, function_expr_invalid_operation_error, ) @@ -108,12 +107,6 @@ def iter_right(self) -> t.Iterator[ExprIR]: yield self yield from self.expr.iter_right() - def __init__(self, *, expr: ExprIR, name: str) -> None: - if expr.meta.has_multiple_outputs(): - raise alias_duplicate_error(expr, name) - kwds = {"expr": expr, "name": name} - super().__init__(**kwds) - def map_ir(self, function: MapIR, /) -> ExprIR: return function(self.with_expr(self.expr.map_ir(function))) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 20f64fe0d8..83ba96f70c 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -40,6 +40,7 @@ from collections import deque from functools import lru_cache +from itertools import chain from types import MappingProxyType from typing import TYPE_CHECKING, TypeVar, overload @@ -48,10 +49,17 @@ ExprIR, Immutable, SelectorIR, + is_horizontal_reduction, is_regex_projection, ) +from narwhals._plan.exceptions import column_not_found_error, duplicate_error from narwhals.dtypes import DType -from narwhals.exceptions import ComputeError, InvalidOperationError +from narwhals.exceptions import ( + ColumnNotFoundError, + ComputeError, + DuplicateError, + InvalidOperationError, +) if TYPE_CHECKING: import re @@ -232,21 +240,44 @@ def prepare_projection( schema if isinstance(schema, FrozenSchema) else freeze_schema(**schema) ) rewritten = rewrite_projections(tuple(exprs), keys=(), schema=frozen_schema) - # NOTE: There's an `expressions_to_schema` step that I'm skipping for now - # seems too big of a rabbit hole to go down + if err := ensure_valid_exprs(rewritten, frozen_schema): + raise err return rewritten, frozen_schema -def expand_function_inputs(origin: ExprIR, /, *, schema: FrozenSchema) -> ExprIR: - from narwhals._plan import expr +def ensure_valid_exprs( + exprs: Seq[ExprIR], schema: FrozenSchema +) -> ColumnNotFoundError | DuplicateError | None: + """Return an appropriate error if we can't materialize.""" + if err := _ensure_column_names_unique(exprs): + return err + root_names = _root_names_unique(exprs) + if not (set(schema.names).issuperset(root_names)): + return column_not_found_error(root_names, schema) + return None + + +def _ensure_column_names_unique(exprs: Seq[ExprIR]) -> DuplicateError | None: + names = [e.meta.output_name() for e in exprs] + if len(names) != len(set(names)): + return duplicate_error(exprs) + return None + +def _root_names_unique(exprs: Seq[ExprIR]) -> set[str]: + from narwhals._plan.meta import _expr_to_leaf_column_names_iter + + it = chain.from_iterable(_expr_to_leaf_column_names_iter(expr) for expr in exprs) + return set(it) + + +def expand_function_inputs(origin: ExprIR, /, *, schema: FrozenSchema) -> ExprIR: def fn(child: ExprIR, /) -> ExprIR: - if not ( - isinstance(child, expr.FunctionExpr) - and child.options.is_input_wildcard_expansion() - ): - return child - return child.with_input(rewrite_projections(child.input, keys=(), schema=schema)) + if is_horizontal_reduction(child): + return child.with_input( + rewrite_projections(child.input, keys=(), schema=schema) + ) + return child return origin.map_ir(fn) diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 39037be63e..0688d919e9 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +import re +from typing import TYPE_CHECKING, Callable, Iterable, Sequence import pytest @@ -15,7 +16,7 @@ rewrite_special_aliases, ) from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir -from narwhals.exceptions import ColumnNotFoundError, ComputeError +from narwhals.exceptions import ColumnNotFoundError, ComputeError, DuplicateError if TYPE_CHECKING: from typing_extensions import TypeIs @@ -52,6 +53,18 @@ def schema_1() -> dict[str, DType]: } +MULTI_OUTPUT_EXPRS = ( + pytest.param(nwd.col("a", "b", "c")), + pytest.param(ndcs.numeric() - ndcs.matches("[d-j]")), + pytest.param(nwd.nth(0, 1, 2)), + pytest.param(ndcs.by_dtype(nw.Int64, nw.Int32, nw.Int16)), + pytest.param(ndcs.by_name("a", "b", "c")), +) +"""All of these resolve to `["a", "b", "c"]`.""" + +BIG_EXCLUDE = ("k", "l", "m", "n", "o", "p", "s", "u", "r", "a", "b", "e", "q") + + def assert_expr_ir_equal(left: DummyExpr | ExprIR, right: DummyExpr | ExprIR) -> None: lhs = left._ir if is_expr(left) else left rhs = right._ir if is_expr(right) else right @@ -119,46 +132,6 @@ def test_expand_columns_root( assert actual == expected -@pytest.mark.parametrize( - "expr", - [ - nwd.col("y", "z"), - nwd.col("a", "b", "z"), - nwd.col("x", "b", "a"), - nwd.col( - [ - "a", - "b", - "c", - "d", - "e", - "f", - "g", - "h", - "FIVE", - "i", - "j", - "k", - "l", - "m", - "n", - "o", - "p", - "q", - "r", - "s", - "u", - ] - ), - ], -) -def test_invalid_expand_columns(expr: DummyExpr, schema_1: dict[str, DType]) -> None: - selection = expr._ir - assert is_column_selection(selection) - with pytest.raises(ColumnNotFoundError): - selection.expand_columns(schema_1) - - def udf_name_map(name: str) -> str: original = name upper = name.upper() @@ -351,9 +324,6 @@ def test_replace_selector( assert_expr_ir_equal(actual, expected) -BIG_EXCLUDE = ("k", "l", "m", "n", "o", "p", "s", "u", "r", "a", "b", "e", "q") - - @pytest.mark.parametrize( ("into_exprs", "expected"), [ @@ -513,3 +483,138 @@ def test_prepare_projection( assert len(actual) == len(expected) for lhs, rhs in zip(actual, expected): assert_expr_ir_equal(lhs, rhs) + + +@pytest.mark.parametrize( + "expr", + [ + nwd.all(), + nwd.nth(1, 2, 3), + nwd.col("a", "b", "c"), + ndcs.boolean() | ndcs.categorical(), + (ndcs.by_name("a", "b") | ndcs.string()), + (nwd.col("b", "c") & nwd.col("a")), + nwd.col("a", "b").min().over("c", order_by="e"), + (~ndcs.by_dtype(nw.Int64()) - ndcs.datetime()), + nwd.nth(6, 2).abs().cast(nw.Int32()) + 10, + *MULTI_OUTPUT_EXPRS, + ], +) +def test_prepare_projection_duplicate( + expr: DummyExpr, schema_1: dict[str, DType] +) -> None: + irs = parse_into_seq_of_expr_ir(expr.alias("dupe")) + pattern = re.compile(r"\.alias\(.dupe.\)") + with pytest.raises(DuplicateError, match=pattern): + prepare_projection(irs, schema_1) + + +@pytest.mark.parametrize( + ("into_exprs", "missing"), + [ + ([nwd.col("y", "z")], ["y", "z"]), + ([nwd.col("a", "b", "z")], ["z"]), + ([nwd.col("x", "b", "a")], ["x"]), + ( + [ + nwd.col( + [ + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "FIVE", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "u", + ] + ) + ], + ["FIVE"], + ), + ( + [nwd.col("a").min().over("c").alias("y"), nwd.col("one").alias("b").last()], + ["one"], + ), + ([nwd.col("a").sort_by("b", "who").alias("f")], ["who"]), + ( + [ + nwd.nth(0, 5) + .cast(nw.Int64()) + .abs() + .cum_sum() + .over("X", "O", "h", "m", "r", "zee"), + nwd.col("d", "j"), + "n", + ], + ["O", "X", "zee"], + ), + ], +) +def test_prepare_projection_column_not_found( + into_exprs: IntoExpr | Sequence[IntoExpr], + missing: Sequence[str], + schema_1: dict[str, DType], +) -> None: + pattern = re.compile(rf"not found: {re.escape(repr(missing))}") + irs = parse_into_seq_of_expr_ir(into_exprs) + with pytest.raises(ColumnNotFoundError, match=pattern): + prepare_projection(irs, schema_1) + + +@pytest.mark.parametrize( + "into_exprs", + [ + ("a", "b", "c"), + (["a", "b", "c"]), + ("a", "b", nwd.col("c")), + (nwd.col("a"), "b", "c"), + (nwd.col("a", "b"), "c"), + ("a", nwd.col("b", "c")), + ((nwd.nth(0), nwd.nth(1, 2))), + *MULTI_OUTPUT_EXPRS, + ], +) +@pytest.mark.parametrize( + "function", + [ + nwd.all_horizontal, + nwd.any_horizontal, + nwd.sum_horizontal, + nwd.min_horizontal, + nwd.max_horizontal, + nwd.mean_horizontal, + nwd.concat_str, + ], +) +def test_prepare_projection_horizontal_alias( + into_exprs: IntoExpr | Iterable[IntoExpr], + function: Callable[..., DummyExpr], + schema_1: dict[str, DType], +) -> None: + # NOTE: See https://github.com/narwhals-dev/narwhals/pull/2572#discussion_r2139965411 + expr = function(into_exprs) + alias_1 = expr.alias("alias(x1)") + irs = parse_into_seq_of_expr_ir(alias_1) + out_irs, _ = prepare_projection(irs, schema_1) + assert len(out_irs) == 1 + assert out_irs[0] == function("a", "b", "c").alias("alias(x1)")._ir + + alias_2 = alias_1.alias("alias(x2)") + irs = parse_into_seq_of_expr_ir(alias_2) + out_irs, _ = prepare_projection(irs, schema_1) + assert len(out_irs) == 1 + assert out_irs[0] == function("a", "b", "c").alias("alias(x1)").alias("alias(x2)")._ir diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 2df7ef557d..4fa567386e 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -11,13 +11,11 @@ from narwhals._plan import ( boolean, functions as F, # noqa: N812 - selectors as ndcs, ) from narwhals._plan.common import ExprIR, Function from narwhals._plan.dummy import DummyExpr, DummySeries from narwhals._plan.expr import BinaryExpr, FunctionExpr from narwhals.exceptions import ( - DuplicateError, InvalidOperationError, LengthChangingExprError, MultiOutputExpressionError, @@ -301,39 +299,3 @@ def test_is_in_series() -> None: def test_invalid_is_in(other: Any, context: ContextManager[Any]) -> None: with context: nwd.col("a").is_in(other) - - -@pytest.mark.parametrize( - "expr", - [ - nwd.all(), - nwd.nth(1, 2, 3), - nwd.col("a", "b", "c"), - ndcs.boolean(), - (ndcs.by_name("a", "b") | ndcs.string()), - (nwd.col("b", "c") & nwd.col("a")), - nwd.col("a", "b").min().over("c", order_by="e"), - (~ndcs.by_dtype(nw.Int64()) - ndcs.datetime()), - nwd.nth(6, 2).abs().cast(nw.Int32()) + 10, - ], -) -def test_invalid_alias(expr: DummyExpr) -> None: - pattern = re.compile(r"alias.+dupe.+multi\-output") - with pytest.raises(DuplicateError, match=pattern): - expr.alias("dupe") - - -@pytest.mark.xfail( - reason="BUG: Giving a false positive for horizontal reductions:\n" - "'Cannot apply alias 'abc' to multi-output expression'", - raises=DuplicateError, -) -def test_alias_horizontal() -> None: # pragma: no cover - assert nwd.sum_horizontal("a", "b", "c").alias("abc") - assert nwd.sum_horizontal(["a", "b", "c"]).alias("abc") - assert nwd.sum_horizontal("a", "b", nwd.col("c")).alias("abc") - assert nwd.sum_horizontal(nwd.col("a"), "b", "c").alias("abc") - # NOTE: Fails starting here, but all the others should be equivalent - assert nwd.sum_horizontal(nwd.col("a", "b"), "c").alias("abc") - assert nwd.sum_horizontal("a", nwd.col("b", "c")).alias("abc") - assert nwd.sum_horizontal(nwd.col("a", "b", "c")).alias("abc") From d8dcfa4398948288da1a19d57809baead63cd899 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 11 Jun 2025 22:02:30 +0100 Subject: [PATCH 203/368] revert: Remove superseded `_ColumnSelection.expand_columns` Resolves (https://github.com/narwhals-dev/narwhals/pull/2572#discussion_r2139984656) --- narwhals/_plan/expr.py | 43 +------------------- tests/plan/expr_expansion_test.py | 65 +------------------------------ 2 files changed, 3 insertions(+), 105 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index f5ff58d5e3..da2012a739 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -15,10 +15,7 @@ is_non_nested_literal, is_regex_projection, ) -from narwhals._plan.exceptions import ( - column_not_found_error, - function_expr_invalid_operation_error, -) +from narwhals._plan.exceptions import function_expr_invalid_operation_error from narwhals._plan.name import KeepName, RenameAlias from narwhals._plan.typing import ( ExprT, @@ -40,7 +37,7 @@ from narwhals._utils import flatten if t.TYPE_CHECKING: - from typing_extensions import Self, TypeAlias + from typing_extensions import Self from narwhals._plan.common import Seq from narwhals._plan.functions import MapBatches # noqa: F401 @@ -79,12 +76,6 @@ "WindowExpr", ] -_Schema: TypeAlias = "t.Mapping[str, DType]" -"""Equivalent to `expr_expansion.FrozenSchema`. - -Using temporarily before adding caching into the mix. -""" - class Alias(ExprIR): __slots__ = ("expr", "name") @@ -136,17 +127,9 @@ def _col(name: str, /) -> Column: return Column(name=name) -def _cols(names: t.Iterable[str], /) -> Seq[Column]: - return tuple(_col(name) for name in names) - - class _ColumnSelection(ExprIR): """Nodes which can resolve to `Column`(s) with a `Schema`.""" - def expand_columns(self, schema: _Schema, /) -> Seq[Column]: - """Transform selection in context of `schema` into simpler nodes.""" - raise NotImplementedError - def map_ir(self, function: MapIR, /) -> ExprIR: return function(self) @@ -162,11 +145,6 @@ def __repr__(self) -> str: def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: return plx.col(*self.names) - def expand_columns(self, schema: _Schema) -> Seq[Column]: - if set(schema).issuperset(self.names): - return _cols(self.names) - raise column_not_found_error(self.names, schema) - class Nth(_ColumnSelection): __slots__ = ("index",) @@ -176,10 +154,6 @@ class Nth(_ColumnSelection): def __repr__(self) -> str: return f"nth({self.index})" - def expand_columns(self, schema: _Schema) -> Seq[Column]: - name = tuple(schema)[self.index] - return (_col(name),) - class IndexColumns(_ColumnSelection): """Renamed from `IndexColumn`. @@ -196,10 +170,6 @@ class IndexColumns(_ColumnSelection): def __repr__(self) -> str: return f"index_columns({self.indices!r})" - def expand_columns(self, schema: _Schema) -> Seq[Column]: - names = tuple(schema) - return _cols(names[index] for index in self.indices) - class All(_ColumnSelection): """Aka Wildcard (`pl.all()` or `pl.col("*")`). @@ -210,9 +180,6 @@ class All(_ColumnSelection): def __repr__(self) -> str: return "all()" - def expand_columns(self, schema: _Schema) -> Seq[Column]: - return _cols(schema) - class Exclude(_ColumnSelection): __slots__ = ("expr", "names") @@ -237,12 +204,6 @@ def from_names(expr: ExprIR, *names: str | t.Iterable[str]) -> Exclude: def __repr__(self) -> str: return f"{self.expr!r}.exclude({list(self.names)!r})" - def expand_columns(self, schema: _Schema) -> Seq[Column]: - if not isinstance(self.expr, All): - msg = f"Only {All()!r} is currently supported with `exclude()`" - raise NotImplementedError(msg) - return _cols(name for name in schema if name not in self.names) - def iter_left(self) -> t.Iterator[ExprIR]: yield from self.expr.iter_left() yield self diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 0688d919e9..ad16700b53 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -8,7 +8,7 @@ import narwhals as nw from narwhals._plan import demo as nwd, selectors as ndcs from narwhals._plan.common import IntoExpr, is_expr -from narwhals._plan.expr import Alias, Column, Columns, _ColumnSelection +from narwhals._plan.expr import Alias, Columns from narwhals._plan.expr_expansion import ( freeze_schema, prepare_projection, @@ -19,8 +19,6 @@ from narwhals.exceptions import ColumnNotFoundError, ComputeError, DuplicateError if TYPE_CHECKING: - from typing_extensions import TypeIs - from narwhals._plan.common import ExprIR from narwhals._plan.dummy import DummyExpr, DummySelector from narwhals._plan.typing import MapIR @@ -71,67 +69,6 @@ def assert_expr_ir_equal(left: DummyExpr | ExprIR, right: DummyExpr | ExprIR) -> assert lhs == rhs -# NOTE: The meta check doesn't provide typing and describes a superset of `_ColumnSelection` -def is_column_selection(obj: ExprIR) -> TypeIs[_ColumnSelection]: - return obj.meta.is_column_selection(allow_aliasing=False) and isinstance( - obj, _ColumnSelection - ) - - -def seq_column_from_names(names: Sequence[str]) -> tuple[Column, ...]: - return tuple(Column(name=name) for name in names) - - -@pytest.mark.parametrize( - ("expr", "into_expected"), - [ - (nwd.col("a", "c"), ["a", "c"]), - (nwd.col("o", "k", "b"), ["o", "k", "b"]), - (nwd.nth(5), ["f"]), - (nwd.nth(0, 1, 2, 3, 4), ["a", "b", "c", "d", "e"]), - (nwd.nth(-1), ["u"]), - (nwd.nth([-2, -3, -4]), ["s", "r", "q"]), - ( - nwd.all(), - [ - "a", - "b", - "c", - "d", - "e", - "f", - "g", - "h", - "i", - "j", - "k", - "l", - "m", - "n", - "o", - "p", - "q", - "r", - "s", - "u", - ], - ), - ( - nwd.exclude("a", "c", "e", "l", "q"), - ["b", "d", "f", "g", "h", "i", "j", "k", "m", "n", "o", "p", "r", "s", "u"], - ), - ], -) -def test_expand_columns_root( - expr: DummyExpr, into_expected: Sequence[str], schema_1: dict[str, DType] -) -> None: - expected = seq_column_from_names(into_expected) - selection = expr._ir - assert is_column_selection(selection) - actual = selection.expand_columns(schema_1) - assert actual == expected - - def udf_name_map(name: str) -> str: original = name upper = name.upper() From 49a6fc9108502de299ba03b7a77e66188b7b8e01 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 11 Jun 2025 22:12:02 +0100 Subject: [PATCH 204/368] refactor: Repurpose `col` --- narwhals/_plan/expr.py | 12 +++++++----- narwhals/_plan/expr_expansion.py | 15 ++++++++------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index da2012a739..ec25ad9e34 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -74,9 +74,15 @@ "SortBy", "Ternary", "WindowExpr", + "col", ] +def col(name: str, /) -> Column: + """Sugar for a **single** column selection node.""" + return Column(name=name) + + class Alias(ExprIR): __slots__ = ("expr", "name") @@ -117,16 +123,12 @@ def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: return plx.col(self.name) def with_name(self, name: str, /) -> Column: - return self if name == self.name else Column(name=name) + return self if name == self.name else col(name) def map_ir(self, function: MapIR, /) -> ExprIR: return function(self) -def _col(name: str, /) -> Column: - return Column(name=name) - - class _ColumnSelection(ExprIR): """Nodes which can resolve to `Column`(s) with a `Schema`.""" diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 83ba96f70c..c59ac2c638 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -53,6 +53,7 @@ is_regex_projection, ) from narwhals._plan.exceptions import column_not_found_error, duplicate_error +from narwhals._plan.expr import col from narwhals.dtypes import DType from narwhals.exceptions import ( ColumnNotFoundError, @@ -287,7 +288,7 @@ def replace_nth(origin: ExprIR, /, schema: FrozenSchema) -> ExprIR: def fn(child: ExprIR, /) -> ExprIR: if isinstance(child, expr.Nth): - return expr.Column(name=schema.names[child.index]) + return col(schema.names[child.index]) return child return origin.map_ir(fn) @@ -309,11 +310,11 @@ def _replace_columns_exclude(origin: ExprIR, /, name: str) -> ExprIR: [`polars_plan::plans::conversion::expr_expansion::expand_columns`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L187-L191 """ - from narwhals._plan.expr import Column, Columns, Exclude + from narwhals._plan.expr import Columns, Exclude def fn(child: ExprIR, /) -> ExprIR: if isinstance(child, Columns): - return Column(name=name) + return col(name) if isinstance(child, Exclude): return child.expr return child @@ -322,11 +323,11 @@ def fn(child: ExprIR, /) -> ExprIR: def replace_index_with_column(origin: ExprIR, /, name: str) -> ExprIR: - from narwhals._plan.expr import Column, Exclude, IndexColumns + from narwhals._plan.expr import Exclude, IndexColumns def fn(child: ExprIR, /) -> ExprIR: if isinstance(child, IndexColumns): - return Column(name=name) + return col(name) if isinstance(child, Exclude): return child.expr return child @@ -336,11 +337,11 @@ def fn(child: ExprIR, /) -> ExprIR: def replace_wildcard_with_column(origin: ExprIR, /, name: str) -> ExprIR: """`expr.All` and `Exclude`.""" - from narwhals._plan.expr import All, Column, Exclude + from narwhals._plan.expr import All, Exclude def fn(child: ExprIR, /) -> ExprIR: if isinstance(child, All): - return Column(name=name) + return col(name) if isinstance(child, Exclude): return child.expr return child From b244b344d958d1b93cea101b1ef166eab276f7c7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 11 Jun 2025 22:21:10 +0100 Subject: [PATCH 205/368] refactor: Make `expr` a dependency of `expr_expansion` --- narwhals/_plan/expr_expansion.py | 71 ++++++++++++++------------------ 1 file changed, 30 insertions(+), 41 deletions(-) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index c59ac2c638..cff92133bc 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -53,7 +53,17 @@ is_regex_projection, ) from narwhals._plan.exceptions import column_not_found_error, duplicate_error -from narwhals._plan.expr import col +from narwhals._plan.expr import ( + Alias, + All, + Columns, + Exclude, + IndexColumns, + KeepName, + Nth, + RenameAlias, + col, +) from narwhals.dtypes import DType from narwhals.exceptions import ( ColumnNotFoundError, @@ -76,7 +86,7 @@ from typing_extensions import TypeAlias - from narwhals._plan import expr, selectors + from narwhals._plan import selectors from narwhals._plan.common import Seq from narwhals._plan.dummy import DummyExpr from narwhals.dtypes import DType @@ -194,23 +204,21 @@ def from_ir(ir: ExprIR, /) -> ExpansionFlags: [`find_flags`]: https://github.com/pola-rs/polars/blob/df4d21c30c2b383b651e194f8263244f2afaeda3/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L607-L660 """ - from narwhals._plan import expr - multiple_columns: bool = False has_nth: bool = False has_wildcard: bool = False has_selector: bool = False has_exclude: bool = False for e in ir.iter_left(): - if isinstance(e, (expr.Columns, expr.IndexColumns)): + if isinstance(e, (Columns, IndexColumns)): multiple_columns = True - elif isinstance(e, expr.Nth): + elif isinstance(e, Nth): has_nth = True - elif isinstance(e, expr.All): + elif isinstance(e, All): has_wildcard = True - elif isinstance(e, expr.SelectorIR): + elif isinstance(e, SelectorIR): has_selector = True - elif isinstance(e, expr.Exclude): + elif isinstance(e, Exclude): has_exclude = True return ExpansionFlags( multiple_columns=multiple_columns, @@ -284,10 +292,8 @@ def fn(child: ExprIR, /) -> ExprIR: def replace_nth(origin: ExprIR, /, schema: FrozenSchema) -> ExprIR: - from narwhals._plan import expr - def fn(child: ExprIR, /) -> ExprIR: - if isinstance(child, expr.Nth): + if isinstance(child, Nth): return col(schema.names[child.index]) return child @@ -295,8 +301,6 @@ def fn(child: ExprIR, /) -> ExprIR: def remove_exclude(origin: ExprIR, /) -> ExprIR: - from narwhals._plan.expr import Exclude - def fn(child: ExprIR, /) -> ExprIR: if isinstance(child, Exclude): return child.expr @@ -310,7 +314,6 @@ def _replace_columns_exclude(origin: ExprIR, /, name: str) -> ExprIR: [`polars_plan::plans::conversion::expr_expansion::expand_columns`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L187-L191 """ - from narwhals._plan.expr import Columns, Exclude def fn(child: ExprIR, /) -> ExprIR: if isinstance(child, Columns): @@ -323,8 +326,6 @@ def fn(child: ExprIR, /) -> ExprIR: def replace_index_with_column(origin: ExprIR, /, name: str) -> ExprIR: - from narwhals._plan.expr import Exclude, IndexColumns - def fn(child: ExprIR, /) -> ExprIR: if isinstance(child, IndexColumns): return col(name) @@ -337,7 +338,6 @@ def fn(child: ExprIR, /) -> ExprIR: def replace_wildcard_with_column(origin: ExprIR, /, name: str) -> ExprIR: """`expr.All` and `Exclude`.""" - from narwhals._plan.expr import All, Exclude def fn(child: ExprIR, /) -> ExprIR: if isinstance(child, All): @@ -357,12 +357,11 @@ def replace_selector( schema: FrozenSchema, ) -> ExprIR: """Fully diverging from `polars`, we'll see how that goes.""" - from narwhals._plan import expr def fn(child: ExprIR, /) -> ExprIR: if isinstance(child, SelectorIR): cols = (k for k, v in schema.items() if child.matches_column(k, v)) - return expr.Columns(names=tuple(cols)) + return Columns(names=tuple(cols)) return child return ir.map_ir(fn) @@ -401,18 +400,12 @@ def replace_and_add_to_results( schema: FrozenSchema, flags: ExpansionFlags, ) -> ResultIRs: - from narwhals._plan import expr - if flags.has_nth: origin = replace_nth(origin, schema) if flags.expands: - it = ( - e - for e in origin.iter_left() - if isinstance(e, (expr.Columns, expr.IndexColumns)) - ) + it = (e for e in origin.iter_left() if isinstance(e, (Columns, IndexColumns))) if e := next(it, None): - if isinstance(e, expr.Columns): + if isinstance(e, Columns): exclude = prepare_excluded(origin, keys=(), has_exclude=flags.has_exclude) result = expand_columns( origin, result, e, col_names=schema.names, exclude=exclude @@ -433,8 +426,6 @@ def replace_and_add_to_results( def _iter_exclude_names(origin: ExprIR, /) -> Iterator[str]: """Yield all excluded names in `origin`.""" - from narwhals._plan.expr import Exclude - for e in origin.iter_left(): if isinstance(e, Exclude): yield from e.names @@ -459,9 +450,7 @@ def prepare_excluded( return frozenset(exclude) -def _all_columns_match(origin: ExprIR, /, columns: expr.Columns) -> bool: - from narwhals._plan.expr import Columns - +def _all_columns_match(origin: ExprIR, /, columns: Columns) -> bool: it = (e == columns if isinstance(e, Columns) else True for e in origin.iter_left()) return all(it) @@ -470,7 +459,7 @@ def expand_columns( origin: ExprIR, /, result: ResultIRs, - columns: expr.Columns, + columns: Columns, *, col_names: FrozenColumns, exclude: Excluded, @@ -489,7 +478,7 @@ def expand_indices( origin: ExprIR, /, result: ResultIRs, - indices: expr.IndexColumns, + indices: IndexColumns, *, schema: FrozenSchema, exclude: Excluded, @@ -528,21 +517,21 @@ def rewrite_special_aliases(origin: ExprIR, /) -> ExprIR: - Expanding all selections into `Column` - Dealing with `FunctionExpr.input` """ - from narwhals._plan import expr, meta + from narwhals._plan import meta - if meta.has_expr_ir(origin, expr.KeepName, expr.RenameAlias): - if isinstance(origin, expr.KeepName): + if meta.has_expr_ir(origin, KeepName, RenameAlias): + if isinstance(origin, KeepName): parent = origin.expr roots = parent.meta.root_names() alias = next(iter(roots)) - return expr.Alias(expr=parent, name=alias) - elif isinstance(origin, expr.RenameAlias): + return Alias(expr=parent, name=alias) + elif isinstance(origin, RenameAlias): parent = origin.expr leaf_name_or_err = meta.get_single_leaf_name(parent) if not isinstance(leaf_name_or_err, str): raise leaf_name_or_err alias = origin.function(leaf_name_or_err) - return expr.Alias(expr=parent, name=alias) + return Alias(expr=parent, name=alias) else: msg = "`keep`, `suffix`, `prefix` should be last expression" raise InvalidOperationError(msg) From 09c01fe22c25d03bc46d14d91ea8a5158bfbacde Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 11 Jun 2025 22:38:24 +0100 Subject: [PATCH 206/368] revert: Remove unused `regex` expansion stuff We don't use it in `col`, and selectors are resolved in an entirely different way --- narwhals/_plan/expr_expansion.py | 103 +++---------------------------- 1 file changed, 7 insertions(+), 96 deletions(-) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index cff92133bc..910aa3bc35 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -50,7 +50,6 @@ Immutable, SelectorIR, is_horizontal_reduction, - is_regex_projection, ) from narwhals._plan.exceptions import column_not_found_error, duplicate_error from narwhals._plan.expr import ( @@ -73,20 +72,10 @@ ) if TYPE_CHECKING: - import re - from typing import ( - Callable, - ItemsView, - Iterator, - KeysView, - Mapping, - Sequence, - ValuesView, - ) + from typing import ItemsView, Iterator, KeysView, Mapping, Sequence, ValuesView from typing_extensions import TypeAlias - from narwhals._plan import selectors from narwhals._plan.common import Seq from narwhals._plan.dummy import DummyExpr from narwhals.dtypes import DType @@ -407,9 +396,7 @@ def replace_and_add_to_results( if e := next(it, None): if isinstance(e, Columns): exclude = prepare_excluded(origin, keys=(), has_exclude=flags.has_exclude) - result = expand_columns( - origin, result, e, col_names=schema.names, exclude=exclude - ) + result = expand_columns(origin, result, e, exclude=exclude) else: exclude = prepare_excluded( origin, keys=keys, has_exclude=flags.has_exclude @@ -420,7 +407,8 @@ def replace_and_add_to_results( result = replace_wildcard(origin, result, col_names=schema.names, exclude=exclude) else: exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude) - result = replace_regex(origin, result, col_names=schema.names, exclude=exclude) + origin = rewrite_special_aliases(origin) + result.append(origin) return result @@ -456,13 +444,7 @@ def _all_columns_match(origin: ExprIR, /, columns: Columns) -> bool: def expand_columns( - origin: ExprIR, - /, - result: ResultIRs, - columns: Columns, - *, - col_names: FrozenColumns, - exclude: Excluded, + origin: ExprIR, /, result: ResultIRs, columns: Columns, *, exclude: Excluded ) -> ResultIRs: if not _all_columns_match(origin, columns): msg = "expanding more than one `col` is not allowed" @@ -470,7 +452,8 @@ def expand_columns( for name in columns.names: if name not in exclude: new_expr = _replace_columns_exclude(origin, name) - result = replace_regex(new_expr, result, col_names=col_names, exclude=exclude) + new_expr = rewrite_special_aliases(new_expr) + result.append(new_expr) return result @@ -536,75 +519,3 @@ def rewrite_special_aliases(origin: ExprIR, /) -> ExprIR: msg = "`keep`, `suffix`, `prefix` should be last expression" raise InvalidOperationError(msg) return origin - - -def into_pattern(obj: str | re.Pattern[str] | selectors.Matches, /) -> re.Pattern[str]: - import re - - from narwhals._plan import selectors - - if isinstance(obj, str): - return re.compile(obj) - elif isinstance(obj, selectors.Matches): - return obj.pattern - elif isinstance(obj, re.Pattern): - return obj - else: - msg = f"Cannot convert {type(obj).__name__!r} into a regular expression" - raise TypeError(msg) - - -# NOTE: Will likely be using `selectors.Matches` for this -# Doing a direct translation from `rust` *first*, to make replacing -# the deviations *later* not as daunting -def replace_regex( - origin: ExprIR, /, result: ResultIRs, *, col_names: FrozenColumns, exclude: Excluded -) -> ResultIRs: - regex: str | None = None - for name in origin.meta.root_names(): - if is_regex_projection(name): - if regex is None: - regex = name - result = expand_regex( - origin, - result, - into_pattern(name), - col_names=col_names, - exclude=exclude, - ) - elif regex != name: - msg = "an expression is not allowed to have different regexes" - raise ComputeError(msg) - if regex is None: - origin = rewrite_special_aliases(origin) - result.append(origin) - return result - - -def expand_regex( - origin: ExprIR, - /, - result: ResultIRs, - pattern: re.Pattern[str], - *, - col_names: FrozenColumns, - exclude: Excluded, -) -> ResultIRs: - for name in col_names: - if pattern.match(name) and name not in exclude: - expanded = remove_exclude(origin) - expanded = expanded.map_ir(_replace_regex(pattern, name)) - expanded = rewrite_special_aliases(expanded) - result.append(expanded) - return result - - -def _replace_regex(pattern: re.Pattern[str], name: str, /) -> Callable[[ExprIR], ExprIR]: - from narwhals._plan.meta import is_column - - pat = pattern.pattern - - def fn(ir: ExprIR, /) -> ExprIR: - return ir.with_name(name) if is_column(ir) and ir.name == pat else ir - - return fn From bd588677fe0b5ef5037a180a0b5ce21f57b9c83a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 11 Jun 2025 22:42:57 +0100 Subject: [PATCH 207/368] refactor: Remove/rename things inherited from `rust` --- narwhals/_plan/expr_expansion.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 910aa3bc35..cb9564e2f3 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -365,9 +365,7 @@ def rewrite_projections( *, # NOTE: Represents group_by keys schema: FrozenSchema, ) -> Seq[ExprIR]: - # NOTE: This is where the mutable `result` is initialized - result_length = len(input) + len(schema) - result: deque[ExprIR] = deque(maxlen=result_length) + result: deque[ExprIR] = deque() for expr in input: expanded = expand_function_inputs(expr, schema=schema) flags = ExpansionFlags.from_ir(expanded) @@ -407,8 +405,8 @@ def replace_and_add_to_results( result = replace_wildcard(origin, result, col_names=schema.names, exclude=exclude) else: exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude) - origin = rewrite_special_aliases(origin) - result.append(origin) + expanded = rewrite_special_aliases(origin) + result.append(expanded) return result @@ -451,9 +449,9 @@ def expand_columns( raise ComputeError(msg) for name in columns.names: if name not in exclude: - new_expr = _replace_columns_exclude(origin, name) - new_expr = rewrite_special_aliases(new_expr) - result.append(new_expr) + expanded = _replace_columns_exclude(origin, name) + expanded = rewrite_special_aliases(expanded) + result.append(expanded) return result @@ -475,9 +473,9 @@ def expand_indices( raise ComputeError(msg) name = names[idx] if name not in exclude: - new_expr = replace_index_with_column(origin, name) - new_expr = rewrite_special_aliases(new_expr) - result.append(new_expr) + expanded = replace_index_with_column(origin, name) + expanded = rewrite_special_aliases(expanded) + result.append(expanded) return result @@ -486,9 +484,9 @@ def replace_wildcard( ) -> ResultIRs: for name in col_names: if name not in exclude: - new_expr = replace_wildcard_with_column(origin, name) - new_expr = rewrite_special_aliases(new_expr) - result.append(new_expr) + expanded = replace_wildcard_with_column(origin, name) + expanded = rewrite_special_aliases(expanded) + result.append(expanded) return result From 00c2cc19034fb215bbdd9318bd4f7a9b47593317 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 11 Jun 2025 22:51:10 +0100 Subject: [PATCH 208/368] fix: Add some missing `is_scalar` props --- narwhals/_plan/expr.py | 4 ++++ narwhals/_plan/name.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index ec25ad9e34..27217caf80 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -757,6 +757,10 @@ class Ternary(ExprIR): truthy: ExprIR falsy: ExprIR + @property + def is_scalar(self) -> bool: + return self.predicate.is_scalar and self.truthy.is_scalar and self.falsy.is_scalar + def __str__(self) -> str: # NOTE: Default slot ordering made it difficult to read fields = ( diff --git a/narwhals/_plan/name.py b/narwhals/_plan/name.py index 86a54d6fbe..2e8d6a2233 100644 --- a/narwhals/_plan/name.py +++ b/narwhals/_plan/name.py @@ -21,6 +21,10 @@ class KeepName(ExprIR): expr: ExprIR + @property + def is_scalar(self) -> bool: + return self.expr.is_scalar + def __repr__(self) -> str: return f"{self.expr!r}.name.keep()" @@ -45,6 +49,10 @@ class RenameAlias(ExprIR): expr: ExprIR function: AliasName + @property + def is_scalar(self) -> bool: + return self.expr.is_scalar + def __repr__(self) -> str: return f".rename_alias({self.expr!r})" From 48e9f259ba43cffb143fe7ab0acbebf9827fe528 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 12 Jun 2025 11:31:04 +0100 Subject: [PATCH 209/368] feat: Add `functions.Log` Added after this PR started in (#2549) --- narwhals/_plan/dummy.py | 4 ++++ narwhals/_plan/functions.py | 13 +++++++++++++ 2 files changed, 17 insertions(+) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 287c12027f..423605983f 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -2,6 +2,7 @@ from __future__ import annotations +import math import typing as t from typing import TYPE_CHECKING @@ -196,6 +197,9 @@ def hist( node = F.HistBinCount(include_breakpoint=include_breakpoint) return self._from_ir(node.to_function_expr(self._ir)) + def log(self, base: float = math.e) -> Self: + return self._from_ir(F.Log(base=base).to_function_expr(self._ir)) + def null_count(self) -> Self: return self._from_ir(F.NullCount().to_function_expr(self._ir)) diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 1a74a58285..70c00e37ee 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -79,6 +79,19 @@ def __repr__(self) -> str: return "null_count" +class Log(Function): + __slots__ = ("base",) + + base: float + + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() + + def __repr__(self) -> str: + return "log" + + class Pow(Function): @property def function_options(self) -> FunctionOptions: From fb3f407e06546c090f3e3b8a016ef3f59c65daae Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 12 Jun 2025 13:56:33 +0100 Subject: [PATCH 210/368] feat: Add `Expr.filter` The parsing will be reused by `when`, `BaseFrame.filter`, etc --- narwhals/_plan/dummy.py | 8 +++++ narwhals/_plan/exceptions.py | 3 +- narwhals/_plan/expr.py | 1 - narwhals/_plan/expr_parsing.py | 38 +++++++++++++++++++- tests/plan/expr_expansion_test.py | 10 ++---- tests/plan/expr_parsing_test.py | 60 ++++++++++++++++++++++++++++++- tests/plan/utils.py | 15 ++++++++ 7 files changed, 123 insertions(+), 12 deletions(-) create mode 100644 tests/plan/utils.py diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 423605983f..efe1c049a6 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -173,6 +173,14 @@ def sort_by( options = SortMultipleOptions(descending=desc, nulls_last=nulls) return self._from_ir(expr.SortBy(expr=self._ir, by=sort_by, options=options)) + def filter( + self, + *predicates: IntoExprColumn | t.Iterable[IntoExprColumn], + **constraints: t.Any, + ) -> Self: + by = parse.parse_predicates_constraints_into_expr_ir(*predicates, **constraints) + return self._from_ir(expr.Filter(expr=self._ir, by=by)) + def abs(self) -> Self: return self._from_ir(F.Abs().to_function_expr(self._ir)) diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index 2237370b90..b1c275d8b6 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -138,10 +138,11 @@ def invalid_into_expr_error( named_inputs: dict[str, IntoExpr], /, ) -> InvalidIntoExprError: + named = f"\n{named_inputs!r}" if named_inputs else "" msg = ( f"Passing both iterable and positional inputs is not supported.\n" f"Hint:\nInstead try collecting all arguments into a {type(first_input).__name__!r}\n" - f"{first_input!r}\n{more_inputs!r}\n{named_inputs!r}" + f"{first_input!r}\n{more_inputs!r}{named}" ) return InvalidIntoExprError(msg) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 27217caf80..e584ae79d3 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -485,7 +485,6 @@ class AnonymousExpr(FunctionExpr["MapBatches"]): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L158-L166.""" -# TODO @dangotbanned: add `DummyExpr.filter` class Filter(ExprIR): __slots__ = ("by", "expr") diff --git a/narwhals/_plan/expr_parsing.py b/narwhals/_plan/expr_parsing.py index 8af7db1f55..93abfc5888 100644 --- a/narwhals/_plan/expr_parsing.py +++ b/narwhals/_plan/expr_parsing.py @@ -1,9 +1,10 @@ from __future__ import annotations # ruff: noqa: A002 +from itertools import chain from typing import TYPE_CHECKING, Iterable, Sequence, TypeVar -from narwhals._plan.common import is_expr, is_iterable_reject +from narwhals._plan.common import IntoExprColumn, is_expr, is_iterable_reject from narwhals._plan.exceptions import ( invalid_into_expr_error, is_iterable_pandas_error, @@ -103,6 +104,22 @@ def parse_into_seq_of_expr_ir( return tuple(_parse_into_iter_expr_ir(first_input, *more_inputs, **named_inputs)) +def parse_predicates_constraints_into_expr_ir( + first_predicate: IntoExprColumn | Iterable[IntoExprColumn] = (), + *more_predicates: IntoExprColumn | _RaisesInvalidIntoExprError, + **constraints: IntoExpr, +) -> ExprIR: + """Parse variadic predicates and constraints into an `ExprIR` node. + + The result is an AND-reduction of all inputs. + """ + all_predicates = _parse_into_iter_expr_ir(first_predicate, *more_predicates) + if constraints: + chained = chain(all_predicates, _parse_constraints(constraints)) + return _combine_predicates(chained) + return _combine_predicates(all_predicates) + + def _parse_into_iter_expr_ir( first_input: IntoExpr | Iterable[IntoExpr], *more_inputs: IntoExpr, @@ -139,6 +156,25 @@ def _parse_named_inputs(named_inputs: dict[str, IntoExpr], /) -> Iterator[ExprIR yield Alias(expr=parse_into_expr_ir(input), name=name) +def _parse_constraints(constraints: dict[str, IntoExpr], /) -> Iterator[ExprIR]: + from narwhals._plan import demo as nwd + + for name, value in constraints.items(): + yield (nwd.col(name) == value)._ir + + +def _combine_predicates(predicates: Iterator[ExprIR], /) -> ExprIR: + from narwhals._plan.boolean import AllHorizontal + + first = next(predicates, None) + if not first: + msg = "at least one predicate or constraint must be provided" + raise TypeError(msg) + if second := next(predicates, None): + return AllHorizontal().to_function_expr(first, second, *predicates) + return first + + def _is_iterable(obj: Iterable[T] | Any) -> TypeIs[Iterable[T]]: if is_pandas_dataframe(obj) or is_pandas_series(obj): raise is_iterable_pandas_error(obj) diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index ad16700b53..8d8d9966ea 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -7,7 +7,6 @@ import narwhals as nw from narwhals._plan import demo as nwd, selectors as ndcs -from narwhals._plan.common import IntoExpr, is_expr from narwhals._plan.expr import Alias, Columns from narwhals._plan.expr_expansion import ( freeze_schema, @@ -17,9 +16,10 @@ ) from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir from narwhals.exceptions import ColumnNotFoundError, ComputeError, DuplicateError +from tests.plan.utils import assert_expr_ir_equal if TYPE_CHECKING: - from narwhals._plan.common import ExprIR + from narwhals._plan.common import ExprIR, IntoExpr from narwhals._plan.dummy import DummyExpr, DummySelector from narwhals._plan.typing import MapIR from narwhals.dtypes import DType @@ -63,12 +63,6 @@ def schema_1() -> dict[str, DType]: BIG_EXCLUDE = ("k", "l", "m", "n", "o", "p", "s", "u", "r", "a", "b", "e", "q") -def assert_expr_ir_equal(left: DummyExpr | ExprIR, right: DummyExpr | ExprIR) -> None: - lhs = left._ir if is_expr(left) else left - rhs = right._ir if is_expr(right) else right - assert lhs == rhs - - def udf_name_map(name: str) -> str: original = name upper = name.upper() diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 4fa567386e..6264a07bb0 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -2,6 +2,7 @@ import re from collections import deque +from contextlib import nullcontext from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import pytest @@ -12,15 +13,17 @@ boolean, functions as F, # noqa: N812 ) -from narwhals._plan.common import ExprIR, Function +from narwhals._plan.common import ExprIR, Function, IntoExprColumn from narwhals._plan.dummy import DummyExpr, DummySeries from narwhals._plan.expr import BinaryExpr, FunctionExpr from narwhals.exceptions import ( + InvalidIntoExprError, InvalidOperationError, LengthChangingExprError, MultiOutputExpressionError, ShapeError, ) +from tests.plan.utils import assert_expr_ir_equal if TYPE_CHECKING: from typing import ContextManager @@ -299,3 +302,58 @@ def test_is_in_series() -> None: def test_invalid_is_in(other: Any, context: ContextManager[Any]) -> None: with context: nwd.col("a").is_in(other) + + +def test_filter_full_spellings() -> None: + a = nwd.col("a") + b = nwd.col("b") + c = nwd.col("c") + d = nwd.col("d") + expected = a.filter(b != b.max(), c < nwd.lit(2), d == nwd.lit(5)) + expr_1 = a.filter([b != b.max(), c < nwd.lit(2), d == nwd.lit(5)]) + expr_2 = a.filter([b != b.max(), c < nwd.lit(2)], d=nwd.lit(5)) + expr_3 = a.filter([b != b.max(), c < nwd.lit(2)], d=5) + expr_4 = a.filter(b != b.max(), c < nwd.lit(2), d=5) + expr_5 = a.filter(b != b.max(), c < 2, d=5) + expr_6 = a.filter((b != b.max(), c < 2), d=5) + assert_expr_ir_equal(expected, expr_1) + assert_expr_ir_equal(expected, expr_2) + assert_expr_ir_equal(expected, expr_3) + assert_expr_ir_equal(expected, expr_4) + assert_expr_ir_equal(expected, expr_5) + assert_expr_ir_equal(expected, expr_6) + + +@pytest.mark.parametrize( + ("predicates", "constraints", "context"), + [ + ([nwd.col("b").is_last_distinct()], {}, nullcontext()), + ((), {"b": 10}, nullcontext()), + ((), {"b": nwd.lit(10)}, nullcontext()), + ( + (), + {}, + pytest.raises( + TypeError, match=re.compile(r"at least one predicate", re.IGNORECASE) + ), + ), + ((nwd.col("b") > 1, nwd.col("c").is_null()), {}, nullcontext()), + ( + ([nwd.col("b") > 1], nwd.col("c").is_null()), + {}, + pytest.raises( + InvalidIntoExprError, + match=re.compile( + r"both iterable.+positional.+not supported", re.IGNORECASE + ), + ), + ), + ], +) +def test_filter_partial_spellings( + predicates: Iterable[IntoExprColumn], + constraints: dict[str, Any], + context: ContextManager[Any], +) -> None: + with context: + assert nwd.col("a").filter(*predicates, **constraints) diff --git a/tests/plan/utils.py b/tests/plan/utils.py new file mode 100644 index 0000000000..641ccddb19 --- /dev/null +++ b/tests/plan/utils.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from narwhals._plan.common import is_expr + +if TYPE_CHECKING: + from narwhals._plan.common import ExprIR + from narwhals._plan.dummy import DummyExpr + + +def assert_expr_ir_equal(left: DummyExpr | ExprIR, right: DummyExpr | ExprIR) -> None: + lhs = left._ir if is_expr(left) else left + rhs = right._ir if is_expr(right) else right + assert lhs == rhs From 52f7975b5cf045cb8c5dfce2e3cfb3f9d0d7ec8d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 12 Jun 2025 14:32:34 +0100 Subject: [PATCH 211/368] ci: Update `name-tests-test` exclude pattern I'm putting utils here to simplify keeping this branch in sync --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0e3f4636f1..ae7a3707e9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -95,7 +95,7 @@ repos: rev: v5.0.0 hooks: - id: name-tests-test - exclude: ^tests/utils\.py + exclude: ^(tests/utils\.py|tests/plan/utils\.py) - id: no-commit-to-branch - id: end-of-file-fixer exclude: .svg$ From 0ff24fee05c21b3c4a45b05d986535f67ce35a5e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 12 Jun 2025 14:49:22 +0100 Subject: [PATCH 212/368] refactor: Move types from `common` to `typing` --- narwhals/_plan/boolean.py | 3 ++- narwhals/_plan/common.py | 16 ++-------------- narwhals/_plan/demo.py | 2 +- narwhals/_plan/dummy.py | 4 ++-- narwhals/_plan/exceptions.py | 3 ++- narwhals/_plan/expr.py | 2 +- narwhals/_plan/expr_expansion.py | 2 +- narwhals/_plan/expr_parsing.py | 5 +++-- narwhals/_plan/functions.py | 3 ++- narwhals/_plan/options.py | 2 +- narwhals/_plan/typing.py | 21 ++++++++++++++++++++- narwhals/_plan/when_then.py | 3 ++- narwhals/_plan/window.py | 3 ++- tests/plan/expr_expansion_test.py | 4 ++-- tests/plan/expr_parsing_test.py | 4 ++-- 15 files changed, 45 insertions(+), 32 deletions(-) diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/boolean.py index e8ccdbc54b..fda6733ed5 100644 --- a/narwhals/_plan/boolean.py +++ b/narwhals/_plan/boolean.py @@ -9,9 +9,10 @@ from narwhals._typing_compat import TypeVar if t.TYPE_CHECKING: - from narwhals._plan.common import ExprIR, Seq # noqa: F401 + from narwhals._plan.common import ExprIR from narwhals._plan.dummy import DummySeries from narwhals._plan.expr import Literal # noqa: F401 + from narwhals._plan.typing import Seq # noqa: F401 from narwhals.typing import ClosedInterval OtherT = TypeVar("OtherT") diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index ee2eea0efa..89de921662 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -4,13 +4,13 @@ from decimal import Decimal from typing import TYPE_CHECKING, Generic, TypeVar -from narwhals._plan.typing import ExprT, IRNamespaceT, MapIR, Ns +from narwhals._plan.typing import ExprT, IRNamespaceT, MapIR, Ns, Seq from narwhals.utils import Version if TYPE_CHECKING: from typing import Any, Callable, Iterable, Iterator, Literal - from typing_extensions import Never, Self, TypeAlias, TypeIs, dataclass_transform + from typing_extensions import Never, Self, TypeIs, dataclass_transform from narwhals._plan.dummy import DummyExpr, DummySelector, DummySeries from narwhals._plan.expr import FunctionExpr @@ -49,18 +49,6 @@ def decorator(cls_or_fn: T) -> T: T = TypeVar("T") -Seq: TypeAlias = "tuple[T,...]" -"""Immutable Sequence. - -Using instead of `Sequence`, as a `list` can be passed there (can't break immutability promise). -""" - -Udf: TypeAlias = "Callable[[Any], Any]" -"""Placeholder for `map_batches(function=...)`.""" - -IntoExprColumn: TypeAlias = "DummyExpr | DummySeries | str" -IntoExpr: TypeAlias = "NonNestedLiteral | IntoExprColumn" - _IMMUTABLE_HASH_NAME: Literal["__immutable_hash_value__"] = "__immutable_hash_value__" diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index 7d6d066d71..a5e28c67f0 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -11,7 +11,6 @@ ) from narwhals._plan.common import ( ExprIR, - IntoExpr, is_expr, is_non_nested_literal, py_to_narwhals_dtype, @@ -30,6 +29,7 @@ from narwhals._plan.dummy import DummyExpr from narwhals._plan.expr import SortBy + from narwhals._plan.typing import IntoExpr from narwhals.typing import NonNestedLiteral diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index efe1c049a6..97b6671560 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -33,14 +33,14 @@ from typing_extensions import Never, Self from narwhals._plan.categorical import ExprCatNamespace - from narwhals._plan.common import ExprIR, IntoExpr, IntoExprColumn, Seq, Udf + from narwhals._plan.common import ExprIR from narwhals._plan.lists import ExprListNamespace from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.name import ExprNameNamespace from narwhals._plan.strings import ExprStringNamespace from narwhals._plan.struct import ExprStructNamespace from narwhals._plan.temporal import ExprDateTimeNamespace - from narwhals._plan.typing import ExprT, Ns + from narwhals._plan.typing import ExprT, IntoExpr, IntoExprColumn, Ns, Seq, Udf from narwhals.typing import ( ClosedInterval, FillNullStrategy, diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index b1c275d8b6..0e4dbe0b4b 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -24,10 +24,11 @@ import polars as pl from narwhals._plan.aggregation import Agg - from narwhals._plan.common import ExprIR, Function, IntoExpr, Seq + from narwhals._plan.common import ExprIR, Function from narwhals._plan.expr import FunctionExpr, WindowExpr from narwhals._plan.operators import Operator from narwhals._plan.options import SortOptions + from narwhals._plan.typing import IntoExpr, Seq # NOTE: Using verbose names to start with diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index e584ae79d3..2cce4f5d8b 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -33,13 +33,13 @@ RollingT, SelectorOperatorT, SelectorT, + Seq, ) from narwhals._utils import flatten if t.TYPE_CHECKING: from typing_extensions import Self - from narwhals._plan.common import Seq from narwhals._plan.functions import MapBatches # noqa: F401 from narwhals._plan.literal import LiteralValue from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index cb9564e2f3..f0d28acc43 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -76,8 +76,8 @@ from typing_extensions import TypeAlias - from narwhals._plan.common import Seq from narwhals._plan.dummy import DummyExpr + from narwhals._plan.typing import Seq from narwhals.dtypes import DType diff --git a/narwhals/_plan/expr_parsing.py b/narwhals/_plan/expr_parsing.py index 93abfc5888..1e8b1139bf 100644 --- a/narwhals/_plan/expr_parsing.py +++ b/narwhals/_plan/expr_parsing.py @@ -4,7 +4,7 @@ from itertools import chain from typing import TYPE_CHECKING, Iterable, Sequence, TypeVar -from narwhals._plan.common import IntoExprColumn, is_expr, is_iterable_reject +from narwhals._plan.common import is_expr, is_iterable_reject from narwhals._plan.exceptions import ( invalid_into_expr_error, is_iterable_pandas_error, @@ -18,7 +18,8 @@ import polars as pl from typing_extensions import TypeAlias, TypeIs - from narwhals._plan.common import ExprIR, IntoExpr, Seq + from narwhals._plan.common import ExprIR + from narwhals._plan.typing import IntoExpr, IntoExprColumn, Seq from narwhals.dtypes import DType T = TypeVar("T") diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 70c00e37ee..4c5d7cd910 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -13,9 +13,10 @@ from typing_extensions import Self - from narwhals._plan.common import ExprIR, Seq, Udf + from narwhals._plan.common import ExprIR from narwhals._plan.expr import AnonymousExpr, RollingExpr from narwhals._plan.options import EWMOptions, RankOptions, RollingOptionsFixedWindow + from narwhals._plan.typing import Seq, Udf from narwhals.dtypes import DType from narwhals.typing import FillNullStrategy diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 9008ec1979..1193993987 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -6,7 +6,7 @@ from narwhals._plan.common import Immutable if TYPE_CHECKING: - from narwhals._plan.common import Seq + from narwhals._plan.typing import Seq from narwhals.typing import RankMethod diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index dee99325ee..107dd188e5 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -11,12 +11,14 @@ from narwhals._compliant.typing import CompliantExprAny from narwhals._plan import operators as ops from narwhals._plan.common import ExprIR, Function, IRNamespace, SelectorIR - from narwhals._plan.dummy import DummySeries + from narwhals._plan.dummy import DummyExpr, DummySeries from narwhals._plan.functions import RollingWindow from narwhals.typing import NonNestedLiteral __all__ = [ "FunctionT", + "IntoExpr", + "IntoExprColumn", "LeftSelectorT", "LeftT", "LiteralT", @@ -29,6 +31,8 @@ "RollingT", "SelectorOperatorT", "SelectorT", + "Seq", + "Udf", ] @@ -63,3 +67,18 @@ ExprT = TypeVar("ExprT", bound="Expr") Ns: TypeAlias = "Namespace[t.Any, ExprT]" """A `CompliantNamespace`, ignoring the `Frame` type.""" + + +T = TypeVar("T") + +Seq: TypeAlias = "tuple[T,...]" +"""Immutable Sequence. + +Using instead of `Sequence`, as a `list` can be passed there (can't break immutability promise). +""" + +Udf: TypeAlias = "t.Callable[[t.Any], t.Any]" +"""Placeholder for `map_batches(function=...)`.""" + +IntoExprColumn: TypeAlias = "DummyExpr | DummySeries | str" +IntoExpr: TypeAlias = "NonNestedLiteral | IntoExprColumn" diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index 40eaae9971..9fe4b87687 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -7,8 +7,9 @@ from narwhals._plan.expr_parsing import parse_into_expr_ir if TYPE_CHECKING: - from narwhals._plan.common import ExprIR, IntoExpr, Seq + from narwhals._plan.common import ExprIR from narwhals._plan.expr import Ternary + from narwhals._plan.typing import IntoExpr, Seq class When(Immutable): diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py index c2484a544a..3d16caa96b 100644 --- a/narwhals/_plan/window.py +++ b/narwhals/_plan/window.py @@ -10,9 +10,10 @@ ) if TYPE_CHECKING: - from narwhals._plan.common import ExprIR, Seq + from narwhals._plan.common import ExprIR from narwhals._plan.expr import OrderedWindowExpr, WindowExpr from narwhals._plan.options import SortOptions + from narwhals._plan.typing import Seq from narwhals.exceptions import InvalidOperationError diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 8d8d9966ea..7364226c0e 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -19,9 +19,9 @@ from tests.plan.utils import assert_expr_ir_equal if TYPE_CHECKING: - from narwhals._plan.common import ExprIR, IntoExpr + from narwhals._plan.common import ExprIR from narwhals._plan.dummy import DummyExpr, DummySelector - from narwhals._plan.typing import MapIR + from narwhals._plan.typing import IntoExpr, MapIR from narwhals.dtypes import DType diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 6264a07bb0..8f62a0d121 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -13,7 +13,7 @@ boolean, functions as F, # noqa: N812 ) -from narwhals._plan.common import ExprIR, Function, IntoExprColumn +from narwhals._plan.common import ExprIR, Function from narwhals._plan.dummy import DummyExpr, DummySeries from narwhals._plan.expr import BinaryExpr, FunctionExpr from narwhals.exceptions import ( @@ -30,7 +30,7 @@ from typing_extensions import TypeAlias - from narwhals._plan.common import IntoExpr, Seq + from narwhals._plan.typing import IntoExpr, IntoExprColumn, Seq IntoIterable: TypeAlias = Callable[[Sequence[Any]], Iterable[Any]] From 13caf8d2641b2616e6fb1d686669fdf6a304ce12 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 12 Jun 2025 15:08:48 +0100 Subject: [PATCH 213/368] feat: Utilize `IntoDType` Just added in #2654 --- narwhals/_plan/common.py | 12 +++++++++--- narwhals/_plan/demo.py | 10 ++++++---- narwhals/_plan/dummy.py | 5 +++-- narwhals/_plan/expr_parsing.py | 4 ++-- narwhals/_plan/functions.py | 7 +++---- tests/plan/expr_parsing_test.py | 2 +- 6 files changed, 24 insertions(+), 16 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 89de921662..c00be99779 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Generic, TypeVar from narwhals._plan.typing import ExprT, IRNamespaceT, MapIR, Ns, Seq +from narwhals.dtypes import DType from narwhals.utils import Version if TYPE_CHECKING: @@ -16,8 +17,7 @@ from narwhals._plan.expr import FunctionExpr from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.options import FunctionOptions - from narwhals.dtypes import DType - from narwhals.typing import NonNestedLiteral + from narwhals.typing import IntoDType, NonNestedDType, NonNestedLiteral else: # NOTE: This isn't important to the proposal, just wanted IDE support @@ -367,7 +367,7 @@ def is_horizontal_reduction(obj: FunctionExpr[Any] | Any) -> TypeIs[FunctionExpr def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) -> DType: dtypes = version.dtypes - mapping: dict[type[NonNestedLiteral], type[DType]] = { + mapping: dict[type[NonNestedLiteral], type[NonNestedDType]] = { int: dtypes.Int64, float: dtypes.Float64, str: dtypes.String, @@ -383,6 +383,12 @@ def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) return mapping.get(type(obj), dtypes.Unknown)() +def into_dtype(dtype: IntoDType, /) -> DType: + if isinstance(dtype, type) and issubclass(dtype, DType): + return dtype() + return dtype + + def collect(iterable: Seq[T] | Iterable[T], /) -> Seq[T]: """Collect `iterable` into a `tuple`, *iff* it is not one already.""" return iterable if isinstance(iterable, tuple) else tuple(iterable) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index a5e28c67f0..14876fa666 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -11,6 +11,7 @@ ) from narwhals._plan.common import ( ExprIR, + into_dtype, is_expr, is_non_nested_literal, py_to_narwhals_dtype, @@ -21,7 +22,6 @@ from narwhals._plan.strings import ConcatHorizontal from narwhals._plan.when_then import When from narwhals._utils import Version, flatten -from narwhals.dtypes import DType from narwhals.exceptions import OrderDependentExprError if t.TYPE_CHECKING: @@ -30,7 +30,7 @@ from narwhals._plan.dummy import DummyExpr from narwhals._plan.expr import SortBy from narwhals._plan.typing import IntoExpr - from narwhals.typing import NonNestedLiteral + from narwhals.typing import IntoDType, NonNestedLiteral def col(*names: str | t.Iterable[str]) -> DummyExpr: @@ -54,15 +54,17 @@ def nth(*indices: int | t.Sequence[int]) -> DummyExpr: def lit( - value: NonNestedLiteral | DummySeries, dtype: DType | type[DType] | None = None + value: NonNestedLiteral | DummySeries, dtype: IntoDType | None = None ) -> DummyExpr: if isinstance(value, DummySeries): return SeriesLiteral(value=value).to_literal().to_narwhals() if not is_non_nested_literal(value): msg = f"{type(value).__name__!r} is not supported in `nw.lit`, got: {value!r}." raise TypeError(msg) - if dtype is None or not isinstance(dtype, DType): + if dtype is None: dtype = py_to_narwhals_dtype(value, Version.MAIN) + else: + dtype = into_dtype(dtype) return ScalarLiteral(value=value, dtype=dtype).to_literal().to_narwhals() diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 97b6671560..bffe012fcf 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -44,6 +44,7 @@ from narwhals.typing import ( ClosedInterval, FillNullStrategy, + IntoDType, NativeSeries, NumericLiteral, RankMethod, @@ -367,7 +368,7 @@ def replace_strict( old: t.Sequence[t.Any] | t.Mapping[t.Any, t.Any], new: t.Sequence[t.Any] | None = None, *, - return_dtype: DType | type[DType] | None = None, + return_dtype: IntoDType | None = None, ) -> Self: before: Seq[t.Any] after: Seq[t.Any] @@ -395,7 +396,7 @@ def gather_every(self, n: int, offset: int = 0) -> Self: def map_batches( self, function: Udf, - return_dtype: DType | None = None, + return_dtype: IntoDType | None = None, *, is_elementwise: bool = False, returns_scalar: bool = False, diff --git a/narwhals/_plan/expr_parsing.py b/narwhals/_plan/expr_parsing.py index 1e8b1139bf..0c31a6e157 100644 --- a/narwhals/_plan/expr_parsing.py +++ b/narwhals/_plan/expr_parsing.py @@ -20,7 +20,7 @@ from narwhals._plan.common import ExprIR from narwhals._plan.typing import IntoExpr, IntoExprColumn, Seq - from narwhals.dtypes import DType + from narwhals.typing import IntoDType T = TypeVar("T") @@ -82,7 +82,7 @@ def parse_into_expr_ir( - input: IntoExpr, *, str_as_lit: bool = False, dtype: DType | None = None + input: IntoExpr, *, str_as_lit: bool = False, dtype: IntoDType | None = None ) -> ExprIR: """Parse a single input into an `ExprIR` node.""" from narwhals._plan import demo as nwd diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 4c5d7cd910..923f02bbff 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -17,8 +17,7 @@ from narwhals._plan.expr import AnonymousExpr, RollingExpr from narwhals._plan.options import EWMOptions, RankOptions, RollingOptionsFixedWindow from narwhals._plan.typing import Seq, Udf - from narwhals.dtypes import DType - from narwhals.typing import FillNullStrategy + from narwhals.typing import FillNullStrategy, IntoDType class Abs(Function): @@ -372,7 +371,7 @@ class ReplaceStrict(Function): old: Seq[Any] new: Seq[Any] - return_dtype: DType | type[DType] | None + return_dtype: IntoDType | None @property def function_options(self) -> FunctionOptions: @@ -400,7 +399,7 @@ class MapBatches(Function): __slots__ = ("function", "is_elementwise", "return_dtype", "returns_scalar") function: Udf - return_dtype: DType | None + return_dtype: IntoDType | None is_elementwise: bool returns_scalar: bool diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 8f62a0d121..f8cec7232d 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -77,7 +77,7 @@ def test_parsing( (["a", "b", "c"]), (nwd.col("d", "e", "f"), nwd.col("g"), "q", nwd.nth(9)), ((nwd.lit(1),)), - ([nwd.lit(1), nwd.lit(2), nwd.lit(3)]), + ([nwd.lit(1), nwd.lit(2, nw.Int64), nwd.lit(3, nw.Int64())]), ], ) def test_function_expr_horizontal( From 8c86feaf75536932186869a23a471b2fb881af2b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 12 Jun 2025 15:31:57 +0100 Subject: [PATCH 214/368] perf: Add a two-level cache for selectors expansion --- narwhals/_plan/expr_expansion.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index f0d28acc43..f22fe91d9d 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -349,13 +349,30 @@ def replace_selector( def fn(child: ExprIR, /) -> ExprIR: if isinstance(child, SelectorIR): - cols = (k for k, v in schema.items() if child.matches_column(k, v)) - return Columns(names=tuple(cols)) + return expand_selector(child, schema=schema) return child return ir.map_ir(fn) +@lru_cache(maxsize=100) +def selector_matches_column(selector: SelectorIR, name: str, dtype: DType, /) -> bool: + """Cached version of `SelectorIR.matches.column`. + + Allows results of evaluations can be shared across: + - Instances of `SelectorIR` + - Multiple schemas + """ + return selector.matches_column(name, dtype) + + +@lru_cache(maxsize=100) +def expand_selector(selector: SelectorIR, *, schema: FrozenSchema) -> Columns: + """Expand `selector` into `Columns`, within the context of `schema`.""" + cols = (k for k, v in schema.items() if selector_matches_column(selector, k, v)) + return Columns(names=tuple(cols)) + + def rewrite_projections( input: Seq[ExprIR], # `FunctionExpr.input` /, From 6414142765b8645f57888db71c7773145162e034 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 12 Jun 2025 15:48:39 +0100 Subject: [PATCH 215/368] refactor: Replace 3x `replace_*` functions with 1 I imagine the `rust` impl benefits from individual functions, but without a compiler we won't --- narwhals/_plan/expr_expansion.py | 41 ++++++++------------------------ 1 file changed, 10 insertions(+), 31 deletions(-) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index f22fe91d9d..23e4756121 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -61,6 +61,7 @@ KeepName, Nth, RenameAlias, + _ColumnSelection, col, ) from narwhals.dtypes import DType @@ -298,38 +299,16 @@ def fn(child: ExprIR, /) -> ExprIR: return origin.map_ir(fn) -def _replace_columns_exclude(origin: ExprIR, /, name: str) -> ExprIR: - """Based on the anonymous function in [`polars_plan::plans::conversion::expr_expansion::expand_columns`]. +def replace_with_column( + origin: ExprIR, tp: type[_ColumnSelection], /, name: str +) -> ExprIR: + """Expand a single column within a multi-selection using `name`. - [`polars_plan::plans::conversion::expr_expansion::expand_columns`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L187-L191 + For `Columns`, `IndexColumns`, `All`. """ def fn(child: ExprIR, /) -> ExprIR: - if isinstance(child, Columns): - return col(name) - if isinstance(child, Exclude): - return child.expr - return child - - return origin.map_ir(fn) - - -def replace_index_with_column(origin: ExprIR, /, name: str) -> ExprIR: - def fn(child: ExprIR, /) -> ExprIR: - if isinstance(child, IndexColumns): - return col(name) - if isinstance(child, Exclude): - return child.expr - return child - - return origin.map_ir(fn) - - -def replace_wildcard_with_column(origin: ExprIR, /, name: str) -> ExprIR: - """`expr.All` and `Exclude`.""" - - def fn(child: ExprIR, /) -> ExprIR: - if isinstance(child, All): + if isinstance(child, tp): return col(name) if isinstance(child, Exclude): return child.expr @@ -466,7 +445,7 @@ def expand_columns( raise ComputeError(msg) for name in columns.names: if name not in exclude: - expanded = _replace_columns_exclude(origin, name) + expanded = replace_with_column(origin, Columns, name) expanded = rewrite_special_aliases(expanded) result.append(expanded) return result @@ -490,7 +469,7 @@ def expand_indices( raise ComputeError(msg) name = names[idx] if name not in exclude: - expanded = replace_index_with_column(origin, name) + expanded = replace_with_column(origin, IndexColumns, name) expanded = rewrite_special_aliases(expanded) result.append(expanded) return result @@ -501,7 +480,7 @@ def replace_wildcard( ) -> ResultIRs: for name in col_names: if name not in exclude: - expanded = replace_wildcard_with_column(origin, name) + expanded = replace_with_column(origin, All, name) expanded = rewrite_special_aliases(expanded) result.append(expanded) return result From 0d0d6a2e2272a92b00305b5031b1605904036aba Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 12 Jun 2025 17:05:30 +0100 Subject: [PATCH 216/368] feat: Support `*args, **kwds` in `when` --- narwhals/_plan/demo.py | 16 ++++++++-------- narwhals/_plan/when_then.py | 29 ++++++++++++++++++++--------- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index 14876fa666..0bc829c19b 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -12,7 +12,6 @@ from narwhals._plan.common import ( ExprIR, into_dtype, - is_expr, is_non_nested_literal, py_to_narwhals_dtype, ) @@ -29,7 +28,7 @@ from narwhals._plan.dummy import DummyExpr from narwhals._plan.expr import SortBy - from narwhals._plan.typing import IntoExpr + from narwhals._plan.typing import IntoExpr, IntoExprColumn from narwhals.typing import IntoDType, NonNestedLiteral @@ -144,7 +143,9 @@ def concat_str( ) -def when(*predicates: IntoExpr | t.Iterable[IntoExpr]) -> When: +def when( + *predicates: IntoExprColumn | t.Iterable[IntoExprColumn], **constraints: t.Any +) -> When: """Start a `when-then-otherwise` expression. Examples: @@ -167,11 +168,10 @@ def when(*predicates: IntoExpr | t.Iterable[IntoExpr]) -> When: Narwhals DummyExpr (main): .when([(col('y')) == (lit(str: b))]).then(lit(int: 1)).otherwise(lit(null)) """ - if builtins.len(predicates) == 1 and is_expr(predicates[0]): - expr = predicates[0] - else: - expr = all_horizontal(*predicates) - return When._from_expr(expr) + condition = parse.parse_predicates_constraints_into_expr_ir( + *predicates, **constraints + ) + return When._from_ir(condition) def _is_order_enforcing_previous(obj: t.Any) -> TypeIs[SortBy]: diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index 9fe4b87687..f236572c55 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -1,15 +1,18 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Iterable from narwhals._plan.common import Immutable, is_expr from narwhals._plan.dummy import DummyExpr -from narwhals._plan.expr_parsing import parse_into_expr_ir +from narwhals._plan.expr_parsing import ( + parse_into_expr_ir, + parse_predicates_constraints_into_expr_ir, +) if TYPE_CHECKING: from narwhals._plan.common import ExprIR from narwhals._plan.expr import Ternary - from narwhals._plan.typing import IntoExpr, Seq + from narwhals._plan.typing import IntoExpr, IntoExprColumn, Seq class When(Immutable): @@ -24,6 +27,10 @@ def then(self, expr: IntoExpr, /) -> Then: def _from_expr(expr: DummyExpr, /) -> When: return When(condition=expr._ir) + @staticmethod + def _from_ir(ir: ExprIR, /) -> When: + return When(condition=ir) + class Then(Immutable, DummyExpr): __slots__ = ("condition", "statement") @@ -31,10 +38,12 @@ class Then(Immutable, DummyExpr): condition: ExprIR statement: ExprIR - def when(self, condition: IntoExpr, /) -> ChainedWhen: + def when( + self, *predicates: IntoExprColumn | Iterable[IntoExprColumn], **constraints: Any + ) -> ChainedWhen: + condition = parse_predicates_constraints_into_expr_ir(*predicates, **constraints) return ChainedWhen( - conditions=(self.condition, parse_into_expr_ir(condition)), - statements=(self.statement,), + conditions=(self.condition, condition), statements=(self.statement,) ) def otherwise(self, statement: IntoExpr, /) -> DummyExpr: @@ -78,10 +87,12 @@ class ChainedThen(Immutable, DummyExpr): conditions: Seq[ExprIR] statements: Seq[ExprIR] - def when(self, condition: IntoExpr, /) -> ChainedWhen: + def when( + self, *predicates: IntoExprColumn | Iterable[IntoExprColumn], **constraints: Any + ) -> ChainedWhen: + condition = parse_predicates_constraints_into_expr_ir(*predicates, **constraints) return ChainedWhen( - conditions=(*self.conditions, parse_into_expr_ir(condition)), - statements=self.statements, + conditions=(*self.conditions, condition), statements=self.statements ) def otherwise(self, statement: IntoExpr, /) -> DummyExpr: From 534cf16ac9728c137ac7e6f78c50a6f61f5675b8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 12 Jun 2025 17:31:08 +0100 Subject: [PATCH 217/368] feat: Add `expr`, `sqrt`, `kurtosis` Keeping things in sync with (#2556) --- narwhals/_plan/dummy.py | 11 +++++++++++ narwhals/_plan/functions.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index bffe012fcf..e0532e42be 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -209,6 +209,17 @@ def hist( def log(self, base: float = math.e) -> Self: return self._from_ir(F.Log(base=base).to_function_expr(self._ir)) + def exp(self) -> Self: + return self._from_ir(F.Exp().to_function_expr(self._ir)) + + def sqrt(self) -> Self: + return self._from_ir(F.Sqrt().to_function_expr(self._ir)) + + def kurtosis(self, *, fisher: bool = True, bias: bool = True) -> Self: + return self._from_ir( + F.Kurtosis(fisher=fisher, bias=bias).to_function_expr(self._ir) + ) + def null_count(self) -> Self: return self._from_ir(F.NullCount().to_function_expr(self._ir)) diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 923f02bbff..1eeb8e6f50 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -92,6 +92,15 @@ def __repr__(self) -> str: return "log" +class Exp(Function): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() + + def __repr__(self) -> str: + return "exp" + + class Pow(Function): @property def function_options(self) -> FunctionOptions: @@ -101,6 +110,29 @@ def __repr__(self) -> str: return "pow" +class Sqrt(Function): + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.elementwise() + + def __repr__(self) -> str: + return "sqrt" + + +class Kurtosis(Function): + __slots__ = ("bias", "fisher") + + fisher: bool + bias: bool + + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.aggregation() + + def __repr__(self) -> str: + return "kurtosis" + + class FillNull(Function): @property def function_options(self) -> FunctionOptions: From c9cb5965f47580ba44bfa2fefeffde4a5f35bbe2 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 12 Jun 2025 19:38:56 +0100 Subject: [PATCH 218/368] feat: Ensure mutability stays within function boundaries --- narwhals/_plan/expr_expansion.py | 51 +++++++++++++------------------- 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 23e4756121..a25cb0713e 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -86,7 +86,6 @@ Excluded: TypeAlias = "frozenset[str]" """Internally use a `set`, then freeze before returning.""" -ResultIRs: TypeAlias = "deque[ExprIR]" _FrozenSchemaHash: TypeAlias = "Seq[tuple[str, DType]]" _T2 = TypeVar("_T2") @@ -368,21 +367,16 @@ def rewrite_projections( if flags.has_selector: expanded = replace_selector(expanded, keys, schema=schema) flags = flags.with_multiple_columns() - result = replace_and_add_to_results( - expanded, result, keys=keys, schema=schema, flags=flags + result.extend( + replace_and_add_to_results(expanded, keys=keys, schema=schema, flags=flags) ) return tuple(result) def replace_and_add_to_results( - origin: ExprIR, - /, - result: ResultIRs, - keys: Seq[ExprIR], - *, - schema: FrozenSchema, - flags: ExpansionFlags, -) -> ResultIRs: + origin: ExprIR, /, keys: Seq[ExprIR], *, schema: FrozenSchema, flags: ExpansionFlags +) -> Seq[ExprIR]: + result: deque[ExprIR] = deque() if flags.has_nth: origin = replace_nth(origin, schema) if flags.expands: @@ -390,20 +384,20 @@ def replace_and_add_to_results( if e := next(it, None): if isinstance(e, Columns): exclude = prepare_excluded(origin, keys=(), has_exclude=flags.has_exclude) - result = expand_columns(origin, result, e, exclude=exclude) + result.extend(expand_columns(origin, e, exclude=exclude)) else: exclude = prepare_excluded( origin, keys=keys, has_exclude=flags.has_exclude ) - result = expand_indices(origin, result, e, schema=schema, exclude=exclude) + result.extend(expand_indices(origin, e, schema=schema, exclude=exclude)) elif flags.has_wildcard: exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude) - result = replace_wildcard(origin, result, col_names=schema.names, exclude=exclude) + result.extend(replace_wildcard(origin, col_names=schema.names, exclude=exclude)) else: exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude) expanded = rewrite_special_aliases(origin) result.append(expanded) - return result + return tuple(result) def _iter_exclude_names(origin: ExprIR, /) -> Iterator[str]: @@ -438,28 +432,24 @@ def _all_columns_match(origin: ExprIR, /, columns: Columns) -> bool: def expand_columns( - origin: ExprIR, /, result: ResultIRs, columns: Columns, *, exclude: Excluded -) -> ResultIRs: + origin: ExprIR, /, columns: Columns, *, exclude: Excluded +) -> Seq[ExprIR]: if not _all_columns_match(origin, columns): msg = "expanding more than one `col` is not allowed" raise ComputeError(msg) + result: deque[ExprIR] = deque() for name in columns.names: if name not in exclude: expanded = replace_with_column(origin, Columns, name) expanded = rewrite_special_aliases(expanded) result.append(expanded) - return result + return tuple(result) def expand_indices( - origin: ExprIR, - /, - result: ResultIRs, - indices: IndexColumns, - *, - schema: FrozenSchema, - exclude: Excluded, -) -> ResultIRs: + origin: ExprIR, /, indices: IndexColumns, *, schema: FrozenSchema, exclude: Excluded +) -> Seq[ExprIR]: + result: deque[ExprIR] = deque() n_fields = len(schema) names = tuple(schema) for index in indices.indices: @@ -472,18 +462,19 @@ def expand_indices( expanded = replace_with_column(origin, IndexColumns, name) expanded = rewrite_special_aliases(expanded) result.append(expanded) - return result + return tuple(result) def replace_wildcard( - origin: ExprIR, /, result: ResultIRs, *, col_names: FrozenColumns, exclude: Excluded -) -> ResultIRs: + origin: ExprIR, /, *, col_names: FrozenColumns, exclude: Excluded +) -> Seq[ExprIR]: + result: deque[ExprIR] = deque() for name in col_names: if name not in exclude: expanded = replace_with_column(origin, All, name) expanded = rewrite_special_aliases(expanded) result.append(expanded) - return result + return tuple(result) def rewrite_special_aliases(origin: ExprIR, /) -> ExprIR: From 03af47edd7472abcd8a3ffbe202f1c7a2a27f9a9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 12 Jun 2025 21:56:35 +0100 Subject: [PATCH 219/368] feat: more consistent index error --- narwhals/_plan/exceptions.py | 10 ++++++++++ narwhals/_plan/expr_expansion.py | 26 +++++++++++++++++++------- tests/plan/expr_expansion_test.py | 12 ++++++++++++ 3 files changed, 41 insertions(+), 7 deletions(-) diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index 0e4dbe0b4b..91436f3358 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -198,3 +198,13 @@ def column_not_found_error( available = tuple(available) missing = set(subset).difference(available) return ColumnNotFoundError.from_missing_and_available_column_names(missing, available) + + +def column_index_error( + index: int, schema_or_column_names: Iterable[str], / +) -> ComputeError: + # NOTE: If the original expression used a negative index, we should use that as well + n_names = len(tuple(schema_or_column_names)) + max_nth = f"`nth({n_names - 1})`" if index >= 0 else f"`nth(-{n_names})`" + msg = f"Invalid column index {index!r}\nHint: The schema's last column is {max_nth}" + return ComputeError(msg) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index a25cb0713e..b30525fc30 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -51,7 +51,11 @@ SelectorIR, is_horizontal_reduction, ) -from narwhals._plan.exceptions import column_not_found_error, duplicate_error +from narwhals._plan.exceptions import ( + column_index_error, + column_not_found_error, + duplicate_error, +) from narwhals._plan.expr import ( Alias, All, @@ -281,14 +285,24 @@ def fn(child: ExprIR, /) -> ExprIR: def replace_nth(origin: ExprIR, /, schema: FrozenSchema) -> ExprIR: + n_fields = len(schema) + names = tuple(schema) + def fn(child: ExprIR, /) -> ExprIR: if isinstance(child, Nth): - return col(schema.names[child.index]) + if not is_index_in_range(child.index, n_fields): + raise column_index_error(child.index, names) + return col(names[child.index]) return child return origin.map_ir(fn) +def is_index_in_range(index: int, n_fields: int) -> bool: + idx = index + n_fields if index < 0 else index + return not (idx < 0 or idx >= n_fields) + + def remove_exclude(origin: ExprIR, /) -> ExprIR: def fn(child: ExprIR, /) -> ExprIR: if isinstance(child, Exclude): @@ -453,11 +467,9 @@ def expand_indices( n_fields = len(schema) names = tuple(schema) for index in indices.indices: - idx = index + n_fields if index < 0 else index - if idx < 0 or idx > n_fields: - msg = f"invalid column index {idx!r}" - raise ComputeError(msg) - name = names[idx] + if not is_index_in_range(index, n_fields): + raise column_index_error(index, names) + name = names[index] if name not in exclude: expanded = replace_with_column(origin, IndexColumns, name) expanded = rewrite_special_aliases(expanded) diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 7364226c0e..2114c90cd3 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -549,3 +549,15 @@ def test_prepare_projection_horizontal_alias( out_irs, _ = prepare_projection(irs, schema_1) assert len(out_irs) == 1 assert out_irs[0] == function("a", "b", "c").alias("alias(x1)").alias("alias(x2)")._ir + + +@pytest.mark.parametrize( + "into_exprs", [nwd.nth(-21), nwd.nth(-1, 2, 54, 0), nwd.nth(20), nwd.nth([-10, -100])] +) +def test_prepare_projection_index_error( + into_exprs: IntoExpr | Iterable[IntoExpr], schema_1: dict[str, DType] +) -> None: + irs = parse_into_seq_of_expr_ir(into_exprs) + pattern = re.compile(r"invalid.+index.+nth", re.DOTALL | re.IGNORECASE) + with pytest.raises(ComputeError, match=pattern): + prepare_projection(irs, schema_1) From 708d6ac1d719d5c5437572568959832ec34b83d8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 13 Jun 2025 19:17:49 +0100 Subject: [PATCH 220/368] refactor: Reduce schema to columns where possible Resolves: - https://github.com/narwhals-dev/narwhals/pull/2572#discussion_r2145590008 - https://github.com/narwhals-dev/narwhals/pull/2572#discussion_r2145593529 --- narwhals/_plan/expr_expansion.py | 42 ++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index b30525fc30..7b174dec52 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -284,15 +284,14 @@ def fn(child: ExprIR, /) -> ExprIR: return origin.map_ir(fn) -def replace_nth(origin: ExprIR, /, schema: FrozenSchema) -> ExprIR: - n_fields = len(schema) - names = tuple(schema) +def replace_nth(origin: ExprIR, /, col_names: FrozenColumns) -> ExprIR: + n_fields = len(col_names) def fn(child: ExprIR, /) -> ExprIR: if isinstance(child, Nth): if not is_index_in_range(child.index, n_fields): - raise column_index_error(child.index, names) - return col(names[child.index]) + raise column_index_error(child.index, col_names) + return col(col_names[child.index]) return child return origin.map_ir(fn) @@ -382,17 +381,24 @@ def rewrite_projections( expanded = replace_selector(expanded, keys, schema=schema) flags = flags.with_multiple_columns() result.extend( - replace_and_add_to_results(expanded, keys=keys, schema=schema, flags=flags) + replace_and_add_to_results( + expanded, keys=keys, col_names=schema.names, flags=flags + ) ) return tuple(result) def replace_and_add_to_results( - origin: ExprIR, /, keys: Seq[ExprIR], *, schema: FrozenSchema, flags: ExpansionFlags + origin: ExprIR, + /, + keys: Seq[ExprIR], + *, + col_names: FrozenColumns, + flags: ExpansionFlags, ) -> Seq[ExprIR]: result: deque[ExprIR] = deque() if flags.has_nth: - origin = replace_nth(origin, schema) + origin = replace_nth(origin, col_names) if flags.expands: it = (e for e in origin.iter_left() if isinstance(e, (Columns, IndexColumns))) if e := next(it, None): @@ -403,10 +409,12 @@ def replace_and_add_to_results( exclude = prepare_excluded( origin, keys=keys, has_exclude=flags.has_exclude ) - result.extend(expand_indices(origin, e, schema=schema, exclude=exclude)) + result.extend( + expand_indices(origin, e, col_names=col_names, exclude=exclude) + ) elif flags.has_wildcard: exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude) - result.extend(replace_wildcard(origin, col_names=schema.names, exclude=exclude)) + result.extend(replace_wildcard(origin, col_names=col_names, exclude=exclude)) else: exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude) expanded = rewrite_special_aliases(origin) @@ -461,15 +469,19 @@ def expand_columns( def expand_indices( - origin: ExprIR, /, indices: IndexColumns, *, schema: FrozenSchema, exclude: Excluded + origin: ExprIR, + /, + indices: IndexColumns, + *, + col_names: FrozenColumns, + exclude: Excluded, ) -> Seq[ExprIR]: result: deque[ExprIR] = deque() - n_fields = len(schema) - names = tuple(schema) + n_fields = len(col_names) for index in indices.indices: if not is_index_in_range(index, n_fields): - raise column_index_error(index, names) - name = names[index] + raise column_index_error(index, col_names) + name = col_names[index] if name not in exclude: expanded = replace_with_column(origin, IndexColumns, name) expanded = rewrite_special_aliases(expanded) From 7dbc3800bf42691384a4d21fc3fc861cc32ebab2 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 13 Jun 2025 19:25:14 +0100 Subject: [PATCH 221/368] typo --- narwhals/_plan/strings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/strings.py index bb675e0896..bcf2bb4aee 100644 --- a/narwhals/_plan/strings.py +++ b/narwhals/_plan/strings.py @@ -118,7 +118,7 @@ class StartsWith(StringFunction): prefix: str def __repr__(self) -> str: - return "str.startswith" + return "str.starts_with" class StripChars(StringFunction): From e7e17a70a9ad8e4d59c9976c2afc6d98a09ec5e7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 14 Jun 2025 20:42:08 +0100 Subject: [PATCH 222/368] feat: Add `_repr_html_` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Had no idea this was all we needed to do for the fancy notebook repr 😄 --- narwhals/_plan/common.py | 3 +++ narwhals/_plan/dummy.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index c00be99779..db9b70e47a 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -240,6 +240,9 @@ def meta(self) -> IRMetaNamespace: return IRMetaNamespace(_ir=self) + def _repr_html_(self) -> str: + return self.__repr__() + class SelectorIR(ExprIR): def to_narwhals(self, version: Version = Version.MAIN) -> DummySelector: diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index e0532e42be..cb3d25ff50 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -66,6 +66,9 @@ def __str__(self) -> str: """Use `print(self)` for formatting.""" return f"Narwhals DummyExpr ({self.version.name.lower()}):\n{self._ir!s}" + def _repr_html_(self) -> str: + return self._ir._repr_html_() + @classmethod def _from_ir(cls, ir: ExprIR, /) -> Self: obj = cls.__new__(cls) From 12d7c965b10c5ca3689bceaedf0ce33cd163af08 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 15 Jun 2025 10:17:00 +0100 Subject: [PATCH 223/368] chore(ruff): Update for `3.9` typing See #2679 --- narwhals/_plan/aggregation.py | 2 +- narwhals/_plan/common.py | 3 ++- narwhals/_plan/exceptions.py | 3 ++- narwhals/_plan/expr_expansion.py | 9 ++++++++- narwhals/_plan/expr_parsing.py | 7 +++++-- narwhals/_plan/meta.py | 2 +- narwhals/_plan/name.py | 2 +- narwhals/_plan/selectors.py | 6 ++++-- narwhals/_plan/when_then.py | 4 +++- tests/plan/expr_expansion_test.py | 4 +++- tests/plan/expr_parsing_test.py | 9 +++++---- 11 files changed, 35 insertions(+), 16 deletions(-) diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index e8013cc02b..8e1cfc780d 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -6,7 +6,7 @@ from narwhals._plan.exceptions import agg_scalar_error if TYPE_CHECKING: - from typing import Iterator + from collections.abc import Iterator from typing_extensions import Self diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index db9b70e47a..191add3451 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -9,7 +9,8 @@ from narwhals.utils import Version if TYPE_CHECKING: - from typing import Any, Callable, Iterable, Iterator, Literal + from collections.abc import Iterable, Iterator + from typing import Any, Callable, Literal from typing_extensions import Never, Self, TypeIs, dataclass_transform diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index 91436f3358..8a7a76b02e 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -18,7 +18,8 @@ ) if TYPE_CHECKING: - from typing import Any, Iterable + from collections.abc import Iterable + from typing import Any import pandas as pd import polars as pl diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 7b174dec52..50dc094d1a 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -77,7 +77,14 @@ ) if TYPE_CHECKING: - from typing import ItemsView, Iterator, KeysView, Mapping, Sequence, ValuesView + from collections.abc import ( + ItemsView, + Iterator, + KeysView, + Mapping, + Sequence, + ValuesView, + ) from typing_extensions import TypeAlias diff --git a/narwhals/_plan/expr_parsing.py b/narwhals/_plan/expr_parsing.py index 0c31a6e157..fc480a0f66 100644 --- a/narwhals/_plan/expr_parsing.py +++ b/narwhals/_plan/expr_parsing.py @@ -1,8 +1,10 @@ from __future__ import annotations +from collections.abc import Iterable, Sequence + # ruff: noqa: A002 from itertools import chain -from typing import TYPE_CHECKING, Iterable, Sequence, TypeVar +from typing import TYPE_CHECKING, TypeVar from narwhals._plan.common import is_expr, is_iterable_reject from narwhals._plan.exceptions import ( @@ -13,7 +15,8 @@ from narwhals.dependencies import get_polars, is_pandas_dataframe, is_pandas_series if TYPE_CHECKING: - from typing import Any, Iterator + from collections.abc import Iterator + from typing import Any import polars as pl from typing_extensions import TypeAlias, TypeIs diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index fd14e1d793..2399ecbaef 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -13,7 +13,7 @@ from narwhals.utils import Version if TYPE_CHECKING: - from typing import Iterator + from collections.abc import Iterator from typing_extensions import TypeIs diff --git a/narwhals/_plan/name.py b/narwhals/_plan/name.py index 2e8d6a2233..a393b4b63b 100644 --- a/narwhals/_plan/name.py +++ b/narwhals/_plan/name.py @@ -5,7 +5,7 @@ from narwhals._plan.common import ExprIR, ExprNamespace, Immutable, IRNamespace if TYPE_CHECKING: - from typing import Iterator + from collections.abc import Iterator from typing_extensions import Self diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/selectors.py index 0fa94bfec7..af3524180f 100644 --- a/narwhals/_plan/selectors.py +++ b/narwhals/_plan/selectors.py @@ -7,14 +7,16 @@ from __future__ import annotations import re -from typing import TYPE_CHECKING, Iterable +from collections.abc import Iterable +from typing import TYPE_CHECKING from narwhals._plan.common import Immutable, is_iterable_reject from narwhals._utils import Version, _parse_time_unit_and_time_zone if TYPE_CHECKING: + from collections.abc import Iterator from datetime import timezone - from typing import Iterator, TypeVar + from typing import TypeVar from narwhals._plan.dummy import DummySelector from narwhals._plan.expr import RootSelector diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index f236572c55..ab9ae66933 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable +from typing import TYPE_CHECKING, Any from narwhals._plan.common import Immutable, is_expr from narwhals._plan.dummy import DummyExpr @@ -10,6 +10,8 @@ ) if TYPE_CHECKING: + from collections.abc import Iterable + from narwhals._plan.common import ExprIR from narwhals._plan.expr import Ternary from narwhals._plan.typing import IntoExpr, IntoExprColumn, Seq diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 2114c90cd3..270181ea55 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from typing import TYPE_CHECKING, Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Callable import pytest @@ -19,6 +19,8 @@ from tests.plan.utils import assert_expr_ir_equal if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + from narwhals._plan.common import ExprIR from narwhals._plan.dummy import DummyExpr, DummySelector from narwhals._plan.typing import IntoExpr, MapIR diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index f8cec7232d..61e3c4d926 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -2,8 +2,9 @@ import re from collections import deque +from collections.abc import Iterable, Sequence from contextlib import nullcontext -from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable import pytest @@ -26,7 +27,7 @@ from tests.plan.utils import assert_expr_ir_equal if TYPE_CHECKING: - from typing import ContextManager + from contextlib import AbstractContextManager from typing_extensions import TypeAlias @@ -299,7 +300,7 @@ def test_is_in_series() -> None: ), ], ) -def test_invalid_is_in(other: Any, context: ContextManager[Any]) -> None: +def test_invalid_is_in(other: Any, context: AbstractContextManager[Any]) -> None: with context: nwd.col("a").is_in(other) @@ -353,7 +354,7 @@ def test_filter_full_spellings() -> None: def test_filter_partial_spellings( predicates: Iterable[IntoExprColumn], constraints: dict[str, Any], - context: ContextManager[Any], + context: AbstractContextManager[Any], ) -> None: with context: assert nwd.col("a").filter(*predicates, **constraints) From 1de65d2f82ace95a9bc72667067ffdfa9d28be6d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 18 Jun 2025 13:41:06 +0100 Subject: [PATCH 224/368] fix: More consistent `__str__` Noticed some weird bits when I wrote https://discord.com/channels/1235257048170762310/1383078215303696544/1384869107203047588 All of this is easy to avoid by ignoring the `ruff` lint that introduced the need for fixing manually https://docs.astral.sh/ruff/rules/unsorted-dunder-slots/ --- narwhals/_plan/expr.py | 33 ++++++--------------------------- 1 file changed, 6 insertions(+), 27 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 2cce4f5d8b..82fbed478d 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -10,7 +10,6 @@ from narwhals._plan.common import ( ExprIR, SelectorIR, - _field_str, collect, is_non_nested_literal, is_regex_projection, @@ -305,7 +304,7 @@ def map_ir(self, function: MapIR, /) -> ExprIR: class Cast(ExprIR): - __slots__ = ("dtype", "expr") + __slots__ = ("expr", "dtype") # noqa: RUF023 expr: ExprIR dtype: DType @@ -364,7 +363,7 @@ def with_expr(self, expr: ExprIR, /) -> Self: class SortBy(ExprIR): """https://github.com/narwhals-dev/narwhals/issues/2534.""" - __slots__ = ("by", "expr", "options") + __slots__ = ("expr", "by", "options") # noqa: RUF023 expr: ExprIR by: Seq[ExprIR] @@ -486,7 +485,7 @@ class AnonymousExpr(FunctionExpr["MapBatches"]): class Filter(ExprIR): - __slots__ = ("by", "expr") + __slots__ = ("expr", "by") # noqa: RUF023 expr: ExprIR by: ExprIR @@ -528,7 +527,7 @@ class WindowExpr(ExprIR): - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/mod.rs#L840-L876 """ - __slots__ = ("expr", "options", "partition_by") + __slots__ = ("expr", "partition_by", "options") # noqa: RUF023 expr: ExprIR """Renamed from `function`. @@ -548,12 +547,6 @@ class WindowExpr(ExprIR): def __repr__(self) -> str: return f"{self.expr!r}.over({list(self.partition_by)!r})" - def __str__(self) -> str: - args = ( - f"expr={self.expr}, partition_by={self.partition_by}, options={self.options}" - ) - return f"{type(self).__name__}({args})" - def iter_left(self) -> t.Iterator[ExprIR]: yield from self.expr.iter_left() for e in self.partition_by: @@ -586,7 +579,7 @@ def with_partition_by(self, partition_by: t.Iterable[ExprIR], /) -> Self: # TODO @dangotbanned: Reduce repetition from `WindowExpr` class OrderedWindowExpr(WindowExpr): - __slots__ = ("expr", "options", "order_by", "partition_by", "sort_options") + __slots__ = ("expr", "partition_by", "order_by", "sort_options", "options") # noqa: RUF023 expr: ExprIR partition_by: Seq[ExprIR] @@ -606,11 +599,6 @@ def __repr__(self) -> str: args = f"partition_by={list(self.partition_by)!r}, order_by={list(order)!r}" return f"{self.expr!r}.over({args})" - def __str__(self) -> str: - order_by = f"({self.order_by}, {self.sort_options})" - args = f"expr={self.expr}, partition_by={self.partition_by}, order_by={order_by}, options={self.options}" - return f"{type(self).__name__}({args})" - def iter_left(self) -> t.Iterator[ExprIR]: yield from self.expr.iter_left() for e in self.partition_by: @@ -750,7 +738,7 @@ def map_ir(self, function: MapIR, /) -> ExprIR: class Ternary(ExprIR): """When-Then-Otherwise.""" - __slots__ = ("falsy", "predicate", "truthy") + __slots__ = ("predicate", "truthy", "falsy") # noqa: RUF023 predicate: ExprIR truthy: ExprIR @@ -760,15 +748,6 @@ class Ternary(ExprIR): def is_scalar(self) -> bool: return self.predicate.is_scalar and self.truthy.is_scalar and self.falsy.is_scalar - def __str__(self) -> str: - # NOTE: Default slot ordering made it difficult to read - fields = ( - _field_str("predicate", self.predicate), - _field_str("truthy", self.truthy), - _field_str("falsy", self.falsy), - ) - return f"{type(self).__name__}({', '.join(fields)})" - def __repr__(self) -> str: return ( f".when({self.predicate!r}).then({self.truthy!r}).otherwise({self.falsy!r})" From 11f1e1bd3142012efa17d727a26da39a96a2ee86 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 22 Jun 2025 12:02:23 +0100 Subject: [PATCH 225/368] refactor: Move, document `GroupByKeys` - The way I've done selectors has meant `replace_selector(key)` is unused - That replacement has moved to `replace_and_add_to_results` on the `Columns` branch - So that part is now parameterized --- narwhals/_plan/expr_expansion.py | 31 ++++++++++++++++--------------- tests/plan/expr_expansion_test.py | 3 +-- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 50dc094d1a..733cfbe765 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -97,6 +97,13 @@ Excluded: TypeAlias = "frozenset[str]" """Internally use a `set`, then freeze before returning.""" +GroupByKeys: TypeAlias = "Seq[ExprIR]" +"""Represents group_by keys. + +- Originates from `polars_plan::plans::conversion::dsl_to_ir::resolve_group_by` +- Not fully utilized in `narwhals` version yet +""" + _FrozenSchemaHash: TypeAlias = "Seq[tuple[str, DType]]" _T2 = TypeVar("_T2") @@ -336,13 +343,7 @@ def fn(child: ExprIR, /) -> ExprIR: return origin.map_ir(fn) -def replace_selector( - ir: ExprIR, - /, - keys: Seq[ExprIR], # noqa: ARG001 - *, - schema: FrozenSchema, -) -> ExprIR: +def replace_selector(ir: ExprIR, /, *, schema: FrozenSchema) -> ExprIR: """Fully diverging from `polars`, we'll see how that goes.""" def fn(child: ExprIR, /) -> ExprIR: @@ -374,10 +375,8 @@ def expand_selector(selector: SelectorIR, *, schema: FrozenSchema) -> Columns: def rewrite_projections( input: Seq[ExprIR], # `FunctionExpr.input` /, - keys: Seq[ - ExprIR - ], # NOTE: Mutable (empty) array initialized on call (except in `polars_plan::plans::conversion::dsl_to_ir::resolve_group_by`) - *, # NOTE: Represents group_by keys + keys: GroupByKeys, + *, schema: FrozenSchema, ) -> Seq[ExprIR]: result: deque[ExprIR] = deque() @@ -385,7 +384,7 @@ def rewrite_projections( expanded = expand_function_inputs(expr, schema=schema) flags = ExpansionFlags.from_ir(expanded) if flags.has_selector: - expanded = replace_selector(expanded, keys, schema=schema) + expanded = replace_selector(expanded, schema=schema) flags = flags.with_multiple_columns() result.extend( replace_and_add_to_results( @@ -398,7 +397,7 @@ def rewrite_projections( def replace_and_add_to_results( origin: ExprIR, /, - keys: Seq[ExprIR], + keys: GroupByKeys, *, col_names: FrozenColumns, flags: ExpansionFlags, @@ -410,7 +409,9 @@ def replace_and_add_to_results( it = (e for e in origin.iter_left() if isinstance(e, (Columns, IndexColumns))) if e := next(it, None): if isinstance(e, Columns): - exclude = prepare_excluded(origin, keys=(), has_exclude=flags.has_exclude) + exclude = prepare_excluded( + origin, keys=keys, has_exclude=flags.has_exclude + ) result.extend(expand_columns(origin, e, exclude=exclude)) else: exclude = prepare_excluded( @@ -437,7 +438,7 @@ def _iter_exclude_names(origin: ExprIR, /) -> Iterator[str]: def prepare_excluded( - origin: ExprIR, /, keys: Seq[ExprIR], *, has_exclude: bool + origin: ExprIR, /, keys: GroupByKeys, *, has_exclude: bool ) -> Excluded: """Huge simplification of [`polars_plan::plans::conversion::expr_expansion::prepare_excluded`]. diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 270181ea55..dc476a8a77 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -252,8 +252,7 @@ def test_replace_selector( expected: DummyExpr | ExprIR, schema_1: dict[str, DType], ) -> None: - group_by_keys = () - actual = replace_selector(expr._ir, group_by_keys, schema=freeze_schema(**schema_1)) + actual = replace_selector(expr._ir, schema=freeze_schema(**schema_1)) assert_expr_ir_equal(actual, expected) From e17ab21338ee1f212f5279928dd2f38d1fd1024b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 22 Jun 2025 21:30:10 +0100 Subject: [PATCH 226/368] refactor: Add some `is_*_expr` guards Needing these more with the rewrites stuff --- narwhals/_plan/common.py | 16 +++++++++++++--- narwhals/_plan/operators.py | 6 +++--- narwhals/_plan/window.py | 8 +++----- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 191add3451..2f07a91c79 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -15,7 +15,7 @@ from typing_extensions import Never, Self, TypeIs, dataclass_transform from narwhals._plan.dummy import DummyExpr, DummySelector, DummySeries - from narwhals._plan.expr import FunctionExpr + from narwhals._plan.expr import FunctionExpr, WindowExpr from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.options import FunctionOptions from narwhals.typing import IntoDType, NonNestedDType, NonNestedLiteral @@ -363,10 +363,20 @@ def is_regex_projection(name: str) -> bool: return name.startswith("^") and name.endswith("$") -def is_horizontal_reduction(obj: FunctionExpr[Any] | Any) -> TypeIs[FunctionExpr[Any]]: +def is_window_expr(obj: Any) -> TypeIs[WindowExpr]: + from narwhals._plan.expr import WindowExpr + + return isinstance(obj, WindowExpr) + + +def is_function_expr(obj: Any) -> TypeIs[FunctionExpr[Any]]: from narwhals._plan.expr import FunctionExpr - return isinstance(obj, FunctionExpr) and obj.options.is_input_wildcard_expansion() + return isinstance(obj, FunctionExpr) + + +def is_horizontal_reduction(obj: FunctionExpr[Any] | Any) -> TypeIs[FunctionExpr[Any]]: + return is_function_expr(obj) and obj.options.is_input_wildcard_expansion() def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) -> DType: diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index a702a60d85..7c22346900 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -3,13 +3,13 @@ import operator from typing import TYPE_CHECKING -from narwhals._plan.common import Immutable +from narwhals._plan.common import Immutable, is_function_expr from narwhals._plan.exceptions import ( binary_expr_length_changing_error, binary_expr_multi_output_error, binary_expr_shape_error, ) -from narwhals._plan.expr import BinarySelector, FunctionExpr +from narwhals._plan.expr import BinarySelector if TYPE_CHECKING: from typing import Any, ClassVar @@ -76,7 +76,7 @@ def __call__(self, lhs: Any, rhs: Any) -> Any: def _is_filtration(ir: ExprIR) -> bool: - if not ir.is_scalar and isinstance(ir, FunctionExpr): + if not ir.is_scalar and is_function_expr(ir): return not ir.options.is_elementwise() return False diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py index 3d16caa96b..5c1eafa7a4 100644 --- a/narwhals/_plan/window.py +++ b/narwhals/_plan/window.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import Immutable +from narwhals._plan.common import Immutable, is_function_expr, is_window_expr from narwhals._plan.exceptions import ( over_elementwise_error, over_nested_error, @@ -33,11 +33,9 @@ def _validate_over( sort_options: SortOptions | None = None, /, ) -> InvalidOperationError | None: - from narwhals._plan.expr import FunctionExpr, WindowExpr - - if isinstance(expr, WindowExpr): + if is_window_expr(expr): return over_nested_error(expr, partition_by, order_by, sort_options) - if isinstance(expr, FunctionExpr): + if is_function_expr(expr): if expr.options.is_elementwise(): return over_elementwise_error(expr, partition_by, order_by, sort_options) if expr.options.is_row_separable(): From 12ebe0ccb0a7ccc1b1b41d00051841c7c2e2e6c1 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 22 Jun 2025 21:42:04 +0100 Subject: [PATCH 227/368] docs: lil note on `prepare_projection` --- narwhals/_plan/expr_expansion.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 733cfbe765..9a1062e0ed 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -252,6 +252,18 @@ def with_multiple_columns(self) -> ExpansionFlags: def prepare_projection( exprs: Sequence[ExprIR], schema: Mapping[str, DType] | FrozenSchema ) -> tuple[Seq[ExprIR], FrozenSchema]: + """Expand IRs into named column selections. + + **Primary entry-point**, will be used by `select`, `with_columns`, + and any other context that requires resolving expression names. + + Arguments: + exprs: IRs that *may* contain things like `Columns`, `SelectorIR`, `Exclude`, etc. + schema: Scope to expand multi-column selectors in. + + Returns: + `exprs`, rewritten using `Column(name)` only. + """ frozen_schema = ( schema if isinstance(schema, FrozenSchema) else freeze_schema(**schema) ) From 44f7602f27f0d32d068df8b8c327767352ddf639 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 22 Jun 2025 21:53:17 +0100 Subject: [PATCH 228/368] feat(DRAFT): Add `rewrite_elementwise_over` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Should do the job for what @MarcoGorelli wanted - Run out of time today, so pushing w/ many todos 😄 --- narwhals/_plan/expr_rewrites.py | 68 +++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 narwhals/_plan/expr_rewrites.py diff --git a/narwhals/_plan/expr_rewrites.py b/narwhals/_plan/expr_rewrites.py new file mode 100644 index 0000000000..d8fa83d51c --- /dev/null +++ b/narwhals/_plan/expr_rewrites.py @@ -0,0 +1,68 @@ +"""Post-`expr_expansion` rewrites, in a similar style.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from narwhals._plan.common import is_function_expr, is_window_expr +from narwhals._plan.expr_expansion import prepare_projection + +if TYPE_CHECKING: + from narwhals._plan.common import ExprIR + +select_context_ish = prepare_projection + + +# TODO @dangotbanned: Tests +# TODO @dangotbanned: Review if `inputs` is always `len(1)`` after `prepare_projection` +def rewrite_elementwise_over(origin: ExprIR, /) -> ExprIR: + """Requested in [discord-0]. + + Before: + + nw.col("a").sum().abs().over("b") + + After: + + nw.col("a").sum().over("b").abs() + + [discord-0]: https://discord.com/channels/1235257048170762310/1383078215303696544/1384807793512677398 + """ + + def fn(child: ExprIR, /) -> ExprIR: + if ( + is_window_expr(child) + and is_function_expr(child.expr) + and child.expr.options.is_elementwise() + ): + # NOTE: Aliasing isn't required, but it does help readability + window = child + func = child.expr + if len(func.input) != 1: + msg = ( + f"Expected function inputs to have been expanded, " + f"but got {len(func.input)!r} inputs at: {func}" + ) + raise NotImplementedError(msg) + return func.with_input([window.with_expr(func.input[0])]) + return child + + return origin.map_ir(fn) + + +# TODO @dangotbanned: Full implementation +def rewrite_binary_agg_over(origin: ExprIR, /) -> ExprIR: + """Requested in [discord-1], clarified in [discord-2]. + + Before: + + (nw.col("a") - nw.col("a").mean()).over("b") + + After: + + nw.col("a") - nw.col("a").mean().over("b") + + [discord-1]: https://discord.com/channels/1235257048170762310/1383078215303696544/1384850753008435372 + [discord-2]: https://discord.com/channels/1235257048170762310/1383078215303696544/1384869107203047588 + """ + raise NotImplementedError From 576afa9b39707137570cb373760357b6ac38b938 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 23 Jun 2025 13:54:17 +0100 Subject: [PATCH 229/368] feat: Add a basic rewrite composer --- narwhals/_plan/expr_rewrites.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/expr_rewrites.py b/narwhals/_plan/expr_rewrites.py index d8fa83d51c..518524a3db 100644 --- a/narwhals/_plan/expr_rewrites.py +++ b/narwhals/_plan/expr_rewrites.py @@ -4,13 +4,37 @@ from typing import TYPE_CHECKING +from narwhals._plan import expr_parsing as parse from narwhals._plan.common import is_function_expr, is_window_expr from narwhals._plan.expr_expansion import prepare_projection if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from narwhals._plan.common import ExprIR + from narwhals._plan.typing import IntoExpr, MapIR, Seq + from narwhals.dtypes import DType + + +def rewrite_all( + *exprs: IntoExpr, schema: Mapping[str, DType], rewrites: Sequence[MapIR] +) -> Seq[ExprIR]: + """Very naive approach, but should work for a demo. + + - Assumes all of `rewrites` ends with a `ExprIR.map_ir` call + - Applying multiple functions should be happening at a lower level + - Currently we do a full traversal of each tree per-rewrite function + - There's no caching *after* `prepare_projection` yet + """ + out_irs, _ = prepare_projection(parse.parse_into_seq_of_expr_ir(*exprs), schema) + return tuple(_rewrite_sequential(ir, rewrites) for ir in out_irs) + -select_context_ish = prepare_projection +def _rewrite_sequential(origin: ExprIR, rewrites: Sequence[MapIR], /) -> ExprIR: + result = origin + for fn in rewrites: + result = fn(result) + return result # TODO @dangotbanned: Tests From 2ceadc43044e4e9e0fd6a3e414fe76ce0fed9b28 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 23 Jun 2025 13:55:15 +0100 Subject: [PATCH 230/368] test: `test_rewrite_elementwise_over_(simple|multiple)` --- tests/plan/expr_rewrites_test.py | 75 ++++++++++++++++++++++++++++++++ tests/plan/utils.py | 25 +++++++++-- 2 files changed, 96 insertions(+), 4 deletions(-) create mode 100644 tests/plan/expr_rewrites_test.py diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py new file mode 100644 index 0000000000..6883664f53 --- /dev/null +++ b/tests/plan/expr_rewrites_test.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +from narwhals._plan import demo as nwd, expr_parsing as parse +from narwhals._plan.expr import WindowExpr +from narwhals._plan.expr_rewrites import rewrite_all, rewrite_elementwise_over +from narwhals._plan.window import Over +from narwhals.exceptions import InvalidOperationError +from tests.plan.utils import assert_expr_ir_equal + +if TYPE_CHECKING: + from narwhals._plan.typing import IntoExpr + from narwhals.dtypes import DType + + +@pytest.fixture +def schema_2() -> dict[str, DType]: + return { + "a": nw.Int64(), + "b": nw.Int64(), + "c": nw.Int64(), + "d": nw.Int64(), + "e": nw.Int64(), + "f": nw.String(), + "g": nw.String(), + "h": nw.String(), + "i": nw.Boolean(), + "j": nw.Boolean(), + "k": nw.Boolean(), + } + + +def _to_window_expr(into_expr: IntoExpr, *partition_by: IntoExpr) -> WindowExpr: + return WindowExpr( + expr=parse.parse_into_expr_ir(into_expr), + partition_by=parse.parse_into_seq_of_expr_ir(*partition_by), + options=Over(), + ) + + +def test_rewrite_elementwise_over_simple(schema_2: dict[str, DType]) -> None: + with pytest.raises(InvalidOperationError, match=r"over.+elementwise"): + nwd.col("a").sum().abs().over("b") + + # NOTE: Since the requested "before" expression is currently an error (at definition time), + # we need to manually build the IR - to sidestep the validation in `Over.to_window_expr`. + # Later, that error might not be needed if we can do this rewrite. + # If you're here because of a "Did not raise" - just replace everything with the (previously) erroring expr. + expected = nwd.col("a").sum().over("b").abs() + before = _to_window_expr(nwd.col("a").sum().abs(), "b").to_narwhals() + assert_expr_ir_equal(before, "col('a').sum().abs().over([col('b')])") + actual = rewrite_all(before, schema=schema_2, rewrites=[rewrite_elementwise_over]) + assert len(actual) == 1 + assert_expr_ir_equal(actual[0], expected) + + +def test_rewrite_elementwise_over_multiple(schema_2: dict[str, DType]) -> None: + expected = ( + nwd.col("b").last().over("d").replace_strict({1: 2}), + nwd.col("c").last().over("d").replace_strict({1: 2}), + ) + before = _to_window_expr( + nwd.col("b", "c").last().replace_strict({1: 2}), "d" + ).to_narwhals() + assert_expr_ir_equal( + before, "cols(['b', 'c']).last().replace_strict().over([col('d')])" + ) + actual = rewrite_all(before, schema=schema_2, rewrites=[rewrite_elementwise_over]) + assert len(actual) == 2 + for lhs, rhs in zip(actual, expected): + assert_expr_ir_equal(lhs, rhs) diff --git a/tests/plan/utils.py b/tests/plan/utils.py index 641ccddb19..b20e4aba8a 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -5,11 +5,28 @@ from narwhals._plan.common import is_expr if TYPE_CHECKING: + from typing_extensions import LiteralString + from narwhals._plan.common import ExprIR from narwhals._plan.dummy import DummyExpr -def assert_expr_ir_equal(left: DummyExpr | ExprIR, right: DummyExpr | ExprIR) -> None: - lhs = left._ir if is_expr(left) else left - rhs = right._ir if is_expr(right) else right - assert lhs == rhs +def assert_expr_ir_equal( + actual: DummyExpr | ExprIR, expected: DummyExpr | ExprIR | LiteralString, / +) -> None: + """Assert that `actual` is equivalent to `expected`. + + Arguments: + actual: Result expression or IR to compare. + expected: Target expression, IR, or repr to compare. + + Notes: + Performing a repr comparison is more fragile, so should be avoided + *unless* we raise an error at creation time. + """ + lhs = actual._ir if is_expr(actual) else actual + if isinstance(expected, str): + assert repr(lhs) == expected + else: + rhs = expected._ir if is_expr(expected) else expected + assert lhs == rhs From 357d419daf2184fbbe2dd7aca1554f71099f598f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 23 Jun 2025 15:55:04 +0100 Subject: [PATCH 231/368] perf: Add safe caching to `meta.output_name` --- narwhals/_plan/meta.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 2399ecbaef..b9a85d25f5 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -6,6 +6,7 @@ from __future__ import annotations +from functools import lru_cache from typing import TYPE_CHECKING, Literal, overload from narwhals._plan.common import IRNamespace @@ -75,6 +76,8 @@ def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: ok_or_err = _expr_output_name(self._ir) if isinstance(ok_or_err, ComputeError): if raise_if_undetermined: + # NOTE: See (https://github.com/narwhals-dev/narwhals/pull/2572#discussion_r2161824883) + _expr_output_name.cache_clear() raise ok_or_err return None return ok_or_err @@ -128,6 +131,7 @@ def _expr_to_leaf_column_name(ir: ExprIR) -> str | ComputeError: return ComputeError(msg) +@lru_cache(maxsize=32) def _expr_output_name(ir: ExprIR) -> str | ComputeError: from narwhals._plan import expr From 2b8aea5ed01eda3090469493fa8022fef346750c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 23 Jun 2025 16:29:15 +0100 Subject: [PATCH 232/368] prep for `NamedIR` - Realised the rewrites will be simpler if all `Alias` nodes are removed - Now, all the info we need is extracted - so that'll be safe to do next - See https://github.com/narwhals-dev/narwhals/blob/2ceadc43044e4e9e0fd6a3e414fe76ce0fed9b28/tests/plan/expr_expansion_test.py#L316-L319 --- narwhals/_plan/common.py | 19 +++++++++++++- narwhals/_plan/expr_expansion.py | 42 +++++++++++++++---------------- narwhals/_plan/expr_rewrites.py | 2 +- narwhals/_plan/typing.py | 1 + tests/plan/expr_expansion_test.py | 6 ++--- 5 files changed, 43 insertions(+), 27 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 2f07a91c79..f520e0ebce 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -4,7 +4,7 @@ from decimal import Decimal from typing import TYPE_CHECKING, Generic, TypeVar -from narwhals._plan.typing import ExprT, IRNamespaceT, MapIR, Ns, Seq +from narwhals._plan.typing import ExprIRT, ExprT, IRNamespaceT, MapIR, Ns, Seq from narwhals.dtypes import DType from narwhals.utils import Version @@ -263,6 +263,23 @@ def matches_column(self, name: str, dtype: DType) -> bool: raise NotImplementedError(type(self)) +class NamedIR(Immutable, Generic[ExprIRT]): + """Post-projection expansion wrapper for `ExprIR`. + + - Somewhat similar to [`polars_plan::plans::expr_ir::ExprIR`]. + - The [`polars_plan::plans::aexpr::AExpr`] stage has been skipped (*for now*) + - Parts of that will probably be in here too + - `AExpr` seems like too much duplication when we won't get the memory allocation benefits in python + + [`polars_plan::plans::expr_ir::ExprIR`]: https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/expr_ir.rs#L63-L74 + [`polars_plan::plans::aexpr::AExpr`]: https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/aexpr/mod.rs#L145-L231 + """ + + __slots__ = ("expr", "name") + expr: ExprIRT + name: str + + class IRNamespace(Immutable): __slots__ = ("_ir",) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 9a1062e0ed..6791f8f3b9 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -69,12 +69,7 @@ col, ) from narwhals.dtypes import DType -from narwhals.exceptions import ( - ColumnNotFoundError, - ComputeError, - DuplicateError, - InvalidOperationError, -) +from narwhals.exceptions import ComputeError, InvalidOperationError if TYPE_CHECKING: from collections.abc import ( @@ -104,6 +99,9 @@ - Not fully utilized in `narwhals` version yet """ +OutputNames: TypeAlias = "Seq[str]" +"""Fully expanded, validated output column names, for `NamedIR`s.""" + _FrozenSchemaHash: TypeAlias = "Seq[tuple[str, DType]]" _T2 = TypeVar("_T2") @@ -251,7 +249,7 @@ def with_multiple_columns(self) -> ExpansionFlags: def prepare_projection( exprs: Sequence[ExprIR], schema: Mapping[str, DType] | FrozenSchema -) -> tuple[Seq[ExprIR], FrozenSchema]: +) -> tuple[Seq[ExprIR], FrozenSchema, OutputNames]: """Expand IRs into named column selections. **Primary entry-point**, will be used by `select`, `with_columns`, @@ -268,28 +266,28 @@ def prepare_projection( schema if isinstance(schema, FrozenSchema) else freeze_schema(**schema) ) rewritten = rewrite_projections(tuple(exprs), keys=(), schema=frozen_schema) - if err := ensure_valid_exprs(rewritten, frozen_schema): - raise err - return rewritten, frozen_schema + output_names = ensure_valid_exprs(rewritten, frozen_schema) + # TODO @dangotbanned: (Seq[ExprIR], OutputNames) -> (Seq[NamedIR]) + # TODO @dangotbanned: Return a new schema, with the changes (name only) from projecting exprs + # - `select` (subset from schema, maybe need root names as well?) + # - `with_columns` https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs#L1045-L1079 + return rewritten, frozen_schema, output_names -def ensure_valid_exprs( - exprs: Seq[ExprIR], schema: FrozenSchema -) -> ColumnNotFoundError | DuplicateError | None: - """Return an appropriate error if we can't materialize.""" - if err := _ensure_column_names_unique(exprs): - return err +def ensure_valid_exprs(exprs: Seq[ExprIR], schema: FrozenSchema) -> OutputNames: + """Raise an appropriate error if we can't materialize.""" + output_names = _ensure_output_names_unique(exprs) root_names = _root_names_unique(exprs) if not (set(schema.names).issuperset(root_names)): - return column_not_found_error(root_names, schema) - return None + raise column_not_found_error(root_names, schema) + return output_names -def _ensure_column_names_unique(exprs: Seq[ExprIR]) -> DuplicateError | None: - names = [e.meta.output_name() for e in exprs] +def _ensure_output_names_unique(exprs: Seq[ExprIR]) -> OutputNames: + names = tuple(e.meta.output_name() for e in exprs) if len(names) != len(set(names)): - return duplicate_error(exprs) - return None + raise duplicate_error(exprs) + return names def _root_names_unique(exprs: Seq[ExprIR]) -> set[str]: diff --git a/narwhals/_plan/expr_rewrites.py b/narwhals/_plan/expr_rewrites.py index 518524a3db..879d194dd7 100644 --- a/narwhals/_plan/expr_rewrites.py +++ b/narwhals/_plan/expr_rewrites.py @@ -26,7 +26,7 @@ def rewrite_all( - Currently we do a full traversal of each tree per-rewrite function - There's no caching *after* `prepare_projection` yet """ - out_irs, _ = prepare_projection(parse.parse_into_seq_of_expr_ir(*exprs), schema) + out_irs, _, _ = prepare_projection(parse.parse_into_seq_of_expr_ir(*exprs), schema) return tuple(_rewrite_sequential(ir, rewrites) for ir in out_irs) diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 107dd188e5..280fc42b78 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -44,6 +44,7 @@ RightT = TypeVar("RightT", bound="ExprIR", default="ExprIR") RightT2 = TypeVar("RightT2", bound="ExprIR", default="ExprIR") OperatorFn: TypeAlias = "t.Callable[[t.Any, t.Any], t.Any]" +ExprIRT = TypeVar("ExprIRT", bound="ExprIR", default="ExprIR") SelectorT = TypeVar("SelectorT", bound="SelectorIR", default="SelectorIR") LeftSelectorT = TypeVar("LeftSelectorT", bound="SelectorIR", default="SelectorIR") diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index dc476a8a77..fefbb93aa0 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -411,7 +411,7 @@ def test_prepare_projection( schema_1: dict[str, DType], ) -> None: irs_in = parse_into_seq_of_expr_ir(into_exprs) - actual, _ = prepare_projection(irs_in, schema_1) + actual, _, _ = prepare_projection(irs_in, schema_1) assert len(actual) == len(expected) for lhs, rhs in zip(actual, expected): assert_expr_ir_equal(lhs, rhs) @@ -541,13 +541,13 @@ def test_prepare_projection_horizontal_alias( expr = function(into_exprs) alias_1 = expr.alias("alias(x1)") irs = parse_into_seq_of_expr_ir(alias_1) - out_irs, _ = prepare_projection(irs, schema_1) + out_irs, _, _ = prepare_projection(irs, schema_1) assert len(out_irs) == 1 assert out_irs[0] == function("a", "b", "c").alias("alias(x1)")._ir alias_2 = alias_1.alias("alias(x2)") irs = parse_into_seq_of_expr_ir(alias_2) - out_irs, _ = prepare_projection(irs, schema_1) + out_irs, _, _ = prepare_projection(irs, schema_1) assert len(out_irs) == 1 assert out_irs[0] == function("a", "b", "c").alias("alias(x1)").alias("alias(x2)")._ir From bcc071a78ca91c035e78fabe61c12312f6738b28 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 23 Jun 2025 17:27:43 +0100 Subject: [PATCH 233/368] refactor: Update to rewrite w/ `NamedIR` --- narwhals/_plan/common.py | 13 +++++++++++-- narwhals/_plan/expr_expansion.py | 20 ++++++++++++++++++++ narwhals/_plan/expr_rewrites.py | 27 +++++++++++++++++---------- narwhals/_plan/typing.py | 1 + tests/plan/utils.py | 18 ++++++++++++++---- 5 files changed, 63 insertions(+), 16 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index f520e0ebce..8de566f549 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -2,9 +2,9 @@ import datetime as dt from decimal import Decimal -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar, cast -from narwhals._plan.typing import ExprIRT, ExprT, IRNamespaceT, MapIR, Ns, Seq +from narwhals._plan.typing import ExprIRT, ExprIRT2, ExprT, IRNamespaceT, MapIR, Ns, Seq from narwhals.dtypes import DType from narwhals.utils import Version @@ -279,6 +279,15 @@ class NamedIR(Immutable, Generic[ExprIRT]): expr: ExprIRT name: str + def map_ir(self, function: MapIR, /) -> NamedIR[ExprIR]: + """**WARNING**: don't use renaming ops here, or `self.name` is invalid.""" + return self.with_expr(self.expr.map_ir(function)) + + def with_expr(self, expr: ExprIRT2, /) -> NamedIR[ExprIRT2]: + if expr == self.expr: + return cast("NamedIR[ExprIRT2]", self) + return NamedIR(expr=expr, name=self.name) + class IRNamespace(Immutable): __slots__ = ("_ir",) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 6791f8f3b9..070e14d5d1 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -48,6 +48,7 @@ _IMMUTABLE_HASH_NAME, ExprIR, Immutable, + NamedIR, SelectorIR, is_horizontal_reduction, ) @@ -268,12 +269,22 @@ def prepare_projection( rewritten = rewrite_projections(tuple(exprs), keys=(), schema=frozen_schema) output_names = ensure_valid_exprs(rewritten, frozen_schema) # TODO @dangotbanned: (Seq[ExprIR], OutputNames) -> (Seq[NamedIR]) + # See `expr_rewrites.rewrite_all` # TODO @dangotbanned: Return a new schema, with the changes (name only) from projecting exprs # - `select` (subset from schema, maybe need root names as well?) # - `with_columns` https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs#L1045-L1079 return rewritten, frozen_schema, output_names +def into_named_irs(exprs: Seq[ExprIR], names: OutputNames) -> Seq[NamedIR]: + if len(exprs) != len(names): + msg = f"zip length mismatch: {len(exprs)} != {len(names)}" + raise ValueError(msg) + return tuple( + NamedIR(expr=remove_alias(ir), name=name) for ir, name in zip(exprs, names) + ) + + def ensure_valid_exprs(exprs: Seq[ExprIR], schema: FrozenSchema) -> OutputNames: """Raise an appropriate error if we can't materialize.""" output_names = _ensure_output_names_unique(exprs) @@ -326,6 +337,15 @@ def is_index_in_range(index: int, n_fields: int) -> bool: return not (idx < 0 or idx >= n_fields) +def remove_alias(origin: ExprIR, /) -> ExprIR: + def fn(child: ExprIR, /) -> ExprIR: + if isinstance(child, Alias): + return child.expr + return child + + return origin.map_ir(fn) + + def remove_exclude(origin: ExprIR, /) -> ExprIR: def fn(child: ExprIR, /) -> ExprIR: if isinstance(child, Exclude): diff --git a/narwhals/_plan/expr_rewrites.py b/narwhals/_plan/expr_rewrites.py index 879d194dd7..bd0a5882ff 100644 --- a/narwhals/_plan/expr_rewrites.py +++ b/narwhals/_plan/expr_rewrites.py @@ -5,20 +5,22 @@ from typing import TYPE_CHECKING from narwhals._plan import expr_parsing as parse -from narwhals._plan.common import is_function_expr, is_window_expr -from narwhals._plan.expr_expansion import prepare_projection +from narwhals._plan.common import NamedIR, is_function_expr, is_window_expr +from narwhals._plan.expr_expansion import into_named_irs, prepare_projection if TYPE_CHECKING: - from collections.abc import Mapping, Sequence + from collections.abc import Callable, Mapping, Sequence from narwhals._plan.common import ExprIR - from narwhals._plan.typing import IntoExpr, MapIR, Seq + from narwhals._plan.typing import IntoExpr, Seq from narwhals.dtypes import DType def rewrite_all( - *exprs: IntoExpr, schema: Mapping[str, DType], rewrites: Sequence[MapIR] -) -> Seq[ExprIR]: + *exprs: IntoExpr, + schema: Mapping[str, DType], + rewrites: Sequence[Callable[[NamedIR], NamedIR]], +) -> Seq[NamedIR]: """Very naive approach, but should work for a demo. - Assumes all of `rewrites` ends with a `ExprIR.map_ir` call @@ -26,11 +28,16 @@ def rewrite_all( - Currently we do a full traversal of each tree per-rewrite function - There's no caching *after* `prepare_projection` yet """ - out_irs, _, _ = prepare_projection(parse.parse_into_seq_of_expr_ir(*exprs), schema) - return tuple(_rewrite_sequential(ir, rewrites) for ir in out_irs) + out_irs, _, names = prepare_projection( + parse.parse_into_seq_of_expr_ir(*exprs), schema + ) + named_irs = into_named_irs(out_irs, names) + return tuple(_rewrite_sequential(ir, rewrites) for ir in named_irs) -def _rewrite_sequential(origin: ExprIR, rewrites: Sequence[MapIR], /) -> ExprIR: +def _rewrite_sequential( + origin: NamedIR, rewrites: Sequence[Callable[[NamedIR], NamedIR]], / +) -> NamedIR: result = origin for fn in rewrites: result = fn(result) @@ -39,7 +46,7 @@ def _rewrite_sequential(origin: ExprIR, rewrites: Sequence[MapIR], /) -> ExprIR: # TODO @dangotbanned: Tests # TODO @dangotbanned: Review if `inputs` is always `len(1)`` after `prepare_projection` -def rewrite_elementwise_over(origin: ExprIR, /) -> ExprIR: +def rewrite_elementwise_over(origin: NamedIR, /) -> NamedIR: """Requested in [discord-0]. Before: diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 280fc42b78..baed90e6a3 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -45,6 +45,7 @@ RightT2 = TypeVar("RightT2", bound="ExprIR", default="ExprIR") OperatorFn: TypeAlias = "t.Callable[[t.Any, t.Any], t.Any]" ExprIRT = TypeVar("ExprIRT", bound="ExprIR", default="ExprIR") +ExprIRT2 = TypeVar("ExprIRT2", bound="ExprIR", default="ExprIR") SelectorT = TypeVar("SelectorT", bound="SelectorIR", default="SelectorIR") LeftSelectorT = TypeVar("LeftSelectorT", bound="SelectorIR", default="SelectorIR") diff --git a/tests/plan/utils.py b/tests/plan/utils.py index b20e4aba8a..0b3b2712dc 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -2,17 +2,27 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import is_expr +from narwhals._plan.common import ExprIR, NamedIR, is_expr if TYPE_CHECKING: from typing_extensions import LiteralString - from narwhals._plan.common import ExprIR from narwhals._plan.dummy import DummyExpr +def _unwrap_ir(obj: DummyExpr | ExprIR | NamedIR) -> ExprIR: + if is_expr(obj): + return obj._ir + if isinstance(obj, ExprIR): + return obj + if isinstance(obj, NamedIR): + return obj.expr + else: + raise NotImplementedError(type(obj)) + + def assert_expr_ir_equal( - actual: DummyExpr | ExprIR, expected: DummyExpr | ExprIR | LiteralString, / + actual: DummyExpr | ExprIR | NamedIR, expected: DummyExpr | ExprIR | LiteralString, / ) -> None: """Assert that `actual` is equivalent to `expected`. @@ -24,7 +34,7 @@ def assert_expr_ir_equal( Performing a repr comparison is more fragile, so should be avoided *unless* we raise an error at creation time. """ - lhs = actual._ir if is_expr(actual) else actual + lhs = _unwrap_ir(actual) if isinstance(expected, str): assert repr(lhs) == expected else: From d973b832a6ee0a9d26465e09408568f652980f0a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 23 Jun 2025 21:51:25 +0100 Subject: [PATCH 234/368] fix: Make sure to call `function` on result --- narwhals/_plan/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 8de566f549..ee0ff15bf2 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -281,7 +281,7 @@ class NamedIR(Immutable, Generic[ExprIRT]): def map_ir(self, function: MapIR, /) -> NamedIR[ExprIR]: """**WARNING**: don't use renaming ops here, or `self.name` is invalid.""" - return self.with_expr(self.expr.map_ir(function)) + return self.with_expr(function(self.expr.map_ir(function))) def with_expr(self, expr: ExprIRT2, /) -> NamedIR[ExprIRT2]: if expr == self.expr: From 5ae792d9ca5fb1cfe498f7a10d131b20fa3cdaf6 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 23 Jun 2025 22:31:00 +0100 Subject: [PATCH 235/368] refactor: Add `map_ir` function, un special-case `NamedIR` --- narwhals/_plan/common.py | 23 +++++++++++- narwhals/_plan/expr_rewrites.py | 62 ++++++++++++--------------------- narwhals/_plan/typing.py | 3 +- 3 files changed, 47 insertions(+), 41 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index ee0ff15bf2..a1755e0cfd 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -4,7 +4,16 @@ from decimal import Decimal from typing import TYPE_CHECKING, Generic, TypeVar, cast -from narwhals._plan.typing import ExprIRT, ExprIRT2, ExprT, IRNamespaceT, MapIR, Ns, Seq +from narwhals._plan.typing import ( + ExprIRT, + ExprIRT2, + ExprT, + IRNamespaceT, + MapIR, + NamedOrExprIRT, + Ns, + Seq, +) from narwhals.dtypes import DType from narwhals.utils import Version @@ -432,3 +441,15 @@ def into_dtype(dtype: IntoDType, /) -> DType: def collect(iterable: Seq[T] | Iterable[T], /) -> Seq[T]: """Collect `iterable` into a `tuple`, *iff* it is not one already.""" return iterable if isinstance(iterable, tuple) else tuple(iterable) + + +def map_ir( + origin: NamedOrExprIRT, function: MapIR, *more_functions: MapIR +) -> NamedOrExprIRT: + """Apply one or more functions, sequentially, to all of `origin`'s children.""" + if more_functions: + result = origin + for fn in (function, *more_functions): + result = result.map_ir(fn) + return result + return origin.map_ir(function) diff --git a/narwhals/_plan/expr_rewrites.py b/narwhals/_plan/expr_rewrites.py index bd0a5882ff..b1dcba3e76 100644 --- a/narwhals/_plan/expr_rewrites.py +++ b/narwhals/_plan/expr_rewrites.py @@ -5,25 +5,22 @@ from typing import TYPE_CHECKING from narwhals._plan import expr_parsing as parse -from narwhals._plan.common import NamedIR, is_function_expr, is_window_expr +from narwhals._plan.common import NamedIR, is_function_expr, is_window_expr, map_ir from narwhals._plan.expr_expansion import into_named_irs, prepare_projection if TYPE_CHECKING: - from collections.abc import Callable, Mapping, Sequence + from collections.abc import Mapping, Sequence from narwhals._plan.common import ExprIR - from narwhals._plan.typing import IntoExpr, Seq + from narwhals._plan.typing import IntoExpr, MapIR, Seq from narwhals.dtypes import DType def rewrite_all( - *exprs: IntoExpr, - schema: Mapping[str, DType], - rewrites: Sequence[Callable[[NamedIR], NamedIR]], + *exprs: IntoExpr, schema: Mapping[str, DType], rewrites: Sequence[MapIR] ) -> Seq[NamedIR]: """Very naive approach, but should work for a demo. - - Assumes all of `rewrites` ends with a `ExprIR.map_ir` call - Applying multiple functions should be happening at a lower level - Currently we do a full traversal of each tree per-rewrite function - There's no caching *after* `prepare_projection` yet @@ -32,21 +29,12 @@ def rewrite_all( parse.parse_into_seq_of_expr_ir(*exprs), schema ) named_irs = into_named_irs(out_irs, names) - return tuple(_rewrite_sequential(ir, rewrites) for ir in named_irs) - - -def _rewrite_sequential( - origin: NamedIR, rewrites: Sequence[Callable[[NamedIR], NamedIR]], / -) -> NamedIR: - result = origin - for fn in rewrites: - result = fn(result) - return result + return tuple(map_ir(ir, *rewrites) for ir in named_irs) # TODO @dangotbanned: Tests # TODO @dangotbanned: Review if `inputs` is always `len(1)`` after `prepare_projection` -def rewrite_elementwise_over(origin: NamedIR, /) -> NamedIR: +def rewrite_elementwise_over(child: ExprIR, /) -> ExprIR: """Requested in [discord-0]. Before: @@ -59,30 +47,26 @@ def rewrite_elementwise_over(origin: NamedIR, /) -> NamedIR: [discord-0]: https://discord.com/channels/1235257048170762310/1383078215303696544/1384807793512677398 """ - - def fn(child: ExprIR, /) -> ExprIR: - if ( - is_window_expr(child) - and is_function_expr(child.expr) - and child.expr.options.is_elementwise() - ): - # NOTE: Aliasing isn't required, but it does help readability - window = child - func = child.expr - if len(func.input) != 1: - msg = ( - f"Expected function inputs to have been expanded, " - f"but got {len(func.input)!r} inputs at: {func}" - ) - raise NotImplementedError(msg) - return func.with_input([window.with_expr(func.input[0])]) - return child - - return origin.map_ir(fn) + if ( + is_window_expr(child) + and is_function_expr(child.expr) + and child.expr.options.is_elementwise() + ): + # NOTE: Aliasing isn't required, but it does help readability + window = child + func = child.expr + if len(func.input) != 1: + msg = ( + f"Expected function inputs to have been expanded, " + f"but got {len(func.input)!r} inputs at: {func}" + ) + raise NotImplementedError(msg) + return func.with_input([window.with_expr(func.input[0])]) + return child # TODO @dangotbanned: Full implementation -def rewrite_binary_agg_over(origin: ExprIR, /) -> ExprIR: +def rewrite_binary_agg_over(child: ExprIR, /) -> ExprIR: """Requested in [discord-1], clarified in [discord-2]. Before: diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index baed90e6a3..753d8d8d6a 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -10,7 +10,7 @@ from narwhals._compliant import CompliantNamespace as Namespace from narwhals._compliant.typing import CompliantExprAny from narwhals._plan import operators as ops - from narwhals._plan.common import ExprIR, Function, IRNamespace, SelectorIR + from narwhals._plan.common import ExprIR, Function, IRNamespace, NamedIR, SelectorIR from narwhals._plan.dummy import DummyExpr, DummySeries from narwhals._plan.functions import RollingWindow from narwhals.typing import NonNestedLiteral @@ -46,6 +46,7 @@ OperatorFn: TypeAlias = "t.Callable[[t.Any, t.Any], t.Any]" ExprIRT = TypeVar("ExprIRT", bound="ExprIR", default="ExprIR") ExprIRT2 = TypeVar("ExprIRT2", bound="ExprIR", default="ExprIR") +NamedOrExprIRT = TypeVar("NamedOrExprIRT", "NamedIR[t.Any]", "ExprIR") SelectorT = TypeVar("SelectorT", bound="SelectorIR", default="SelectorIR") LeftSelectorT = TypeVar("LeftSelectorT", bound="SelectorIR", default="SelectorIR") From ed9d769cd3baa34c8090a17ff68dd0e7e3c57843 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 24 Jun 2025 11:31:25 +0100 Subject: [PATCH 236/368] docs(typing): `IntoFrozenSchema` alias --- narwhals/_plan/expr_expansion.py | 8 +++++++- narwhals/_plan/expr_rewrites.py | 11 +++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 070e14d5d1..a4ed337bd2 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -88,6 +88,12 @@ from narwhals._plan.typing import Seq from narwhals.dtypes import DType +IntoFrozenSchema: TypeAlias = "Mapping[str, DType] | FrozenSchema" +"""A schema to freeze, or an already frozen one. + +As `DType` instances (`.values()`) are hashable, we can coerce the schema +into a cache-safe proxy structure (`FrozenSchema`). +""" FrozenColumns: TypeAlias = "Seq[str]" Excluded: TypeAlias = "frozenset[str]" @@ -249,7 +255,7 @@ def with_multiple_columns(self) -> ExpansionFlags: def prepare_projection( - exprs: Sequence[ExprIR], schema: Mapping[str, DType] | FrozenSchema + exprs: Sequence[ExprIR], schema: IntoFrozenSchema ) -> tuple[Seq[ExprIR], FrozenSchema, OutputNames]: """Expand IRs into named column selections. diff --git a/narwhals/_plan/expr_rewrites.py b/narwhals/_plan/expr_rewrites.py index b1dcba3e76..da0864c2ed 100644 --- a/narwhals/_plan/expr_rewrites.py +++ b/narwhals/_plan/expr_rewrites.py @@ -6,18 +6,21 @@ from narwhals._plan import expr_parsing as parse from narwhals._plan.common import NamedIR, is_function_expr, is_window_expr, map_ir -from narwhals._plan.expr_expansion import into_named_irs, prepare_projection +from narwhals._plan.expr_expansion import ( + IntoFrozenSchema, + into_named_irs, + prepare_projection, +) if TYPE_CHECKING: - from collections.abc import Mapping, Sequence + from collections.abc import Sequence from narwhals._plan.common import ExprIR from narwhals._plan.typing import IntoExpr, MapIR, Seq - from narwhals.dtypes import DType def rewrite_all( - *exprs: IntoExpr, schema: Mapping[str, DType], rewrites: Sequence[MapIR] + *exprs: IntoExpr, schema: IntoFrozenSchema, rewrites: Sequence[MapIR] ) -> Seq[NamedIR]: """Very naive approach, but should work for a demo. From 7dcdf86b533f10444234efa804d13701daa3f582 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 24 Jun 2025 11:55:29 +0100 Subject: [PATCH 237/368] test: Move `meta.output_name` doctests, add failing one Discovered the bug while writing tests for `NamedIR` - which lost track of the alias after enocuntering a literal --- narwhals/_plan/meta.py | 10 ---------- tests/plan/meta_test.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index b9a85d25f5..4e4d8d9092 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -55,8 +55,6 @@ def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: >>> a = nwd.col("a") >>> b = a.alias("b") >>> c = b.min().alias("c") - >>> c_over = c.over(nwd.col("e"), nwd.col("f")) - >>> c_over_sort = c_over.sort_by(nwd.nth(9), nwd.col("g", "h")) >>> >>> a.meta.output_name() 'a' @@ -64,14 +62,6 @@ def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: 'b' >>> c.meta.output_name() 'c' - >>> c_over.meta.output_name() - 'c' - >>> c_over_sort.meta.output_name() - 'c' - >>> nwd.lit(1).meta.output_name() - 'literal' - >>> nwd.len().meta.output_name() - 'len' """ ok_or_err = _expr_output_name(self._ir) if isinstance(ok_or_err, ComputeError): diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index e4e0a41be4..0711e0a24f 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -54,3 +54,41 @@ def test_meta_root_names( nw_result = nw_expr.meta.root_names() assert nw_result == expected assert nw_result == pl_result + + +@pytest.mark.parametrize( + ("nw_expr", "pl_expr", "expected"), + [ + (nwd.col("a"), pl.col("a"), "a"), + (nwd.lit(1), pl.lit(1), "literal"), + (nwd.len(), pl.len(), "len"), + ( + nwd.col("a") + .alias("b") + .min() + .alias("c") + .over("e", "f") + .sort_by(nwd.nth(9), nwd.col("g", "h")), + pl.col("a") + .alias("b") + .min() + .alias("c") + .over("e", "f") + .sort_by(pl.nth(9), pl.col("g", "h")), + "c", + ), + pytest.param( + nwd.col("c").alias("x").fill_null(50), + pl.col("c").alias("x").fill_null(50), + "x", + marks=pytest.mark.xfail( + reason="Incorrectly matched `Literal.name` instead of earlier `Alias.name`." + ), + ), + ], +) +def test_meta_output_name(nw_expr: DummyExpr, pl_expr: pl.Expr, expected: str) -> None: + pl_result = pl_expr.meta.output_name() + nw_result = nw_expr.meta.output_name() + assert nw_result == expected + assert nw_result == pl_result From 54d078113814e84e815cdbbaa380fe3b6ffca35d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 24 Jun 2025 14:36:10 +0100 Subject: [PATCH 238/368] fix: Get the right name from `FunctionExpr` Following up with more tests/fixes for other complex nodes --- narwhals/_plan/common.py | 8 ++++++++ narwhals/_plan/expr.py | 16 ++++++++++++++++ narwhals/_plan/meta.py | 2 +- tests/plan/meta_test.py | 4 +--- 4 files changed, 26 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index a1755e0cfd..3e8cac987f 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -244,6 +244,14 @@ def iter_root_names(self) -> Iterator[ExprIR]: """ yield from self.iter_left() + def iter_output_name(self) -> Iterator[ExprIR]: + """Override for different iteration behavior in `ExprIR.meta.output_name`. + + Note: + Identical to `iter_right` by default. + """ + yield from self.iter_right() + @property def meta(self) -> IRMetaNamespace: from narwhals._plan.meta import IRMetaNamespace diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 82fbed478d..b0ecf7a268 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -463,6 +463,22 @@ def iter_right(self) -> t.Iterator[ExprIR]: for e in reversed(self.input): yield from e.iter_right() + def iter_output_name(self) -> t.Iterator[ExprIR]: + """When we have multiple inputs, we want the name of the left-most expression. + + For expr: + + col("c").alias("x").fill_null(50) + + We are interested in the name which comes from the root: + + FunctionExpr(..., [Alias(..., name='...'), Literal(...), ...]) + # ^^^^^ ^^^ + """ + yield self + for e in self.input: + yield from e.iter_output_name() + def __init__( self, *, diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 4e4d8d9092..6297894053 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -125,7 +125,7 @@ def _expr_to_leaf_column_name(ir: ExprIR) -> str | ComputeError: def _expr_output_name(ir: ExprIR) -> str | ComputeError: from narwhals._plan import expr - for e in ir.iter_right(): + for e in ir.iter_output_name(): if isinstance(e, (expr.WindowExpr, expr.SortBy)): # Don't follow `over(partition_by=...)` or `sort_by(by=...) return _expr_output_name(e.expr) diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index 0711e0a24f..6685442110 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -81,9 +81,7 @@ def test_meta_root_names( nwd.col("c").alias("x").fill_null(50), pl.col("c").alias("x").fill_null(50), "x", - marks=pytest.mark.xfail( - reason="Incorrectly matched `Literal.name` instead of earlier `Alias.name`." - ), + id="FunctionExpr-Literal", ), ], ) From e8106c4bcb883797094415656fa33885a2c61c83 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 24 Jun 2025 18:38:04 +0100 Subject: [PATCH 239/368] test: Lots of `output_name` coverage 5 need fixing --- tests/plan/meta_test.py | 88 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index 6685442110..e10d764e2e 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -56,6 +56,11 @@ def test_meta_root_names( assert nw_result == pl_result +XFAIL_WRONG_ALIAS = pytest.mark.xfail( + reason="Found the wrong alias.\nNeed to add `iter_output_name` override." +) + + @pytest.mark.parametrize( ("nw_expr", "pl_expr", "expected"), [ @@ -83,6 +88,89 @@ def test_meta_root_names( "x", id="FunctionExpr-Literal", ), + pytest.param( + ( + nwd.col("ROOT") + .alias("ROOT-ALIAS") + .filter(nwd.col("b") >= 30, nwd.col("c").alias("d") == 7) + + nwd.col("RHS").alias("RHS-ALIAS") + ), + ( + pl.col("ROOT") + .alias("ROOT-ALIAS") + .filter(pl.col("b") >= 30, pl.col("c").alias("d") == 7) + + pl.col("RHS").alias("RHS-ALIAS") + ), + "ROOT-ALIAS", + id="BinaryExpr-Multiple", + marks=XFAIL_WRONG_ALIAS, + ), + pytest.param( + nwd.col("ROOT").alias("ROOT-ALIAS").mean().over(nwd.col("a").alias("b")), + pl.col("ROOT").alias("ROOT-ALIAS").mean().over(pl.col("a").alias("b")), + "ROOT-ALIAS", + id="WindowExpr", + ), + pytest.param( + nwd.when(nwd.col("a").alias("a?")).then(10), + pl.when(pl.col("a").alias("a?")).then(10), + "literal", + id="When-Literal", + marks=XFAIL_WRONG_ALIAS, + ), + pytest.param( + nwd.when(nwd.col("a").alias("a?")).then(nwd.col("b")).otherwise(20), + pl.when(pl.col("a").alias("a?")).then(pl.col("b")).otherwise(20), + "b", + id="When-Column-Literal", + marks=XFAIL_WRONG_ALIAS, + ), + pytest.param( + nwd.when(a=1).then(10).otherwise(nwd.col("c").alias("c?")), + pl.when(a=1).then(10).otherwise(pl.col("c").alias("c?")), + "literal", + id="When-Literal-Alias", + ), + pytest.param( + ( + nwd.when(nwd.col("a").alias("a?")) + .then(1) + .when(nwd.col("b") == 1) + .then(nwd.col("c")) + ), + ( + pl.when(pl.col("a").alias("a?")) + .then(1) + .when(pl.col("b") == 1) + .then(pl.col("c")) + ), + "literal", + id="When-Literal-BinaryExpr-Column", + marks=XFAIL_WRONG_ALIAS, + ), + pytest.param( + ( + nwd.when(nwd.col("foo") > 2, nwd.col("bar") < 3) + .then(nwd.lit("Yes")) + .otherwise(nwd.lit("No")) + .alias("TARGET") + ), + ( + pl.when(pl.col("foo") > 2, pl.col("bar") < 3) + .then(pl.lit("Yes")) + .otherwise(pl.lit("No")) + .alias("TARGET") + ), + "TARGET", + id="When2-Literal-Literal-Alias", + ), + pytest.param( + (nwd.col("ROOT").alias("ROOT-ALIAS").filter(nwd.col("c") <= 1).mean()), + (pl.col("ROOT").alias("ROOT-ALIAS").filter(pl.col("c") <= 1).mean()), + "ROOT-ALIAS", + id="Filter", + marks=XFAIL_WRONG_ALIAS, + ), ], ) def test_meta_output_name(nw_expr: DummyExpr, pl_expr: pl.Expr, expected: str) -> None: From b79181a0904e090798b96b0292bf53e09e7d7ef3 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 24 Jun 2025 18:54:25 +0100 Subject: [PATCH 240/368] test: Backcompat `len`, `nth` https://github.com/narwhals-dev/narwhals/actions/runs/15848535213/job/44676174572?pr=2572 https://github.com/narwhals-dev/narwhals/actions/runs/15848535214/job/44676174562?pr=2572 --- tests/plan/meta_test.py | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index e10d764e2e..cb43e865dd 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -22,6 +22,10 @@ ) else: # pragma: no cover OVER_CASE = (nwd.col("a").last().over("b"), pl.col("a").last().over("b"), ["a", "b"]) +if POLARS_VERSION >= (0, 20, 5): + LEN_CASE = (nwd.len(), pl.len(), "len") +else: # pragma: no cover + LEN_CASE = (nwd.len().alias("count"), pl.count(), "count") @pytest.mark.parametrize( @@ -66,21 +70,26 @@ def test_meta_root_names( [ (nwd.col("a"), pl.col("a"), "a"), (nwd.lit(1), pl.lit(1), "literal"), - (nwd.len(), pl.len(), "len"), - ( - nwd.col("a") - .alias("b") - .min() - .alias("c") - .over("e", "f") - .sort_by(nwd.nth(9), nwd.col("g", "h")), - pl.col("a") - .alias("b") - .min() - .alias("c") - .over("e", "f") - .sort_by(pl.nth(9), pl.col("g", "h")), + LEN_CASE, + pytest.param( + ( + nwd.col("a") + .alias("b") + .min() + .alias("c") + .over("e", "f") + .sort_by(nwd.col("i"), nwd.col("g", "h")) + ), + ( + pl.col("a") + .alias("b") + .min() + .alias("c") + .over("e", "f") + .sort_by(pl.col("i"), pl.col("g", "h")) + ), "c", + id="Kitchen-Sink", ), pytest.param( nwd.col("c").alias("x").fill_null(50), From 26716c5f1f1b1e3c769f203d61e82c3005fe426d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 24 Jun 2025 19:00:45 +0100 Subject: [PATCH 241/368] refactor: Handle `SortBy`, `WindowExpr` internally This was inherited from the `rust` impl, but makes more sense in the class itself --- narwhals/_plan/expr.py | 6 ++++++ narwhals/_plan/meta.py | 3 --- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index b0ecf7a268..7737c8562a 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -388,6 +388,9 @@ def iter_right(self) -> t.Iterator[ExprIR]: yield from e.iter_right() yield from self.expr.iter_right() + def iter_output_name(self) -> t.Iterator[ExprIR]: + yield from self.expr.iter_output_name() + def map_ir(self, function: MapIR, /) -> ExprIR: by = (ir.map_ir(function) for ir in self.by) return function(self.with_expr(self.expr.map_ir(function)).with_by(by)) @@ -575,6 +578,9 @@ def iter_right(self) -> t.Iterator[ExprIR]: yield from e.iter_right() yield from self.expr.iter_right() + def iter_output_name(self) -> t.Iterator[ExprIR]: + yield from self.expr.iter_output_name() + def map_ir(self, function: MapIR, /) -> ExprIR: over = self.with_expr(self.expr.map_ir(function)).with_partition_by( ir.map_ir(function) for ir in self.partition_by diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 6297894053..5cd097ebea 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -126,9 +126,6 @@ def _expr_output_name(ir: ExprIR) -> str | ComputeError: from narwhals._plan import expr for e in ir.iter_output_name(): - if isinstance(e, (expr.WindowExpr, expr.SortBy)): - # Don't follow `over(partition_by=...)` or `sort_by(by=...) - return _expr_output_name(e.expr) if isinstance(e, (expr.Column, expr.Alias, expr.Literal, expr.Len)): return e.name if isinstance(e, (expr.All, expr.KeepName, expr.RenameAlias)): From 3a1c3759e8a2a97f339411f33022ea8156a97184 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 24 Jun 2025 19:04:55 +0100 Subject: [PATCH 242/368] refactor: Simplify `FunctionExpr` version --- narwhals/_plan/expr.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 7737c8562a..28ce04ad65 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -478,8 +478,7 @@ def iter_output_name(self) -> t.Iterator[ExprIR]: FunctionExpr(..., [Alias(..., name='...'), Literal(...), ...]) # ^^^^^ ^^^ """ - yield self - for e in self.input: + for e in self.input[:1]: yield from e.iter_output_name() def __init__( From bd5c33f6cf843009f36bd8feb15999e6a8d17511 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 24 Jun 2025 19:21:13 +0100 Subject: [PATCH 243/368] fix: Ensure `output_name` matches upstream Resolves https://github.com/narwhals-dev/narwhals/pull/2572#discussion_r2163659077 --- narwhals/_plan/aggregation.py | 3 +++ narwhals/_plan/expr.py | 15 +++++++++++++++ tests/plan/meta_test.py | 10 ---------- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index 8e1cfc780d..0e5665fcb0 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -39,6 +39,9 @@ def iter_right(self) -> Iterator[ExprIR]: yield self yield from self.expr.iter_right() + def iter_output_name(self) -> Iterator[ExprIR]: + yield from self.expr.iter_output_name() + def map_ir(self, function: MapIR, /) -> ExprIR: return function(self.with_expr(self.expr.map_ir(function))) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 28ce04ad65..da8b6333af 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -285,6 +285,9 @@ def iter_right(self) -> t.Iterator[ExprIR]: yield from self.right.iter_right() yield from self.left.iter_right() + def iter_output_name(self) -> t.Iterator[ExprIR]: + yield from self.left.iter_output_name() + def with_left(self, left: LeftT2, /) -> BinaryExpr[LeftT2, OperatorT, RightT]: if left == self.left: return t.cast("BinaryExpr[LeftT2, OperatorT, RightT]", self) @@ -324,6 +327,9 @@ def iter_right(self) -> t.Iterator[ExprIR]: yield self yield from self.expr.iter_right() + def iter_output_name(self) -> t.Iterator[ExprIR]: + yield from self.expr.iter_output_name() + def map_ir(self, function: MapIR, /) -> ExprIR: return function(self.with_expr(self.expr.map_ir(function))) @@ -353,6 +359,9 @@ def iter_right(self) -> t.Iterator[ExprIR]: yield self yield from self.expr.iter_right() + def iter_output_name(self) -> t.Iterator[ExprIR]: + yield from self.expr.iter_output_name() + def map_ir(self, function: MapIR, /) -> ExprIR: return function(self.with_expr(self.expr.map_ir(function))) @@ -525,6 +534,9 @@ def iter_right(self) -> t.Iterator[ExprIR]: yield from self.by.iter_right() yield from self.expr.iter_right() + def iter_output_name(self) -> t.Iterator[ExprIR]: + yield from self.expr.iter_output_name() + def map_ir(self, function: MapIR, /) -> ExprIR: expr = self.expr.map_ir(function) by = self.by.map_ir(function) @@ -786,6 +798,9 @@ def iter_right(self) -> t.Iterator[ExprIR]: yield from self.falsy.iter_right() yield from self.truthy.iter_right() + def iter_output_name(self) -> t.Iterator[ExprIR]: + yield from self.truthy.iter_output_name() + def map_ir(self, function: MapIR, /) -> ExprIR: predicate = self.predicate.map_ir(function) truthy = self.truthy.map_ir(function) diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index cb43e865dd..fcb22d09a3 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -60,11 +60,6 @@ def test_meta_root_names( assert nw_result == pl_result -XFAIL_WRONG_ALIAS = pytest.mark.xfail( - reason="Found the wrong alias.\nNeed to add `iter_output_name` override." -) - - @pytest.mark.parametrize( ("nw_expr", "pl_expr", "expected"), [ @@ -112,7 +107,6 @@ def test_meta_root_names( ), "ROOT-ALIAS", id="BinaryExpr-Multiple", - marks=XFAIL_WRONG_ALIAS, ), pytest.param( nwd.col("ROOT").alias("ROOT-ALIAS").mean().over(nwd.col("a").alias("b")), @@ -125,14 +119,12 @@ def test_meta_root_names( pl.when(pl.col("a").alias("a?")).then(10), "literal", id="When-Literal", - marks=XFAIL_WRONG_ALIAS, ), pytest.param( nwd.when(nwd.col("a").alias("a?")).then(nwd.col("b")).otherwise(20), pl.when(pl.col("a").alias("a?")).then(pl.col("b")).otherwise(20), "b", id="When-Column-Literal", - marks=XFAIL_WRONG_ALIAS, ), pytest.param( nwd.when(a=1).then(10).otherwise(nwd.col("c").alias("c?")), @@ -155,7 +147,6 @@ def test_meta_root_names( ), "literal", id="When-Literal-BinaryExpr-Column", - marks=XFAIL_WRONG_ALIAS, ), pytest.param( ( @@ -178,7 +169,6 @@ def test_meta_root_names( (pl.col("ROOT").alias("ROOT-ALIAS").filter(pl.col("c") <= 1).mean()), "ROOT-ALIAS", id="Filter", - marks=XFAIL_WRONG_ALIAS, ), ], ) From a84ae052df2d0296b22e31e5500beae89295df65 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 24 Jun 2025 19:49:37 +0100 Subject: [PATCH 244/368] feat: Add `NamedIR.(__repr__|_repr_html_)` --- narwhals/_plan/common.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 3e8cac987f..ed5b2dea7e 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -305,6 +305,12 @@ def with_expr(self, expr: ExprIRT2, /) -> NamedIR[ExprIRT2]: return cast("NamedIR[ExprIRT2]", self) return NamedIR(expr=expr, name=self.name) + def __repr__(self) -> str: + return f"{self.name}={self.expr!r}" + + def _repr_html_(self) -> str: + return f"{self.name}={self.expr._repr_html_()}" + class IRNamespace(Immutable): __slots__ = ("_ir",) From a46b3e51e495f6655fef77f774ac16d44cb24eda Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 24 Jun 2025 21:59:20 +0100 Subject: [PATCH 245/368] test: Add `test_rewrite_elementwise_over_complex` --- tests/plan/expr_rewrites_test.py | 45 +++++++++++++++++++++++++++++++- tests/plan/utils.py | 6 ++++- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index 6883664f53..455627936e 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -5,7 +5,8 @@ import pytest import narwhals as nw -from narwhals._plan import demo as nwd, expr_parsing as parse +from narwhals._plan import demo as nwd, expr_parsing as parse, selectors as ndcs +from narwhals._plan.common import ExprIR, NamedIR, is_expr from narwhals._plan.expr import WindowExpr from narwhals._plan.expr_rewrites import rewrite_all, rewrite_elementwise_over from narwhals._plan.window import Over @@ -13,6 +14,7 @@ from tests.plan.utils import assert_expr_ir_equal if TYPE_CHECKING: + from narwhals._plan.dummy import DummyExpr from narwhals._plan.typing import IntoExpr from narwhals.dtypes import DType @@ -73,3 +75,44 @@ def test_rewrite_elementwise_over_multiple(schema_2: dict[str, DType]) -> None: assert len(actual) == 2 for lhs, rhs in zip(actual, expected): assert_expr_ir_equal(lhs, rhs) + + +def named_ir(name: str, expr: DummyExpr | ExprIR, /) -> NamedIR[ExprIR]: + """Helper constructor for test compare.""" + ir = expr._ir if is_expr(expr) else expr + return NamedIR(expr=ir, name=name) + + +def test_rewrite_elementwise_over_complex(schema_2: dict[str, DType]) -> None: + expected = ( + named_ir("a", nwd.col("a")), + named_ir("b", nwd.col("b").cast(nw.String())), + named_ir("x2", nwd.col("c").max().over("a").fill_null(50)), + named_ir("d**", ~nwd.col("d").is_duplicated().over("b")), + named_ir("f_some", nwd.col("f").str.contains("some")), + named_ir("g_some", nwd.col("g").str.contains("some")), + named_ir("h_some", nwd.col("h").str.contains("some")), + named_ir("D", nwd.col("d").null_count().over("f", "g", "j").sqrt()), + named_ir("E", nwd.col("e").null_count().over("f", "g", "j").sqrt()), + named_ir("B", nwd.col("b").null_count().over("f", "g", "j").sqrt()), + ) + before = ( + nwd.col("a"), + nwd.col("b").cast(nw.String()), + ( + _to_window_expr(nwd.col("c").max().alias("x").fill_null(50), "a") + .to_narwhals() + .alias("x2") + ), + ~(nwd.col("d").is_duplicated().alias("d*")).alias("d**").over("b"), + ndcs.string().str.contains("some").name.suffix("_some"), + ( + _to_window_expr(nwd.nth(3, 4, 1).null_count().sqrt(), "f", "g", "j") + .to_narwhals() + .name.to_uppercase() + ), + ) + actual = rewrite_all(*before, schema=schema_2, rewrites=[rewrite_elementwise_over]) + assert len(actual) == len(expected) + for lhs, rhs in zip(actual, expected): + assert_expr_ir_equal(lhs, rhs) diff --git a/tests/plan/utils.py b/tests/plan/utils.py index 0b3b2712dc..d284fb4821 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -22,7 +22,9 @@ def _unwrap_ir(obj: DummyExpr | ExprIR | NamedIR) -> ExprIR: def assert_expr_ir_equal( - actual: DummyExpr | ExprIR | NamedIR, expected: DummyExpr | ExprIR | LiteralString, / + actual: DummyExpr | ExprIR | NamedIR, + expected: DummyExpr | ExprIR | NamedIR | LiteralString, + /, ) -> None: """Assert that `actual` is equivalent to `expected`. @@ -37,6 +39,8 @@ def assert_expr_ir_equal( lhs = _unwrap_ir(actual) if isinstance(expected, str): assert repr(lhs) == expected + elif isinstance(actual, NamedIR) and isinstance(expected, NamedIR): + assert actual == expected else: rhs = expected._ir if is_expr(expected) else expected assert lhs == rhs From 83e2b586f87ad05c5d6a941191f653d2e617ef35 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 24 Jun 2025 22:00:33 +0100 Subject: [PATCH 246/368] fix: Handle `*args` in `rewrite_elementwise_over` --- narwhals/_plan/expr_rewrites.py | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/narwhals/_plan/expr_rewrites.py b/narwhals/_plan/expr_rewrites.py index da0864c2ed..3d883fbae1 100644 --- a/narwhals/_plan/expr_rewrites.py +++ b/narwhals/_plan/expr_rewrites.py @@ -35,9 +35,7 @@ def rewrite_all( return tuple(map_ir(ir, *rewrites) for ir in named_irs) -# TODO @dangotbanned: Tests -# TODO @dangotbanned: Review if `inputs` is always `len(1)`` after `prepare_projection` -def rewrite_elementwise_over(child: ExprIR, /) -> ExprIR: +def rewrite_elementwise_over(window: ExprIR, /) -> ExprIR: """Requested in [discord-0]. Before: @@ -51,21 +49,14 @@ def rewrite_elementwise_over(child: ExprIR, /) -> ExprIR: [discord-0]: https://discord.com/channels/1235257048170762310/1383078215303696544/1384807793512677398 """ if ( - is_window_expr(child) - and is_function_expr(child.expr) - and child.expr.options.is_elementwise() + is_window_expr(window) + and is_function_expr(window.expr) + and window.expr.options.is_elementwise() ): - # NOTE: Aliasing isn't required, but it does help readability - window = child - func = child.expr - if len(func.input) != 1: - msg = ( - f"Expected function inputs to have been expanded, " - f"but got {len(func.input)!r} inputs at: {func}" - ) - raise NotImplementedError(msg) - return func.with_input([window.with_expr(func.input[0])]) - return child + func = window.expr + parent, *args = func.input + return func.with_input((window.with_expr(parent), *args)) + return window # TODO @dangotbanned: Full implementation From ac2b1fffec20a4891f7ca3ed8b5ac5ccb358ce8e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 25 Jun 2025 15:25:31 +0100 Subject: [PATCH 247/368] feat: Add `int_range` Related #2722 -Replaces `RangeLiteral`, which was a lower-level version --- .pre-commit-config.yaml | 4 ++- narwhals/_plan/common.py | 15 +++++++--- narwhals/_plan/demo.py | 23 +++++++++++++++ narwhals/_plan/expr.py | 32 +++++++++++++++++++++ narwhals/_plan/literal.py | 14 --------- narwhals/_plan/options.py | 8 +++++- narwhals/_plan/ranges.py | 51 +++++++++++++++++++++++++++++++++ narwhals/_plan/typing.py | 13 +++++++-- tests/plan/expr_parsing_test.py | 7 ++++- tests/plan/meta_test.py | 9 ++++++ 10 files changed, 152 insertions(+), 24 deletions(-) create mode 100644 narwhals/_plan/ranges.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0840890d7b..beaa29f49f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -79,7 +79,9 @@ repos: narwhals/_utils\.py| narwhals/stable/v1/_dtypes.py| narwhals/.*__init__.py| - narwhals/.*typing\.py + narwhals/.*typing\.py| + narwhals/_plan/demo\.py| + narwhals/_plan/ranges\.py ) - id: pull-request-target name: don't use `pull_request_target` diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index ed5b2dea7e..ed238b92ef 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -2,15 +2,17 @@ import datetime as dt from decimal import Decimal -from typing import TYPE_CHECKING, Generic, TypeVar, cast +from typing import TYPE_CHECKING, Generic, TypeVar, cast, overload from narwhals._plan.typing import ( + DTypeT, ExprIRT, ExprIRT2, ExprT, IRNamespaceT, MapIR, NamedOrExprIRT, + NonNestedDTypeT, Ns, Seq, ) @@ -27,7 +29,7 @@ from narwhals._plan.expr import FunctionExpr, WindowExpr from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.options import FunctionOptions - from narwhals.typing import IntoDType, NonNestedDType, NonNestedLiteral + from narwhals.typing import NonNestedDType, NonNestedLiteral else: # NOTE: This isn't important to the proposal, just wanted IDE support @@ -446,9 +448,14 @@ def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) return mapping.get(type(obj), dtypes.Unknown)() -def into_dtype(dtype: IntoDType, /) -> DType: +@overload +def into_dtype(dtype: type[NonNestedDTypeT], /) -> NonNestedDTypeT: ... +@overload +def into_dtype(dtype: DTypeT, /) -> DTypeT: ... +def into_dtype(dtype: DTypeT | type[NonNestedDTypeT], /) -> DTypeT | NonNestedDTypeT: if isinstance(dtype, type) and issubclass(dtype, DType): - return dtype() + # NOTE: `mypy` needs to learn intersections + return dtype() # type: ignore[return-value] return dtype diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index 0bc829c19b..d5bbfc5d2f 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -18,6 +18,7 @@ from narwhals._plan.dummy import DummySeries from narwhals._plan.expr import All, Column, Columns, IndexColumns, Len, Nth from narwhals._plan.literal import ScalarLiteral, SeriesLiteral +from narwhals._plan.ranges import IntRange from narwhals._plan.strings import ConcatHorizontal from narwhals._plan.when_then import When from narwhals._utils import Version, flatten @@ -29,6 +30,7 @@ from narwhals._plan.dummy import DummyExpr from narwhals._plan.expr import SortBy from narwhals._plan.typing import IntoExpr, IntoExprColumn + from narwhals.dtypes import IntegerType from narwhals.typing import IntoDType, NonNestedLiteral @@ -174,6 +176,27 @@ def when( return When._from_ir(condition) +def int_range( + start: int | IntoExprColumn = 0, + end: int | IntoExprColumn | None = None, + step: int = 1, + *, + dtype: IntegerType | type[IntegerType] = Version.MAIN.dtypes.Int64, + eager: bool = False, +) -> DummyExpr: + if end is None: + end = start + start = 0 + if eager: + msg = f"{eager=}" + raise NotImplementedError(msg) + return ( + IntRange(step=step, dtype=into_dtype(dtype)) + .to_function_expr(*parse.parse_into_seq_of_expr_ir(start, end)) + .to_narwhals() + ) + + def _is_order_enforcing_previous(obj: t.Any) -> TypeIs[SortBy]: """In theory, we could add other nodes to this check.""" from narwhals._plan.expr import SortBy diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index da8b6333af..e2f4db2000 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -26,6 +26,7 @@ MapIR, Ns, OperatorT, + RangeT, RightSelectorT, RightT, RightT2, @@ -35,6 +36,7 @@ Seq, ) from narwhals._utils import flatten +from narwhals.exceptions import InvalidOperationError if t.TYPE_CHECKING: from typing_extensions import Self @@ -511,6 +513,36 @@ class AnonymousExpr(FunctionExpr["MapBatches"]): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L158-L166.""" +class RangeExpr(FunctionExpr[RangeT]): + """E.g. `int_range(...)`. + + Special-cased as it is only allowed scalar inputs, and is row_separable. + + Contradicts the check in `FunctionExpr`, so we've got something *like* [`ensure_range_bounds_contain_exactly_one_value`]. + + [`ensure_range_bounds_contain_exactly_one_value`]:https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/aexpr/function_expr/range/int_range.rs#L9-L14 + """ + + def __init__( + self, + *, + input: Seq[ExprIR], # noqa: A002 + function: RangeT, + options: FunctionOptions, + **kwds: t.Any, + ) -> None: + # NOTE: `IntRange` has 2x scalar inputs, so always triggered error in parent + if len(input) < 2: + msg = f"Expected at least 2 inputs for `{function!r}()`, but got `{len(input)}`.\n`{input}`" + raise InvalidOperationError(msg) + super(ExprIR, self).__init__( + **dict(input=input, function=function, options=options, **kwds) + ) + + def __repr__(self) -> str: + return f"{self.function!r}({list(self.input)!r})" + + class Filter(ExprIR): __slots__ = ("expr", "by") # noqa: RUF023 diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index b878a02fce..c8a9cdcb6f 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -81,20 +81,6 @@ def unwrap(self) -> DummySeries: return self.value -class RangeLiteral(LiteralValue): - """Don't need yet, but might push forward the discussions. - - - https://github.com/narwhals-dev/narwhals/issues/2463#issuecomment-2844654064 - - https://github.com/narwhals-dev/narwhals/issues/2307#issuecomment-2832422364. - """ - - __slots__ = ("dtype", "high", "low") - - low: int - high: int - dtype: DType - - def _is_scalar( obj: ScalarLiteral[NonNestedLiteralT] | Any, ) -> TypeIs[ScalarLiteral[NonNestedLiteralT]]: diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 1193993987..922b1cfa63 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -27,7 +27,13 @@ class FunctionFlags(enum.Flag): """Automatically explode on unit length if it ran as final aggregation.""" ROW_SEPARABLE = 1 << 8 - """`drop_nulls` is the only one we've got that is *just* this. + """Given a function `f` and a column of values `[v1, ..., vn]`. + + `f` is row-separable *iff*: + + f([v1, ..., vn]) = concat(f(v1, ... vm), f(vm+1, ..., vn)) + + In isolation, used on `drop_nulls`, `int_range` https://github.com/pola-rs/polars/pull/22573 """ diff --git a/narwhals/_plan/ranges.py b/narwhals/_plan/ranges.py new file mode 100644 index 0000000000..a00d6b0e53 --- /dev/null +++ b/narwhals/_plan/ranges.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from narwhals._plan.common import ExprIR, Function +from narwhals._plan.options import FunctionOptions + +if TYPE_CHECKING: + from typing_extensions import Self + + from narwhals._plan.expr import RangeExpr + from narwhals.dtypes import IntegerType + + +class RangeFunction(Function): + def __repr__(self) -> str: + tp = type(self) + if tp is RangeFunction: + return tp.__name__ + m: dict[type[RangeFunction], str] = {IntRange: "int_range"} + return m[tp] + + def to_function_expr(self, *inputs: ExprIR) -> RangeExpr[Self]: + from narwhals._plan.expr import RangeExpr + + return RangeExpr(input=inputs, function=self, options=self.function_options) + + +class IntRange(RangeFunction): + """Not implemented yet, but might push forward [#2722]. + + See [`rust` entrypoint], which is roughly: + + Expr::Function { [start, end], FunctionExpr::Range(RangeFunction::IntRange { step, dtype }) } + + `narwhals` equivalent: + + FunctionExpr(input=(start, end), function=IntRange(step=step, dtype=dtype)) + + [#2722]: https://github.com/narwhals-dev/narwhals/issues/2722 + [`rust` entrypoint]: https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/dsl/functions/range.rs#L14-L23 + """ + + __slots__ = ("step", "dtype") # noqa: RUF023 + + step: int + dtype: IntegerType + + @property + def function_options(self) -> FunctionOptions: + return FunctionOptions.row_separable() diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 753d8d8d6a..3f3b512b52 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -7,13 +7,15 @@ if t.TYPE_CHECKING: from typing_extensions import TypeAlias + from narwhals import dtypes from narwhals._compliant import CompliantNamespace as Namespace from narwhals._compliant.typing import CompliantExprAny from narwhals._plan import operators as ops from narwhals._plan.common import ExprIR, Function, IRNamespace, NamedIR, SelectorIR from narwhals._plan.dummy import DummyExpr, DummySeries from narwhals._plan.functions import RollingWindow - from narwhals.typing import NonNestedLiteral + from narwhals._plan.ranges import RangeFunction + from narwhals.typing import NonNestedDType, NonNestedLiteral __all__ = [ "FunctionT", @@ -26,6 +28,7 @@ "NonNestedLiteralT", "OperatorFn", "OperatorT", + "RangeT", "RightSelectorT", "RightT", "RollingT", @@ -36,8 +39,9 @@ ] -FunctionT = TypeVar("FunctionT", bound="Function") -RollingT = TypeVar("RollingT", bound="RollingWindow") +FunctionT = TypeVar("FunctionT", bound="Function", default="Function") +RollingT = TypeVar("RollingT", bound="RollingWindow", default="RollingWindow") +RangeT = TypeVar("RangeT", bound="RangeFunction", default="RangeFunction") LeftT = TypeVar("LeftT", bound="ExprIR", default="ExprIR") LeftT2 = TypeVar("LeftT2", bound="ExprIR", default="ExprIR") OperatorT = TypeVar("OperatorT", bound="ops.Operator", default="ops.Operator") @@ -56,6 +60,9 @@ ) IRNamespaceT = TypeVar("IRNamespaceT", bound="IRNamespace") +DTypeT = TypeVar("DTypeT", bound="dtypes.DType") +NonNestedDTypeT = TypeVar("NonNestedDTypeT", bound="NonNestedDType") + NonNestedLiteralT = TypeVar( "NonNestedLiteralT", bound="NonNestedLiteral", default="NonNestedLiteral" ) diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 61e3c4d926..35826a0dad 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -16,7 +16,7 @@ ) from narwhals._plan.common import ExprIR, Function from narwhals._plan.dummy import DummyExpr, DummySeries -from narwhals._plan.expr import BinaryExpr, FunctionExpr +from narwhals._plan.expr import BinaryExpr, FunctionExpr, RangeExpr from narwhals.exceptions import ( InvalidIntoExprError, InvalidOperationError, @@ -156,6 +156,11 @@ def test_invalid_agg_non_elementwise() -> None: nwd.col("a").min().diff() +def test_agg_non_elementwise_range_special() -> None: + e = nwd.int_range(0, 100) + assert isinstance(e._ir, RangeExpr) + + # NOTE: Non-`polars`` rule def test_invalid_over() -> None: pattern = re.compile(r"cannot use.+over.+elementwise", re.IGNORECASE) diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index fcb22d09a3..67f039a656 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -170,6 +170,15 @@ def test_meta_root_names( "ROOT-ALIAS", id="Filter", ), + pytest.param( + nwd.int_range(0, 10), pl.int_range(0, 10), "literal", id="IntRange-Literal" + ), + pytest.param( + nwd.int_range(nwd.col("b"), 10), + pl.int_range(pl.col("b"), 10), + "b", + id="IntRange-Column", + ), ], ) def test_meta_output_name(nw_expr: DummyExpr, pl_expr: pl.Expr, expected: str) -> None: From 7324753a942431d12eb3229333209bae6b33deb4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 25 Jun 2025 16:47:43 +0100 Subject: [PATCH 248/368] feat: Add `NamedIR.is_elementwise_top_level` --- narwhals/_plan/common.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index ed238b92ef..32ff7cad21 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -25,6 +25,7 @@ from typing_extensions import Never, Self, TypeIs, dataclass_transform + from narwhals._plan import expr from narwhals._plan.dummy import DummyExpr, DummySelector, DummySeries from narwhals._plan.expr import FunctionExpr, WindowExpr from narwhals._plan.meta import IRMetaNamespace @@ -313,6 +314,26 @@ def __repr__(self) -> str: def _repr_html_(self) -> str: return f"{self.name}={self.expr._repr_html_()}" + def is_elementwise_top_level(self) -> bool: + """Return True if the outermost node is elementwise. + + Based on [`polars_plan::plans::aexpr::properties::AExpr.is_elementwise_top_level`] + + This check: + - Is not recursive + - Is not valid on `ExprIR` *prior* to being expanded + + [`polars_plan::plans::aexpr::properties::AExpr.is_elementwise_top_level`]: https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/aexpr/properties.rs#L16-L44 + """ + from narwhals._plan import expr + + ir = self.expr + if is_function_expr(ir): + return ir.options.is_elementwise() + if is_literal(ir): + return ir.is_scalar + return isinstance(ir, (expr.BinaryExpr, expr.Column, expr.Ternary, expr.Cast)) + class IRNamespace(Immutable): __slots__ = ("_ir",) @@ -426,6 +447,12 @@ def is_function_expr(obj: Any) -> TypeIs[FunctionExpr[Any]]: return isinstance(obj, FunctionExpr) +def is_literal(obj: Any) -> TypeIs[expr.Literal[Any]]: + from narwhals._plan import expr + + return isinstance(obj, expr.Literal) + + def is_horizontal_reduction(obj: FunctionExpr[Any] | Any) -> TypeIs[FunctionExpr[Any]]: return is_function_expr(obj) and obj.options.is_input_wildcard_expansion() From 1f6e1da1ce96b5e92589b9cca4b670c2285bbf6a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 26 Jun 2025 16:35:44 +0100 Subject: [PATCH 249/368] feat: Initial `rewrite_binary_agg_over` impl - Seems enough for `(nw.col("a") - nw.col("a").mean()).over("b")` - Need to try out more complex examples in tests --- narwhals/_plan/common.py | 22 ++++++++++++++++++++-- narwhals/_plan/expr_rewrites.py | 23 +++++++++++++++++++---- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 32ff7cad21..b40fa2eae5 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -2,7 +2,7 @@ import datetime as dt from decimal import Decimal -from typing import TYPE_CHECKING, Generic, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, Generic, TypeIs, TypeVar, cast, overload from narwhals._plan.typing import ( DTypeT, @@ -27,7 +27,7 @@ from narwhals._plan import expr from narwhals._plan.dummy import DummyExpr, DummySelector, DummySeries - from narwhals._plan.expr import FunctionExpr, WindowExpr + from narwhals._plan.expr import Agg, BinaryExpr, FunctionExpr, WindowExpr from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.options import FunctionOptions from narwhals.typing import NonNestedDType, NonNestedLiteral @@ -447,6 +447,24 @@ def is_function_expr(obj: Any) -> TypeIs[FunctionExpr[Any]]: return isinstance(obj, FunctionExpr) +def is_binary_expr(obj: Any) -> TypeIs[BinaryExpr]: + from narwhals._plan.expr import BinaryExpr + + return isinstance(obj, BinaryExpr) + + +# TODO @dangotbanned: Rename `Agg` -> `AggExpr` +def is_agg_expr(obj: Any) -> TypeIs[Agg]: + from narwhals._plan.expr import Agg + + return isinstance(obj, Agg) + + +def is_aggregation(obj: Any) -> TypeIs[Agg | FunctionExpr[Any]]: + """Superset of `ExprIR.is_scalar`, excludes literals & len.""" + return is_agg_expr(obj) or (is_function_expr(obj) and obj.is_scalar) + + def is_literal(obj: Any) -> TypeIs[expr.Literal[Any]]: from narwhals._plan import expr diff --git a/narwhals/_plan/expr_rewrites.py b/narwhals/_plan/expr_rewrites.py index 3d883fbae1..b3e31bccd1 100644 --- a/narwhals/_plan/expr_rewrites.py +++ b/narwhals/_plan/expr_rewrites.py @@ -5,7 +5,14 @@ from typing import TYPE_CHECKING from narwhals._plan import expr_parsing as parse -from narwhals._plan.common import NamedIR, is_function_expr, is_window_expr, map_ir +from narwhals._plan.common import ( + NamedIR, + is_aggregation, + is_binary_expr, + is_function_expr, + is_window_expr, + map_ir, +) from narwhals._plan.expr_expansion import ( IntoFrozenSchema, into_named_irs, @@ -59,8 +66,8 @@ def rewrite_elementwise_over(window: ExprIR, /) -> ExprIR: return window -# TODO @dangotbanned: Full implementation -def rewrite_binary_agg_over(child: ExprIR, /) -> ExprIR: +# TODO @dangotbanned: Tests +def rewrite_binary_agg_over(window: ExprIR, /) -> ExprIR: """Requested in [discord-1], clarified in [discord-2]. Before: @@ -74,4 +81,12 @@ def rewrite_binary_agg_over(child: ExprIR, /) -> ExprIR: [discord-1]: https://discord.com/channels/1235257048170762310/1383078215303696544/1384850753008435372 [discord-2]: https://discord.com/channels/1235257048170762310/1383078215303696544/1384869107203047588 """ - raise NotImplementedError + if ( + is_window_expr(window) + and is_binary_expr(window.expr) + and (is_aggregation(window.expr.right)) + ): + binary_expr = window.expr + rhs = window.expr.right + return binary_expr.with_right(window.with_expr(rhs)) + return window From f2ac1c6483169dfad9d8c727fc4a46eacb636780 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 26 Jun 2025 16:38:23 +0100 Subject: [PATCH 250/368] fix: undo `TypeIs` import https://github.com/narwhals-dev/narwhals/actions/runs/15906219835/job/44861709053?pr=2572 --- narwhals/_plan/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index b40fa2eae5..955494780e 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -2,7 +2,7 @@ import datetime as dt from decimal import Decimal -from typing import TYPE_CHECKING, Any, Generic, TypeIs, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload from narwhals._plan.typing import ( DTypeT, From e7268654f98aa0c579bfa1b7931b8931c66cfae4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 26 Jun 2025 17:27:06 +0100 Subject: [PATCH 251/368] fix: Ensure lhs gets leaf name used Discovered while writing a different test --- narwhals/_plan/meta.py | 2 ++ tests/plan/expr_expansion_test.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 5cd097ebea..1a5cbb8eac 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -153,6 +153,8 @@ def get_single_leaf_name(ir: ExprIR) -> str | ComputeError: for e in ir.iter_right(): if isinstance(e, (expr.WindowExpr, expr.SortBy, expr.Filter)): return get_single_leaf_name(e.expr) + if isinstance(e, expr.BinaryExpr): + return get_single_leaf_name(e.left) # NOTE: `polars` doesn't include `Literal` here if isinstance(e, (expr.Column, expr.Len)): return e.name diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index fefbb93aa0..a5361a671d 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -403,6 +403,23 @@ def test_replace_selector( .over(order_by=[nwd.col("k")]) ], ), + pytest.param( + (ndcs.by_name("a", "b", "c") / nwd.col("e").first()) + .over("g", "f", order_by="f") + .name.prefix("hi_"), + [ + (nwd.col("a") / nwd.col("e").first()) + .over("g", "f", order_by="f") + .alias("hi_a"), + (nwd.col("b") / nwd.col("e").first()) + .over("g", "f", order_by="f") + .alias("hi_b"), + (nwd.col("c") / nwd.col("e").first()) + .over("g", "f", order_by="f") + .alias("hi_c"), + ], + id="Selector-BinaryExpr-Over-Prefix", + ), ], ) def test_prepare_projection( From 6edc52c205e4e17f4404362c83066c51194cacbb Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 26 Jun 2025 17:34:49 +0100 Subject: [PATCH 252/368] test: Add some `rewrite_binary_agg_over` --- narwhals/_plan/expr_rewrites.py | 2 +- tests/plan/expr_rewrites_test.py | 37 +++++++++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/expr_rewrites.py b/narwhals/_plan/expr_rewrites.py index b3e31bccd1..870085b33d 100644 --- a/narwhals/_plan/expr_rewrites.py +++ b/narwhals/_plan/expr_rewrites.py @@ -66,7 +66,7 @@ def rewrite_elementwise_over(window: ExprIR, /) -> ExprIR: return window -# TODO @dangotbanned: Tests +# TODO @dangotbanned: Tests (single ✔️, multiple ✔️, complex ❌) def rewrite_binary_agg_over(window: ExprIR, /) -> ExprIR: """Requested in [discord-1], clarified in [discord-2]. diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index 455627936e..6d1d9576b0 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -8,7 +8,11 @@ from narwhals._plan import demo as nwd, expr_parsing as parse, selectors as ndcs from narwhals._plan.common import ExprIR, NamedIR, is_expr from narwhals._plan.expr import WindowExpr -from narwhals._plan.expr_rewrites import rewrite_all, rewrite_elementwise_over +from narwhals._plan.expr_rewrites import ( + rewrite_all, + rewrite_binary_agg_over, + rewrite_elementwise_over, +) from narwhals._plan.window import Over from narwhals.exceptions import InvalidOperationError from tests.plan.utils import assert_expr_ir_equal @@ -116,3 +120,34 @@ def test_rewrite_elementwise_over_complex(schema_2: dict[str, DType]) -> None: assert len(actual) == len(expected) for lhs, rhs in zip(actual, expected): assert_expr_ir_equal(lhs, rhs) + + +def test_rewrite_binary_agg_over_simple(schema_2: dict[str, DType]) -> None: + expected = ( + nwd.col("a") - nwd.col("a").mean().over("b"), + nwd.col("c") * nwd.col("c").abs().null_count().over("d"), + ) + before = ( + (nwd.col("a") - nwd.col("a").mean()).over("b"), + (nwd.col("c") * nwd.col("c").abs().null_count()).over("d"), + ) + actual = rewrite_all(*before, schema=schema_2, rewrites=[rewrite_binary_agg_over]) + assert len(actual) == 2 + for lhs, rhs in zip(actual, expected): + assert_expr_ir_equal(lhs, rhs) + + +def test_rewrite_binary_agg_over_multiple(schema_2: dict[str, DType]) -> None: + expected = ( + named_ir("hi_a", nwd.col("a") / nwd.col("e").drop_nulls().first().over("g")), + named_ir("hi_b", nwd.col("b") / nwd.col("e").drop_nulls().first().over("g")), + named_ir("hi_c", nwd.col("c") / nwd.col("e").drop_nulls().first().over("g")), + named_ir("hi_d", nwd.col("d") / nwd.col("e").drop_nulls().first().over("g")), + ) + before = ( + (nwd.col("a", "b", "c", "d") / nwd.col("e").drop_nulls().first()).over("g") + ).name.prefix("hi_") + actual = rewrite_all(before, schema=schema_2, rewrites=[rewrite_binary_agg_over]) + assert len(actual) == 4 + for lhs, rhs in zip(actual, expected): + assert_expr_ir_equal(lhs, rhs) From 463d75af30497beea930e9db28232c08ea85a5fa Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 2 Jul 2025 13:51:34 +0100 Subject: [PATCH 253/368] feat: `FrozenSchema` repr# --- narwhals/_plan/expr_expansion.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index a4ed337bd2..aace32cba4 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -173,6 +173,11 @@ def __getitem__(self, key: str, /) -> DType: def __len__(self) -> int: return self._mapping.__len__() + def __repr__(self) -> str: + sep, nl, indent = ",", "\n", " " + items = f"{sep}{nl}{indent}".join(repr(tuple(els)) for els in self.items()) + return f"{type(self).__name__}([{nl}{indent}{items}{sep}{nl}])" + def freeze_schema(**schema: DType) -> FrozenSchema: schema_hash = tuple(schema.items()) From 08dcfca2f8c43b2a454a4ab6575d00bc6ca17b42 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 2 Jul 2025 14:22:37 +0100 Subject: [PATCH 254/368] refactor: Split out `FrozenSchema` --- narwhals/_plan/expr_expansion.py | 114 +++--------------------------- narwhals/_plan/expr_rewrites.py | 7 +- narwhals/_plan/schema.py | 114 ++++++++++++++++++++++++++++++ tests/plan/expr_expansion_test.py | 2 +- 4 files changed, 126 insertions(+), 111 deletions(-) create mode 100644 narwhals/_plan/schema.py diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index aace32cba4..1075dc885e 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -41,11 +41,9 @@ from collections import deque from functools import lru_cache from itertools import chain -from types import MappingProxyType -from typing import TYPE_CHECKING, TypeVar, overload +from typing import TYPE_CHECKING from narwhals._plan.common import ( - _IMMUTABLE_HASH_NAME, ExprIR, Immutable, NamedIR, @@ -69,18 +67,17 @@ _ColumnSelection, col, ) +from narwhals._plan.schema import ( + FrozenColumns, + FrozenSchema, + IntoFrozenSchema, + freeze_schema, +) from narwhals.dtypes import DType from narwhals.exceptions import ComputeError, InvalidOperationError if TYPE_CHECKING: - from collections.abc import ( - ItemsView, - Iterator, - KeysView, - Mapping, - Sequence, - ValuesView, - ) + from collections.abc import Iterator, Sequence from typing_extensions import TypeAlias @@ -88,14 +85,7 @@ from narwhals._plan.typing import Seq from narwhals.dtypes import DType -IntoFrozenSchema: TypeAlias = "Mapping[str, DType] | FrozenSchema" -"""A schema to freeze, or an already frozen one. -As `DType` instances (`.values()`) are hashable, we can coerce the schema -into a cache-safe proxy structure (`FrozenSchema`). -""" - -FrozenColumns: TypeAlias = "Seq[str]" Excluded: TypeAlias = "frozenset[str]" """Internally use a `set`, then freeze before returning.""" @@ -109,90 +99,6 @@ OutputNames: TypeAlias = "Seq[str]" """Fully expanded, validated output column names, for `NamedIR`s.""" -_FrozenSchemaHash: TypeAlias = "Seq[tuple[str, DType]]" -_T2 = TypeVar("_T2") - - -# NOTE: Both `_freeze` functions will probably want to be cached -# In the traversal/expand/replacement functions, their returns will be hashable -> safe to cache those as well -class FrozenSchema(Immutable): - """Use `freeze_schema(...)` constructor to trigger caching!""" - - __slots__ = ("_mapping",) - _mapping: MappingProxyType[str, DType] - - @property - def __immutable_hash__(self) -> int: - if hasattr(self, _IMMUTABLE_HASH_NAME): - return self.__immutable_hash_value__ - hash_value = hash((self.__class__, *tuple(self._mapping.items()))) - object.__setattr__(self, _IMMUTABLE_HASH_NAME, hash_value) - return self.__immutable_hash_value__ - - @property - def names(self) -> FrozenColumns: - """Get the column names of the schema.""" - return freeze_columns(self) - - @staticmethod - def _from_mapping(mapping: MappingProxyType[str, DType], /) -> FrozenSchema: - return FrozenSchema(_mapping=mapping) - - @staticmethod - def _from_hash_safe(items: _FrozenSchemaHash, /) -> FrozenSchema: - clone = MappingProxyType(dict(items)) - return FrozenSchema._from_mapping(clone) - - def items(self) -> ItemsView[str, DType]: - return self._mapping.items() - - def keys(self) -> KeysView[str]: - return self._mapping.keys() - - def values(self) -> ValuesView[DType]: - return self._mapping.values() - - @overload - def get(self, key: str, /) -> DType | None: ... - @overload - def get(self, key: str, default: DType | _T2, /) -> DType | _T2: ... - def get(self, key: str, default: DType | _T2 | None = None, /) -> DType | _T2 | None: - if default is not None: - return self._mapping.get(key, default) - return self._mapping.get(key) - - def __iter__(self) -> Iterator[str]: - yield from self._mapping - - def __contains__(self, key: object) -> bool: - return self._mapping.__contains__(key) - - def __getitem__(self, key: str, /) -> DType: - return self._mapping.__getitem__(key) - - def __len__(self) -> int: - return self._mapping.__len__() - - def __repr__(self) -> str: - sep, nl, indent = ",", "\n", " " - items = f"{sep}{nl}{indent}".join(repr(tuple(els)) for els in self.items()) - return f"{type(self).__name__}([{nl}{indent}{items}{sep}{nl}])" - - -def freeze_schema(**schema: DType) -> FrozenSchema: - schema_hash = tuple(schema.items()) - return _freeze_schema_cache(schema_hash) - - -@lru_cache(maxsize=100) -def _freeze_schema_cache(schema: _FrozenSchemaHash, /) -> FrozenSchema: - return FrozenSchema._from_hash_safe(schema) - - -@lru_cache(maxsize=100) -def freeze_columns(schema: FrozenSchema, /) -> FrozenColumns: - return tuple(schema) - class ExpansionFlags(Immutable): """`polars` uses a struct, but we may want to use `enum.Flag`.""" @@ -274,9 +180,7 @@ def prepare_projection( Returns: `exprs`, rewritten using `Column(name)` only. """ - frozen_schema = ( - schema if isinstance(schema, FrozenSchema) else freeze_schema(**schema) - ) + frozen_schema = freeze_schema(schema) rewritten = rewrite_projections(tuple(exprs), keys=(), schema=frozen_schema) output_names = ensure_valid_exprs(rewritten, frozen_schema) # TODO @dangotbanned: (Seq[ExprIR], OutputNames) -> (Seq[NamedIR]) diff --git a/narwhals/_plan/expr_rewrites.py b/narwhals/_plan/expr_rewrites.py index 870085b33d..705dec3371 100644 --- a/narwhals/_plan/expr_rewrites.py +++ b/narwhals/_plan/expr_rewrites.py @@ -13,16 +13,13 @@ is_window_expr, map_ir, ) -from narwhals._plan.expr_expansion import ( - IntoFrozenSchema, - into_named_irs, - prepare_projection, -) +from narwhals._plan.expr_expansion import into_named_irs, prepare_projection if TYPE_CHECKING: from collections.abc import Sequence from narwhals._plan.common import ExprIR + from narwhals._plan.schema import IntoFrozenSchema from narwhals._plan.typing import IntoExpr, MapIR, Seq diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py new file mode 100644 index 0000000000..090a91fc8f --- /dev/null +++ b/narwhals/_plan/schema.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from functools import lru_cache +from types import MappingProxyType +from typing import TYPE_CHECKING, TypeVar, overload + +from narwhals._plan.common import _IMMUTABLE_HASH_NAME, Immutable + +if TYPE_CHECKING: + from collections.abc import ItemsView, Iterator, KeysView, Mapping, ValuesView + + from typing_extensions import TypeAlias + + from narwhals._plan.typing import Seq + from narwhals.dtypes import DType + + +IntoFrozenSchema: TypeAlias = "Mapping[str, DType] | FrozenSchema" +"""A schema to freeze, or an already frozen one. + +As `DType` instances (`.values()`) are hashable, we can coerce the schema +into a cache-safe proxy structure (`FrozenSchema`). +""" + +FrozenColumns: TypeAlias = "Seq[str]" +_FrozenSchemaHash: TypeAlias = "Seq[tuple[str, DType]]" +_T2 = TypeVar("_T2") + + +class FrozenSchema(Immutable): + """Use `freeze_schema(...)` constructor to trigger caching!""" + + __slots__ = ("_mapping",) + _mapping: MappingProxyType[str, DType] + + @property + def __immutable_hash__(self) -> int: + if hasattr(self, _IMMUTABLE_HASH_NAME): + return self.__immutable_hash_value__ + hash_value = hash((self.__class__, *tuple(self._mapping.items()))) + object.__setattr__(self, _IMMUTABLE_HASH_NAME, hash_value) + return self.__immutable_hash_value__ + + @property + def names(self) -> FrozenColumns: + """Get the column names of the schema.""" + return freeze_columns(self) + + @staticmethod + def _from_mapping(mapping: MappingProxyType[str, DType], /) -> FrozenSchema: + return FrozenSchema(_mapping=mapping) + + @staticmethod + def _from_hash_safe(items: _FrozenSchemaHash, /) -> FrozenSchema: + clone = MappingProxyType(dict(items)) + return FrozenSchema._from_mapping(clone) + + def items(self) -> ItemsView[str, DType]: + return self._mapping.items() + + def keys(self) -> KeysView[str]: + return self._mapping.keys() + + def values(self) -> ValuesView[DType]: + return self._mapping.values() + + @overload + def get(self, key: str, /) -> DType | None: ... + @overload + def get(self, key: str, default: DType | _T2, /) -> DType | _T2: ... + def get(self, key: str, default: DType | _T2 | None = None, /) -> DType | _T2 | None: + if default is not None: + return self._mapping.get(key, default) + return self._mapping.get(key) + + def __iter__(self) -> Iterator[str]: + yield from self._mapping + + def __contains__(self, key: object) -> bool: + return self._mapping.__contains__(key) + + def __getitem__(self, key: str, /) -> DType: + return self._mapping.__getitem__(key) + + def __len__(self) -> int: + return self._mapping.__len__() + + def __repr__(self) -> str: + sep, nl, indent = ",", "\n", " " + items = f"{sep}{nl}{indent}".join(repr(tuple(els)) for els in self.items()) + return f"{type(self).__name__}([{nl}{indent}{items}{sep}{nl}])" + + +@overload +def freeze_schema(mapping: IntoFrozenSchema, /) -> FrozenSchema: ... +@overload +def freeze_schema(**schema: DType) -> FrozenSchema: ... +def freeze_schema( + mapping: IntoFrozenSchema | None = None, /, **schema: DType +) -> FrozenSchema: + if mapping and isinstance(mapping, FrozenSchema): + return mapping + schema_hash = tuple((mapping or schema).items()) + return _freeze_schema_cache(schema_hash) + + +@lru_cache(maxsize=100) +def _freeze_schema_cache(schema: _FrozenSchemaHash, /) -> FrozenSchema: + return FrozenSchema._from_hash_safe(schema) + + +@lru_cache(maxsize=100) +def freeze_columns(schema: FrozenSchema, /) -> FrozenColumns: + return tuple(schema) diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index a5361a671d..91ebe99752 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -9,12 +9,12 @@ from narwhals._plan import demo as nwd, selectors as ndcs from narwhals._plan.expr import Alias, Columns from narwhals._plan.expr_expansion import ( - freeze_schema, prepare_projection, replace_selector, rewrite_special_aliases, ) from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir +from narwhals._plan.schema import freeze_schema from narwhals.exceptions import ColumnNotFoundError, ComputeError, DuplicateError from tests.plan.utils import assert_expr_ir_equal From 17822e8c014f6a44a1b69a2d1e8046772a340de5 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 2 Jul 2025 15:41:31 +0100 Subject: [PATCH 255/368] planning schema projection Lots of gaps, but need to start somewhere --- narwhals/_plan/contexts.py | 31 ++++++++++++++++++++++++++ narwhals/_plan/schema.py | 45 +++++++++++++++++++++++++++++++------- 2 files changed, 68 insertions(+), 8 deletions(-) create mode 100644 narwhals/_plan/contexts.py diff --git a/narwhals/_plan/contexts.py b/narwhals/_plan/contexts.py new file mode 100644 index 0000000000..773b699df9 --- /dev/null +++ b/narwhals/_plan/contexts.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import enum + +__all__ = ["ExprContext"] + + +class ExprContext(enum.Enum): + """A [context] to evaluate expressions in. + + [context]: https://docs.pola.rs/user-guide/concepts/expressions-and-contexts/#contexts + """ + + SELECT = "select" + """The output schema has the same order and length as the (expanded) input expressions. + + That order is determined during expansion of selectors in an earlier step. + """ + + WITH_COLUMNS = "with_columns" + """The output schema *derives from* the input schema, but *may* produce a different shape. + + - Expressions producing **new names** are appended to the end of the schema + - Expressions producing **existing names** will replace the existing column positionally + """ + + def is_select(self) -> bool: + return self is ExprContext.SELECT + + def is_with_columns(self) -> bool: + return self is ExprContext.WITH_COLUMNS diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index 090a91fc8f..30246993a8 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -1,21 +1,26 @@ from __future__ import annotations +from collections.abc import Mapping from functools import lru_cache from types import MappingProxyType from typing import TYPE_CHECKING, TypeVar, overload -from narwhals._plan.common import _IMMUTABLE_HASH_NAME, Immutable +from narwhals._plan.common import _IMMUTABLE_HASH_NAME, Immutable, NamedIR +from narwhals.dtypes import Unknown if TYPE_CHECKING: - from collections.abc import ItemsView, Iterator, KeysView, Mapping, ValuesView + from collections.abc import ItemsView, Iterator, KeysView, ValuesView from typing_extensions import TypeAlias + from narwhals._plan.contexts import ExprContext from narwhals._plan.typing import Seq from narwhals.dtypes import DType -IntoFrozenSchema: TypeAlias = "Mapping[str, DType] | FrozenSchema" +IntoFrozenSchema: TypeAlias = ( + "Mapping[str, DType] | Iterator[tuple[str, DType]] | FrozenSchema" +) """A schema to freeze, or an already frozen one. As `DType` instances (`.values()`) are hashable, we can coerce the schema @@ -33,6 +38,29 @@ class FrozenSchema(Immutable): __slots__ = ("_mapping",) _mapping: MappingProxyType[str, DType] + def project( + self, exprs: Seq[NamedIR], context: ExprContext + ) -> tuple[Seq[NamedIR], FrozenSchema]: + if context.is_select(): + return exprs, self._select(exprs) + if context.is_with_columns(): + raise NotImplementedError(context) + raise TypeError(context) + + def _select(self, exprs: Seq[NamedIR]) -> FrozenSchema: + """Return a new schema, equivalent to performing `df.select(*exprs)`. + + Arguments: + exprs: Expanded, unaliased expressions. + + Notes: + - New columns all use the `Unknown` dtype + - Any `cast` nodes are not reflected in the schema + """ + names = (e.name for e in exprs) + default = Unknown() + return freeze_schema((name, self.get(name, default)) for name in names) + @property def __immutable_hash__(self) -> int: if hasattr(self, _IMMUTABLE_HASH_NAME): @@ -96,12 +124,13 @@ def freeze_schema(mapping: IntoFrozenSchema, /) -> FrozenSchema: ... @overload def freeze_schema(**schema: DType) -> FrozenSchema: ... def freeze_schema( - mapping: IntoFrozenSchema | None = None, /, **schema: DType + iterable: IntoFrozenSchema | None = None, /, **schema: DType ) -> FrozenSchema: - if mapping and isinstance(mapping, FrozenSchema): - return mapping - schema_hash = tuple((mapping or schema).items()) - return _freeze_schema_cache(schema_hash) + if isinstance(iterable, FrozenSchema): + return iterable + into = iterable or schema + hashable = tuple(into.items() if isinstance(into, Mapping) else into) + return _freeze_schema_cache(hashable) @lru_cache(maxsize=100) From 8243433aa1f43f20c529b9c6eb3f7f9b43562626 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 2 Jul 2025 16:00:44 +0100 Subject: [PATCH 256/368] revert: Drop unplanned `impl_arrow` bits --- narwhals/_plan/impl_arrow.py | 56 ++++-------------------------------- 1 file changed, 5 insertions(+), 51 deletions(-) diff --git a/narwhals/_plan/impl_arrow.py b/narwhals/_plan/impl_arrow.py index 6b883318f1..34abfc32cb 100644 --- a/narwhals/_plan/impl_arrow.py +++ b/narwhals/_plan/impl_arrow.py @@ -9,7 +9,7 @@ from functools import singledispatch from narwhals._plan import expr -from narwhals._plan.literal import is_literal_scalar, is_literal_series +from narwhals._plan.literal import is_literal_scalar if t.TYPE_CHECKING: import pyarrow as pa @@ -24,6 +24,8 @@ Evaluated: TypeAlias = t.Sequence[NativeSeries] +# TODO @dangotbanned: Update to operate on the output of `expr_expansion` or `expr_rewrites` +# No longer need: `Alias`, `Columns`, `Nth`, `All`, `Exclude`, `IndexColumns`, `RootSelector`, `BinarySelector`, `RenameAlias`, `KeepName` @singledispatch def evaluate(node: ExprIR, frame: NativeFrame) -> Evaluated: raise NotImplementedError(type(node)) @@ -34,6 +36,7 @@ def col(node: expr.Column, frame: NativeFrame) -> Evaluated: return [frame.column(node.name)] +# TODO @dangotbanned: Remove after updating tests @evaluate.register(expr.Columns) def cols(node: expr.Columns, frame: NativeFrame) -> Evaluated: return frame.select(list(node.names)).columns @@ -49,16 +52,7 @@ def lit( lit: t.Any = pa.scalar array = pa.repeat(lit(node.unwrap()), len(frame)) return [pa.chunked_array([array])] - elif is_literal_series(node): - ca = node.unwrap().to_native() - return [t.cast("NativeSeries", ca)] - else: - raise NotImplementedError(type(node.value)) - - -@evaluate.register(expr.Alias) -def alias(node: expr.Alias, frame: NativeFrame) -> Evaluated: - raise NotImplementedError(type(node)) + return [node.unwrap().to_native()] @evaluate.register(expr.Len) @@ -66,26 +60,6 @@ def len_(node: expr.Len, frame: NativeFrame) -> Evaluated: raise NotImplementedError(type(node)) -@evaluate.register(expr.Nth) -def nth(node: expr.Nth, frame: NativeFrame) -> Evaluated: - raise NotImplementedError(type(node)) - - -@evaluate.register(expr.IndexColumns) -def index_columns(node: expr.IndexColumns, frame: NativeFrame) -> Evaluated: - raise NotImplementedError(type(node)) - - -@evaluate.register(expr.All) -def all_(node: expr.All, frame: NativeFrame) -> Evaluated: - raise NotImplementedError(type(node)) - - -@evaluate.register(expr.Exclude) -def exclude(node: expr.Exclude, frame: NativeFrame) -> Evaluated: - raise NotImplementedError(type(node)) - - @evaluate.register(expr.Cast) def cast_(node: expr.Cast, frame: NativeFrame) -> Evaluated: raise NotImplementedError(type(node)) @@ -126,21 +100,6 @@ def window_expr(node: expr.WindowExpr, frame: NativeFrame) -> Evaluated: raise NotImplementedError(type(node)) -@evaluate.register(expr.RootSelector) -def selector(node: expr.RootSelector, frame: NativeFrame) -> Evaluated: - raise NotImplementedError(type(node)) - - -@evaluate.register(expr.BinarySelector) -def binary_selector(node: expr.BinarySelector, frame: NativeFrame) -> Evaluated: - raise NotImplementedError(type(node)) - - -@evaluate.register(expr.RenameAlias) -def rename_alias(node: expr.RenameAlias, frame: NativeFrame) -> Evaluated: - raise NotImplementedError(type(node)) - - @evaluate.register(expr.Sort) def sort(node: expr.Sort, frame: NativeFrame) -> Evaluated: raise NotImplementedError(type(node)) @@ -159,8 +118,3 @@ def filter_(node: expr.Filter, frame: NativeFrame) -> Evaluated: @evaluate.register(expr.AnonymousExpr) def anonymous_expr(node: expr.AnonymousExpr, frame: NativeFrame) -> Evaluated: raise NotImplementedError(type(node)) - - -@evaluate.register(expr.KeepName) -def keep_name(node: expr.KeepName, frame: NativeFrame) -> Evaluated: - raise NotImplementedError(type(node)) From 984c07b596d5ebb285fdc2e227ec1caa1da9cc5c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 2 Jul 2025 16:02:07 +0100 Subject: [PATCH 257/368] fix(typing): `*Series` generic --- narwhals/_plan/dummy.py | 24 +++++++++++++----------- narwhals/_plan/impl_arrow.py | 3 ++- narwhals/_plan/typing.py | 9 ++++++--- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index cb3d25ff50..446033a2c8 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -4,7 +4,7 @@ import math import typing as t -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Generic from narwhals._plan import ( aggregation as agg, @@ -24,6 +24,7 @@ SortOptions, ) from narwhals._plan.selectors import by_name +from narwhals._plan.typing import NativeSeriesT from narwhals._plan.window import Over from narwhals._utils import Version, _hasattr_static from narwhals.dtypes import DType @@ -45,7 +46,6 @@ ClosedInterval, FillNullStrategy, IntoDType, - NativeSeries, NumericLiteral, RankMethod, RollingInterpolationMethod, @@ -764,8 +764,8 @@ def to_narwhals(self) -> DummyExpr: return DummyExprV1._from_ir(self._ir) -class DummySeries: - _compliant: DummyCompliantSeries +class DummySeries(Generic[NativeSeriesT]): + _compliant: DummyCompliantSeries[NativeSeriesT] _version: t.ClassVar[Version] = Version.MAIN @property @@ -781,24 +781,26 @@ def name(self) -> str: return self._compliant.name @classmethod - def from_native(cls, native: NativeSeries, /) -> Self: + def from_native(cls, native: NativeSeriesT, /) -> Self: obj = cls.__new__(cls) - obj._compliant = DummyCompliantSeries.from_native(native, cls._version) + obj._compliant = DummyCompliantSeries[NativeSeriesT].from_native( + native, cls._version + ) return obj - def to_native(self) -> NativeSeries: + def to_native(self) -> NativeSeriesT: return self._compliant._native def __iter__(self) -> t.Iterator[t.Any]: yield from self.to_native() -class DummySeriesV1(DummySeries): +class DummySeriesV1(DummySeries[NativeSeriesT]): _version: t.ClassVar[Version] = Version.V1 -class DummyCompliantSeries: - _native: NativeSeries +class DummyCompliantSeries(Generic[NativeSeriesT]): + _native: NativeSeriesT _name: str _version: Version @@ -815,7 +817,7 @@ def name(self) -> str: return self._name @classmethod - def from_native(cls, native: NativeSeries, /, version: Version) -> Self: + def from_native(cls, native: NativeSeriesT, /, version: Version) -> Self: name: str = "" if _hasattr_static(native, "name"): name = getattr(native, "name", name) diff --git a/narwhals/_plan/impl_arrow.py b/narwhals/_plan/impl_arrow.py index 34abfc32cb..b255d02c3d 100644 --- a/narwhals/_plan/impl_arrow.py +++ b/narwhals/_plan/impl_arrow.py @@ -44,7 +44,8 @@ def cols(node: expr.Columns, frame: NativeFrame) -> Evaluated: @evaluate.register(expr.Literal) def lit( - node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries], frame: NativeFrame + node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[NativeSeries]], + frame: NativeFrame, ) -> Evaluated: import pyarrow as pa diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 3f3b512b52..e191854511 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -15,7 +15,7 @@ from narwhals._plan.dummy import DummyExpr, DummySeries from narwhals._plan.functions import RollingWindow from narwhals._plan.ranges import RangeFunction - from narwhals.typing import NonNestedDType, NonNestedLiteral + from narwhals.typing import NativeSeries, NonNestedDType, NonNestedLiteral __all__ = [ "FunctionT", @@ -66,7 +66,10 @@ NonNestedLiteralT = TypeVar( "NonNestedLiteralT", bound="NonNestedLiteral", default="NonNestedLiteral" ) -LiteralT = TypeVar("LiteralT", bound="NonNestedLiteral | DummySeries", default=t.Any) +NativeSeriesT = TypeVar("NativeSeriesT", bound="NativeSeries", default="NativeSeries") +LiteralT = TypeVar( + "LiteralT", bound="NonNestedLiteral | DummySeries[t.Any]", default=t.Any +) MapIR: TypeAlias = "t.Callable[[ExprIR], ExprIR]" """A function to apply to all nodes in this tree.""" @@ -90,5 +93,5 @@ Udf: TypeAlias = "t.Callable[[t.Any], t.Any]" """Placeholder for `map_batches(function=...)`.""" -IntoExprColumn: TypeAlias = "DummyExpr | DummySeries | str" +IntoExprColumn: TypeAlias = "DummyExpr | DummySeries[t.Any] | str" IntoExpr: TypeAlias = "NonNestedLiteral | IntoExprColumn" From 1468662cbc043ed5d06033e87aa476657cf8c695 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 2 Jul 2025 16:07:20 +0100 Subject: [PATCH 258/368] ci: Ignore dtypes import --- .pre-commit-config.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a5f706fddd..7ba7e68449 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -81,7 +81,8 @@ repos: narwhals/.*__init__.py| narwhals/.*typing\.py| narwhals/_plan/demo\.py| - narwhals/_plan/ranges\.py + narwhals/_plan/ranges\.py| + narwhals/_plan/schema\.py ) - id: pull-request-target name: don't use `pull_request_target` From 2985bd50a435fc5b83f454a06f8e76bbf38e036f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 2 Jul 2025 19:43:26 +0100 Subject: [PATCH 259/368] feat: Reimpl `pyarrow`, start on `select` --- narwhals/_plan/dummy.py | 82 +++++++++++-- narwhals/_plan/impl_arrow.py | 196 ++++++++++++++++++++++++-------- narwhals/_plan/typing.py | 8 +- tests/plan/to_compliant_test.py | 38 +++---- 4 files changed, 245 insertions(+), 79 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 446033a2c8..84be362f80 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -24,7 +24,7 @@ SortOptions, ) from narwhals._plan.selectors import by_name -from narwhals._plan.typing import NativeSeriesT +from narwhals._plan.typing import NativeFrameT, NativeSeriesT from narwhals._plan.window import Over from narwhals._utils import Version, _hasattr_static from narwhals.dtypes import DType @@ -764,6 +764,60 @@ def to_narwhals(self) -> DummyExpr: return DummyExprV1._from_ir(self._ir) +class DummyFrame(Generic[NativeFrameT, NativeSeriesT]): + _compliant: DummyCompliantFrame[NativeFrameT, NativeSeriesT] + _version: t.ClassVar[Version] = Version.MAIN + + @property + def version(self) -> Version: + return self._version + + @property + def _series(self) -> type[DummySeries]: + return DummySeries + + @classmethod + def from_native(cls, native: NativeFrameT, /) -> Self: + obj = cls.__new__(cls) + obj._compliant = DummyCompliantFrame[NativeFrameT, NativeSeriesT].from_native( + native, cls._version + ) + return obj + + def to_native(self) -> NativeFrameT: + return self._compliant.native + + def __len__(self) -> int: + return len(self._compliant) + + +class DummyCompliantFrame(Generic[NativeFrameT, NativeSeriesT]): + _native: NativeFrameT + _version: Version + + @property + def native(self) -> NativeFrameT: + return self._native + + @property + def version(self) -> Version: + return self._version + + @property + def _series(self) -> type[DummyCompliantSeries[NativeSeriesT]]: + return DummyCompliantSeries[NativeSeriesT] + + @classmethod + def from_native(cls, native: NativeFrameT, /, version: Version) -> Self: + obj = cls.__new__(cls) + obj._native = native + obj._version = version + return obj + + def __len__(self) -> int: + raise NotImplementedError + + class DummySeries(Generic[NativeSeriesT]): _compliant: DummyCompliantSeries[NativeSeriesT] _version: t.ClassVar[Version] = Version.MAIN @@ -781,15 +835,15 @@ def name(self) -> str: return self._compliant.name @classmethod - def from_native(cls, native: NativeSeriesT, /) -> Self: + def from_native(cls, native: NativeSeriesT, name: str = "", /) -> Self: obj = cls.__new__(cls) obj._compliant = DummyCompliantSeries[NativeSeriesT].from_native( - native, cls._version + native, name, version=cls._version ) return obj def to_native(self) -> NativeSeriesT: - return self._compliant._native + return self._compliant.native def __iter__(self) -> t.Iterator[t.Any]: yield from self.to_native() @@ -804,6 +858,10 @@ class DummyCompliantSeries(Generic[NativeSeriesT]): _name: str _version: Version + @property + def native(self) -> NativeSeriesT: + return self._native + @property def version(self) -> Version: return self._version @@ -817,12 +875,20 @@ def name(self) -> str: return self._name @classmethod - def from_native(cls, native: NativeSeriesT, /, version: Version) -> Self: - name: str = "" - if _hasattr_static(native, "name"): - name = getattr(native, "name", name) + def from_native( + cls, native: NativeSeriesT, name: str = "", /, *, version: Version + ) -> Self: + name = name or ( + getattr(native, "name", name) if _hasattr_static(native, "name") else name + ) obj = cls.__new__(cls) obj._native = native obj._name = name obj._version = version return obj + + def _with_native(self, native: NativeSeriesT) -> Self: + return self.from_native(native, self.name, version=self.version) + + def alias(self, name: str) -> Self: + return self.from_native(self.native, name, version=self.version) diff --git a/narwhals/_plan/impl_arrow.py b/narwhals/_plan/impl_arrow.py index b255d02c3d..5455611018 100644 --- a/narwhals/_plan/impl_arrow.py +++ b/narwhals/_plan/impl_arrow.py @@ -7,115 +7,211 @@ import typing as t from functools import singledispatch +from itertools import chain from narwhals._plan import expr +from narwhals._plan.contexts import ExprContext +from narwhals._plan.dummy import DummyCompliantFrame, DummyCompliantSeries +from narwhals._plan.expr_expansion import into_named_irs, prepare_projection +from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir from narwhals._plan.literal import is_literal_scalar +from narwhals._utils import Version if t.TYPE_CHECKING: import pyarrow as pa - from typing_extensions import TypeAlias + from typing_extensions import Self, TypeAlias, TypeIs - from narwhals._plan.common import ExprIR + from narwhals._plan.common import ExprIR, NamedIR from narwhals._plan.dummy import DummySeries + from narwhals._plan.typing import IntoExpr + from narwhals.dtypes import DType + from narwhals.schema import Schema from narwhals.typing import NonNestedLiteral - NativeFrame: TypeAlias = pa.Table - NativeSeries: TypeAlias = pa.ChunkedArray[t.Any] - Evaluated: TypeAlias = t.Sequence[NativeSeries] - -# TODO @dangotbanned: Update to operate on the output of `expr_expansion` or `expr_rewrites` -# No longer need: `Alias`, `Columns`, `Nth`, `All`, `Exclude`, `IndexColumns`, `RootSelector`, `BinarySelector`, `RenameAlias`, `KeepName` +NativeFrame: TypeAlias = "pa.Table" +NativeSeries: TypeAlias = "pa.ChunkedArray[t.Any]" + + +def is_series(obj: t.Any) -> TypeIs[ArrowSeries]: + return isinstance(obj, ArrowSeries) + + +class ArrowDataFrame(DummyCompliantFrame[NativeFrame, NativeSeries]): + @property + def _series(self) -> type[ArrowSeries]: + return ArrowSeries + + @property + def columns(self) -> list[str]: + return self.native.column_names + + @property + def schema(self) -> dict[str, DType]: + from narwhals._arrow.utils import native_to_narwhals_dtype + + schema = self.native.schema + return { + name: native_to_narwhals_dtype(dtype, self._version) + for name, dtype in zip(schema.names, schema.types) + } + + def __len__(self) -> int: + return len(self.native) + + @classmethod + def from_series( + cls, series: t.Iterable[ArrowSeries] | ArrowSeries, *more_series: ArrowSeries + ) -> Self: + lhs = (series,) if is_series(series) else series + it = chain(lhs, more_series) if more_series else lhs + return cls.from_dict({s.name: s.native for s in it}) + + @classmethod + def from_dict( + cls, + data: t.Mapping[str, t.Any], + /, + *, + schema: t.Mapping[str, DType] | Schema | None = None, + ) -> Self: + import pyarrow as pa + + from narwhals.schema import Schema + + pa_schema = Schema(schema).to_arrow() if schema is not None else schema + native = pa.Table.from_pydict(data, schema=pa_schema) + return cls.from_native(native, version=Version.MAIN) + + def iter_columns(self) -> t.Iterator[ArrowSeries]: + for name, series in zip(self.columns, self.native.itercolumns()): + yield ArrowSeries.from_native(series, name, version=self.version) + + @t.overload + def to_dict(self, *, as_series: t.Literal[True]) -> dict[str, ArrowSeries]: ... + + @t.overload + def to_dict(self, *, as_series: t.Literal[False]) -> dict[str, list[t.Any]]: ... + + def to_dict( + self, *, as_series: bool + ) -> dict[str, ArrowSeries] | dict[str, list[t.Any]]: + it = self.iter_columns() + if as_series: + return {ser.name: ser for ser in it} + return {ser.name: ser.to_list() for ser in it} + + def _evaluate_irs( + self, nodes: t.Iterable[NamedIR[ExprIR]], / + ) -> t.Iterator[ArrowSeries]: + for node in nodes: + yield self._series.from_native( + _evaluate_inner(node.expr, self), node.name, version=self.version + ) + + def select( + self, *exprs: IntoExpr | t.Iterable[IntoExpr], **named_exprs: t.Any + ) -> Self: + irs, schema_frozen, output_names = prepare_projection( + parse_into_seq_of_expr_ir(*exprs, **named_exprs), self.schema + ) + named_irs = into_named_irs(irs, output_names) + named_irs, schema_projected = schema_frozen.project(named_irs, ExprContext.SELECT) + return self.from_series(self._evaluate_irs(named_irs)) + + +class ArrowSeries(DummyCompliantSeries[NativeSeries]): + def to_list(self) -> list[t.Any]: + return self.native.to_pylist() + + +# NOTE: Should mean we produce 1x CompliantSeries for the entire expression +# Multi-output have already been separated +# No intermediate CompliantSeries need to be created, just assign a name to the final one @singledispatch -def evaluate(node: ExprIR, frame: NativeFrame) -> Evaluated: +def _evaluate_inner(node: ExprIR, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) -@evaluate.register(expr.Column) -def col(node: expr.Column, frame: NativeFrame) -> Evaluated: - return [frame.column(node.name)] - - -# TODO @dangotbanned: Remove after updating tests -@evaluate.register(expr.Columns) -def cols(node: expr.Columns, frame: NativeFrame) -> Evaluated: - return frame.select(list(node.names)).columns +@_evaluate_inner.register(expr.Column) +def col(node: expr.Column, frame: ArrowDataFrame) -> NativeSeries: + return frame.native.column(node.name) -@evaluate.register(expr.Literal) -def lit( +@_evaluate_inner.register(expr.Literal) +def lit_( node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[NativeSeries]], - frame: NativeFrame, -) -> Evaluated: + frame: ArrowDataFrame, +) -> NativeSeries: import pyarrow as pa if is_literal_scalar(node): lit: t.Any = pa.scalar array = pa.repeat(lit(node.unwrap()), len(frame)) - return [pa.chunked_array([array])] - return [node.unwrap().to_native()] + return pa.chunked_array([array]) + return node.unwrap().to_native() -@evaluate.register(expr.Len) -def len_(node: expr.Len, frame: NativeFrame) -> Evaluated: +@_evaluate_inner.register(expr.Len) +def len_(node: expr.Len, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) -@evaluate.register(expr.Cast) -def cast_(node: expr.Cast, frame: NativeFrame) -> Evaluated: +@_evaluate_inner.register(expr.Cast) +def cast_(node: expr.Cast, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) -@evaluate.register(expr.Ternary) -def ternary(node: expr.Ternary, frame: NativeFrame) -> Evaluated: +@_evaluate_inner.register(expr.Ternary) +def ternary(node: expr.Ternary, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) -@evaluate.register(expr.Agg) -def agg(node: expr.Agg, frame: NativeFrame) -> Evaluated: +@_evaluate_inner.register(expr.Agg) +def agg(node: expr.Agg, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) -@evaluate.register(expr.OrderableAgg) -def orderable_agg(node: expr.OrderableAgg, frame: NativeFrame) -> Evaluated: +@_evaluate_inner.register(expr.OrderableAgg) +def orderable_agg(node: expr.OrderableAgg, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) -@evaluate.register(expr.BinaryExpr) -def binary_expr(node: expr.BinaryExpr, frame: NativeFrame) -> Evaluated: +@_evaluate_inner.register(expr.BinaryExpr) +def binary_expr(node: expr.BinaryExpr, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) -@evaluate.register(expr.FunctionExpr) -def function_expr(node: expr.FunctionExpr[t.Any], frame: NativeFrame) -> Evaluated: +@_evaluate_inner.register(expr.FunctionExpr) +def function_expr(node: expr.FunctionExpr[t.Any], frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) -@evaluate.register(expr.RollingExpr) -def rolling_expr(node: expr.RollingExpr[t.Any], frame: NativeFrame) -> Evaluated: +@_evaluate_inner.register(expr.RollingExpr) +def rolling_expr(node: expr.RollingExpr[t.Any], frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) -@evaluate.register(expr.WindowExpr) -def window_expr(node: expr.WindowExpr, frame: NativeFrame) -> Evaluated: +@_evaluate_inner.register(expr.WindowExpr) +def window_expr(node: expr.WindowExpr, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) -@evaluate.register(expr.Sort) -def sort(node: expr.Sort, frame: NativeFrame) -> Evaluated: +@_evaluate_inner.register(expr.Sort) +def sort(node: expr.Sort, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) -@evaluate.register(expr.SortBy) -def sort_by(node: expr.SortBy, frame: NativeFrame) -> Evaluated: +@_evaluate_inner.register(expr.SortBy) +def sort_by(node: expr.SortBy, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) -@evaluate.register(expr.Filter) -def filter_(node: expr.Filter, frame: NativeFrame) -> Evaluated: +@_evaluate_inner.register(expr.Filter) +def filter_(node: expr.Filter, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) -@evaluate.register(expr.AnonymousExpr) -def anonymous_expr(node: expr.AnonymousExpr, frame: NativeFrame) -> Evaluated: +@_evaluate_inner.register(expr.AnonymousExpr) +def anonymous_expr(node: expr.AnonymousExpr, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index e191854511..e7f5fe03bc 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -15,7 +15,12 @@ from narwhals._plan.dummy import DummyExpr, DummySeries from narwhals._plan.functions import RollingWindow from narwhals._plan.ranges import RangeFunction - from narwhals.typing import NativeSeries, NonNestedDType, NonNestedLiteral + from narwhals.typing import ( + NativeFrame, + NativeSeries, + NonNestedDType, + NonNestedLiteral, + ) __all__ = [ "FunctionT", @@ -67,6 +72,7 @@ "NonNestedLiteralT", bound="NonNestedLiteral", default="NonNestedLiteral" ) NativeSeriesT = TypeVar("NativeSeriesT", bound="NativeSeries", default="NativeSeries") +NativeFrameT = TypeVar("NativeFrameT", bound="NativeFrame", default="NativeFrame") LiteralT = TypeVar( "LiteralT", bound="NonNestedLiteral | DummySeries[t.Any]", default=t.Any ) diff --git a/tests/plan/to_compliant_test.py b/tests/plan/to_compliant_test.py index a08dc0f5df..69dc1857dd 100644 --- a/tests/plan/to_compliant_test.py +++ b/tests/plan/to_compliant_test.py @@ -7,7 +7,7 @@ import narwhals as nw import narwhals._plan.demo as nwd from narwhals._plan.common import is_expr -from narwhals._plan.impl_arrow import evaluate as evaluate_pyarrow +from narwhals._plan.impl_arrow import ArrowDataFrame from narwhals.utils import Version from tests.namespace_test import backends @@ -16,6 +16,11 @@ from narwhals._plan.dummy import DummyExpr +@pytest.fixture +def data_small() -> dict[str, Any]: + return {"a": ["A", "B", "A"], "b": [1, 2, 3], "c": [9, 2, 4], "d": [8, 7, 8]} + + def _ids_ir(expr: DummyExpr | Any) -> str: if is_expr(expr): return repr(expr._ir) @@ -44,28 +49,21 @@ def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None: @pytest.mark.parametrize( ("expr", "expected"), [ - (nwd.col("a"), ["A", "B", "A"]), - (nwd.col("a", "b"), [["A", "B", "A"], [1, 2, 3]]), - (nwd.lit(1), [1, 1, 1]), - (nwd.lit(2.0), [2.0, 2.0, 2.0]), - (nwd.lit(None, nw.String()), [None, None, None]), + (nwd.col("a"), {"a": ["A", "B", "A"]}), + (nwd.col("a", "b"), {"a": ["A", "B", "A"], "b": [1, 2, 3]}), + (nwd.lit(1), {"literal": [1, 1, 1]}), + (nwd.lit(2.0), {"literal": [2.0, 2.0, 2.0]}), + (nwd.lit(None, nw.String()), {"literal": [None, None, None]}), ], ids=_ids_ir, ) -def test_evaluate_pyarrow(expr: DummyExpr, expected: Any) -> None: +def test_select( + expr: DummyExpr, expected: dict[str, Any], data_small: dict[str, Any] +) -> None: pytest.importorskip("pyarrow") import pyarrow as pa - data: dict[str, Any] = { - "a": ["A", "B", "A"], - "b": [1, 2, 3], - "c": [9, 2, 4], - "d": [8, 7, 8], - } - frame = pa.table(data) - result = evaluate_pyarrow(expr._ir, frame) - if len(result) == 1: - assert result[0].to_pylist() == expected - else: - results = [col.to_pylist() for col in result] - assert results == expected + frame = pa.table(data_small) + df = ArrowDataFrame.from_native(frame, Version.MAIN) + result = df.select(expr).to_dict(as_series=False) + assert result == expected From a85fc7e86d04344036e32d1818ac98dd331e6f85 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 2 Jul 2025 20:31:08 +0100 Subject: [PATCH 260/368] feat(pyarrow): Impl `Cast`, `Sort`, `Filter`, `Len` --- narwhals/_plan/impl_arrow.py | 61 +++++++++++++++++++++++------------- narwhals/_plan/options.py | 10 ++++++ 2 files changed, 49 insertions(+), 22 deletions(-) diff --git a/narwhals/_plan/impl_arrow.py b/narwhals/_plan/impl_arrow.py index 5455611018..7ba3e95825 100644 --- a/narwhals/_plan/impl_arrow.py +++ b/narwhals/_plan/impl_arrow.py @@ -5,6 +5,7 @@ from __future__ import annotations +# ruff: noqa: ARG001 import typing as t from functools import singledispatch from itertools import chain @@ -21,12 +22,13 @@ import pyarrow as pa from typing_extensions import Self, TypeAlias, TypeIs + from narwhals._arrow.typing import ScalarAny from narwhals._plan.common import ExprIR, NamedIR from narwhals._plan.dummy import DummySeries from narwhals._plan.typing import IntoExpr from narwhals.dtypes import DType from narwhals.schema import Schema - from narwhals.typing import NonNestedLiteral + from narwhals.typing import NonNestedLiteral, PythonLiteral NativeFrame: TypeAlias = "pa.Table" @@ -138,28 +140,53 @@ def col(node: expr.Column, frame: ArrowDataFrame) -> NativeSeries: return frame.native.column(node.name) +def _lit_native(value: PythonLiteral | ScalarAny, frame: ArrowDataFrame) -> NativeSeries: + """Will need to support returning a native scalar as well.""" + import pyarrow as pa + + from narwhals._arrow.utils import chunked_array + + lit: t.Any = pa.scalar + scalar: t.Any = value if isinstance(value, pa.Scalar) else lit(value) + array = pa.repeat(scalar, len(frame)) + return chunked_array(array) + + @_evaluate_inner.register(expr.Literal) def lit_( node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[NativeSeries]], frame: ArrowDataFrame, ) -> NativeSeries: - import pyarrow as pa - if is_literal_scalar(node): - lit: t.Any = pa.scalar - array = pa.repeat(lit(node.unwrap()), len(frame)) - return pa.chunked_array([array]) + return _lit_native(node.unwrap(), frame) return node.unwrap().to_native() -@_evaluate_inner.register(expr.Len) -def len_(node: expr.Len, frame: ArrowDataFrame) -> NativeSeries: - raise NotImplementedError(type(node)) - - @_evaluate_inner.register(expr.Cast) def cast_(node: expr.Cast, frame: ArrowDataFrame) -> NativeSeries: - raise NotImplementedError(type(node)) + from narwhals._arrow.utils import narwhals_to_native_dtype + + data_type = narwhals_to_native_dtype(node.dtype, frame.version) + return _evaluate_inner(node.expr, frame).cast(data_type) + + +@_evaluate_inner.register(expr.Sort) +def sort(node: expr.Sort, frame: ArrowDataFrame) -> NativeSeries: + import pyarrow.compute as pc + + native = _evaluate_inner(node.expr, frame) + sorted_indices = pc.array_sort_indices(native, options=node.options.to_arrow()) + return native.take(sorted_indices) + + +@_evaluate_inner.register(expr.Filter) +def filter_(node: expr.Filter, frame: ArrowDataFrame) -> NativeSeries: + return _evaluate_inner(node.expr, frame).filter(_evaluate_inner(node.by, frame)) + + +@_evaluate_inner.register(expr.Len) +def len_(node: expr.Len, frame: ArrowDataFrame) -> NativeSeries: + return _lit_native(len(frame), frame) @_evaluate_inner.register(expr.Ternary) @@ -197,21 +224,11 @@ def window_expr(node: expr.WindowExpr, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) -@_evaluate_inner.register(expr.Sort) -def sort(node: expr.Sort, frame: ArrowDataFrame) -> NativeSeries: - raise NotImplementedError(type(node)) - - @_evaluate_inner.register(expr.SortBy) def sort_by(node: expr.SortBy, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) -@_evaluate_inner.register(expr.Filter) -def filter_(node: expr.Filter, frame: ArrowDataFrame) -> NativeSeries: - raise NotImplementedError(type(node)) - - @_evaluate_inner.register(expr.AnonymousExpr) def anonymous_expr(node: expr.AnonymousExpr, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 922b1cfa63..1164358458 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -6,6 +6,8 @@ from narwhals._plan.common import Immutable if TYPE_CHECKING: + import pyarrow.compute as pc + from narwhals._plan.typing import Seq from narwhals.typing import RankMethod @@ -147,6 +149,14 @@ def __repr__(self) -> str: def default() -> SortOptions: return SortOptions(descending=False, nulls_last=False) + def to_arrow(self) -> pc.ArraySortOptions: + import pyarrow.compute as pc + + return pc.ArraySortOptions( + order=("descending" if self.descending else "ascending"), + null_placement=("at_end" if self.nulls_last else "at_start"), + ) + class SortMultipleOptions(Immutable): __slots__ = ("descending", "nulls_last") From faa91ec5a0559f9efae4e73f03a90d9107e46526 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 2 Jul 2025 21:52:42 +0100 Subject: [PATCH 261/368] feat(pyarrow): Impl `SortBy` Similar to (#2547), but had less to work with --- narwhals/_plan/dummy.py | 7 +++++-- narwhals/_plan/impl_arrow.py | 33 ++++++++++++++++++++++++++------- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 84be362f80..e56125708b 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -28,7 +28,7 @@ from narwhals._plan.window import Over from narwhals._utils import Version, _hasattr_static from narwhals.dtypes import DType -from narwhals.exceptions import ComputeError +from narwhals.exceptions import ComputeError, InvalidOperationError if TYPE_CHECKING: from typing_extensions import Never, Self @@ -172,6 +172,9 @@ def sort_by( nulls_last: bool | t.Iterable[bool] = False, ) -> Self: sort_by = parse.parse_into_seq_of_expr_ir(by, *more_by) + if length_changing := next((e for e in sort_by if e.is_scalar), None): + msg = f"All expressions passed to `sort_by` must preserve length, but got:\n{length_changing!r}" + raise InvalidOperationError(msg) desc = (descending,) if isinstance(descending, bool) else tuple(descending) nulls = (nulls_last,) if isinstance(nulls_last, bool) else tuple(nulls_last) options = SortMultipleOptions(descending=desc, nulls_last=nulls) @@ -876,7 +879,7 @@ def name(self) -> str: @classmethod def from_native( - cls, native: NativeSeriesT, name: str = "", /, *, version: Version + cls, native: NativeSeriesT, name: str = "", /, *, version: Version = Version.MAIN ) -> Self: name = name or ( getattr(native, "name", name) if _hasattr_static(native, "name") else name diff --git a/narwhals/_plan/impl_arrow.py b/narwhals/_plan/impl_arrow.py index 7ba3e95825..8e29f4b89d 100644 --- a/narwhals/_plan/impl_arrow.py +++ b/narwhals/_plan/impl_arrow.py @@ -8,7 +8,7 @@ # ruff: noqa: ARG001 import typing as t from functools import singledispatch -from itertools import chain +from itertools import chain, repeat from narwhals._plan import expr from narwhals._plan.contexts import ExprContext @@ -22,7 +22,7 @@ import pyarrow as pa from typing_extensions import Self, TypeAlias, TypeIs - from narwhals._arrow.typing import ScalarAny + from narwhals._arrow.typing import Order, ScalarAny # type: ignore[attr-defined] from narwhals._plan.common import ExprIR, NamedIR from narwhals._plan.dummy import DummySeries from narwhals._plan.typing import IntoExpr @@ -179,6 +179,30 @@ def sort(node: expr.Sort, frame: ArrowDataFrame) -> NativeSeries: return native.take(sorted_indices) +@_evaluate_inner.register(expr.SortBy) +def sort_by(node: expr.SortBy, frame: ArrowDataFrame) -> NativeSeries: + opts = node.options + if len(opts.nulls_last) != 1: + msg = f"pyarrow doesn't support multiple values for `nulls_last`, got: {opts.nulls_last!r}" + raise NotImplementedError(msg) + placement = "at_end" if opts.nulls_last[0] else "at_start" + from_native = ArrowSeries.from_native + by = ( + from_native(_evaluate_inner(e, frame), str(idx)) for idx, e in enumerate(node.by) + ) + df = frame.from_series(from_native(_evaluate_inner(node.expr, frame), ""), *by) + names = df.columns[1:] + if len(opts.descending) == 1: + descending: t.Iterable[bool] = repeat(opts.descending[0], len(names)) + else: + descending = opts.descending + sorting: list[tuple[str, Order]] = [ + (key, "descending" if desc else "ascending") + for key, desc in zip(names, descending) + ] + return df.native.sort_by(sorting, null_placement=placement).column(0) + + @_evaluate_inner.register(expr.Filter) def filter_(node: expr.Filter, frame: ArrowDataFrame) -> NativeSeries: return _evaluate_inner(node.expr, frame).filter(_evaluate_inner(node.by, frame)) @@ -224,11 +248,6 @@ def window_expr(node: expr.WindowExpr, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) -@_evaluate_inner.register(expr.SortBy) -def sort_by(node: expr.SortBy, frame: ArrowDataFrame) -> NativeSeries: - raise NotImplementedError(type(node)) - - @_evaluate_inner.register(expr.AnonymousExpr) def anonymous_expr(node: expr.AnonymousExpr, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) From 925d601024b728e67549f73b2e25dd78b4224887 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 3 Jul 2025 16:17:28 +0100 Subject: [PATCH 262/368] feat(pyarrow): Impl `First`, `Last` --- narwhals/_plan/impl_arrow.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/impl_arrow.py b/narwhals/_plan/impl_arrow.py index 8e29f4b89d..d3b3b52ce0 100644 --- a/narwhals/_plan/impl_arrow.py +++ b/narwhals/_plan/impl_arrow.py @@ -10,7 +10,7 @@ from functools import singledispatch from itertools import chain, repeat -from narwhals._plan import expr +from narwhals._plan import aggregation, expr from narwhals._plan.contexts import ExprContext from narwhals._plan.dummy import DummyCompliantFrame, DummyCompliantSeries from narwhals._plan.expr_expansion import into_named_irs, prepare_projection @@ -218,9 +218,17 @@ def ternary(node: expr.Ternary, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) -@_evaluate_inner.register(expr.Agg) -def agg(node: expr.Agg, frame: ArrowDataFrame) -> NativeSeries: - raise NotImplementedError(type(node)) +@_evaluate_inner.register(aggregation.Last) +@_evaluate_inner.register(aggregation.First) +def first_last( + node: aggregation.First | aggregation.Last, frame: ArrowDataFrame +) -> NativeSeries: + native = _evaluate_inner(node.expr, frame) + if height := len(native): + result = native[height - 1 if isinstance(node, aggregation.Last) else 0] + else: + result = None + return _lit_native(result, frame) @_evaluate_inner.register(expr.OrderableAgg) @@ -228,6 +236,11 @@ def orderable_agg(node: expr.OrderableAgg, frame: ArrowDataFrame) -> NativeSerie raise NotImplementedError(type(node)) +@_evaluate_inner.register(expr.Agg) +def agg(node: expr.Agg, frame: ArrowDataFrame) -> NativeSeries: + raise NotImplementedError(type(node)) + + @_evaluate_inner.register(expr.BinaryExpr) def binary_expr(node: expr.BinaryExpr, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) From c823d78a2267b2d2c81cd7796255f5ff2bd3ec94 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 3 Jul 2025 16:22:19 +0100 Subject: [PATCH 263/368] feat(pyarrow): Impl all aggregations --- narwhals/_plan/impl_arrow.py | 78 +++++++++++++++++++++++++++++++----- 1 file changed, 67 insertions(+), 11 deletions(-) diff --git a/narwhals/_plan/impl_arrow.py b/narwhals/_plan/impl_arrow.py index d3b3b52ce0..a1a9ccbaa2 100644 --- a/narwhals/_plan/impl_arrow.py +++ b/narwhals/_plan/impl_arrow.py @@ -10,7 +10,7 @@ from functools import singledispatch from itertools import chain, repeat -from narwhals._plan import aggregation, expr +from narwhals._plan import aggregation as agg, expr from narwhals._plan.contexts import ExprContext from narwhals._plan.dummy import DummyCompliantFrame, DummyCompliantSeries from narwhals._plan.expr_expansion import into_named_irs, prepare_projection @@ -34,6 +34,8 @@ NativeFrame: TypeAlias = "pa.Table" NativeSeries: TypeAlias = "pa.ChunkedArray[t.Any]" +UnaryFn: TypeAlias = "t.Callable[[NativeSeries], ScalarAny]" + def is_series(obj: t.Any) -> TypeIs[ArrowSeries]: return isinstance(obj, ArrowSeries) @@ -218,26 +220,80 @@ def ternary(node: expr.Ternary, frame: ArrowDataFrame) -> NativeSeries: raise NotImplementedError(type(node)) -@_evaluate_inner.register(aggregation.Last) -@_evaluate_inner.register(aggregation.First) -def first_last( - node: aggregation.First | aggregation.Last, frame: ArrowDataFrame -) -> NativeSeries: +@_evaluate_inner.register(agg.Last) +@_evaluate_inner.register(agg.First) +def agg_first_last(node: agg.First | agg.Last, frame: ArrowDataFrame) -> NativeSeries: native = _evaluate_inner(node.expr, frame) if height := len(native): - result = native[height - 1 if isinstance(node, aggregation.Last) else 0] + result = native[height - 1 if isinstance(node, agg.Last) else 0] else: result = None return _lit_native(result, frame) -@_evaluate_inner.register(expr.OrderableAgg) -def orderable_agg(node: expr.OrderableAgg, frame: ArrowDataFrame) -> NativeSeries: - raise NotImplementedError(type(node)) +@_evaluate_inner.register(agg.ArgMax) +@_evaluate_inner.register(agg.ArgMin) +def agg_arg_min_max(node: agg.ArgMin | agg.ArgMax, frame: ArrowDataFrame) -> NativeSeries: + import pyarrow.compute as pc + + native = _evaluate_inner(node.expr, frame) + fn = pc.min if isinstance(node, agg.ArgMin) else pc.max + result = pc.index(native, fn(native)) + return _lit_native(result, frame) + + +@_evaluate_inner.register(agg.Sum) +def agg_sum(node: agg.Sum, frame: ArrowDataFrame) -> NativeSeries: + import pyarrow.compute as pc + + result = pc.sum(_evaluate_inner(node.expr, frame), min_count=0) + return _lit_native(result, frame) + + +@_evaluate_inner.register(agg.NUnique) +def agg_n_unique(node: agg.NUnique, frame: ArrowDataFrame) -> NativeSeries: + import pyarrow.compute as pc + + result = pc.count(_evaluate_inner(node.expr, frame).unique(), mode="all") + return _lit_native(result, frame) + + +@_evaluate_inner.register(agg.Var) +@_evaluate_inner.register(agg.Std) +def agg_std_var(node: agg.Std | agg.Var, frame: ArrowDataFrame) -> NativeSeries: + import pyarrow.compute as pc + + fn = pc.stddev if isinstance(node, agg.Std) else pc.variance + result = fn(_evaluate_inner(node.expr, frame), ddof=node.ddof) + return _lit_native(result, frame) + + +@_evaluate_inner.register(agg.Quantile) +def agg_quantile(node: agg.Quantile, frame: ArrowDataFrame) -> NativeSeries: + import pyarrow.compute as pc + + result = pc.quantile( + _evaluate_inner(node.expr, frame), + q=node.quantile, + interpolation=node.interpolation, + )[0] + return _lit_native(result, frame) @_evaluate_inner.register(expr.Agg) -def agg(node: expr.Agg, frame: ArrowDataFrame) -> NativeSeries: +def agg_expr(node: expr.Agg, frame: ArrowDataFrame) -> NativeSeries: + import pyarrow.compute as pc + + mapping: dict[type[expr.Agg], UnaryFn] = { + agg.Count: pc.count, + agg.Max: pc.max, + agg.Mean: pc.mean, + agg.Median: pc.approximate_median, + agg.Min: pc.min, + } + if fn := mapping.get(type(node)): + result = fn(_evaluate_inner(node.expr, frame)) + return _lit_native(result, frame) raise NotImplementedError(type(node)) From b069348496a0e67503fa28fbd0efcfe685105d75 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 3 Jul 2025 16:30:30 +0100 Subject: [PATCH 264/368] docs: Note on broacasting --- narwhals/_plan/impl_arrow.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/narwhals/_plan/impl_arrow.py b/narwhals/_plan/impl_arrow.py index a1a9ccbaa2..5e2e171f5e 100644 --- a/narwhals/_plan/impl_arrow.py +++ b/narwhals/_plan/impl_arrow.py @@ -142,6 +142,9 @@ def col(node: expr.Column, frame: ArrowDataFrame) -> NativeSeries: return frame.native.column(node.name) +# NOTE: Using a very naïve approach to broadcasting **for now** +# - We already have something that works in main +# - Another approach would be to keep everything wrapped (or aggregated into) `expr.Literal` def _lit_native(value: PythonLiteral | ScalarAny, frame: ArrowDataFrame) -> NativeSeries: """Will need to support returning a native scalar as well.""" import pyarrow as pa From a4d90d88a16dc6d98e55a33879e548de8c58e215 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 3 Jul 2025 21:36:58 +0100 Subject: [PATCH 265/368] feat(DRAFT): Prepare new broadcasting layer --- narwhals/_plan/dummy.py | 3 + narwhals/_plan/impl_arrow.py | 134 ++++++++++++++++++++++++++++++++++- 2 files changed, 134 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index e56125708b..133fab566a 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -895,3 +895,6 @@ def _with_native(self, native: NativeSeriesT) -> Self: def alias(self, name: str) -> Self: return self.from_native(self.native, name, version=self.version) + + def __len__(self) -> int: + return len(self.native) diff --git a/narwhals/_plan/impl_arrow.py b/narwhals/_plan/impl_arrow.py index 5e2e171f5e..c62979e32b 100644 --- a/narwhals/_plan/impl_arrow.py +++ b/narwhals/_plan/impl_arrow.py @@ -5,36 +5,46 @@ from __future__ import annotations -# ruff: noqa: ARG001 import typing as t +from collections.abc import Sequence, Sized + +# ruff: noqa: ARG001 from functools import singledispatch from itertools import chain, repeat from narwhals._plan import aggregation as agg, expr +from narwhals._plan.common import into_dtype from narwhals._plan.contexts import ExprContext from narwhals._plan.dummy import DummyCompliantFrame, DummyCompliantSeries from narwhals._plan.expr_expansion import into_named_irs, prepare_projection from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir from narwhals._plan.literal import is_literal_scalar +from narwhals._typing_compat import TypeVar from narwhals._utils import Version +from narwhals.exceptions import InvalidOperationError, ShapeError if t.TYPE_CHECKING: import pyarrow as pa from typing_extensions import Self, TypeAlias, TypeIs - from narwhals._arrow.typing import Order, ScalarAny # type: ignore[attr-defined] + from narwhals._arrow.typing import ( # type: ignore[attr-defined] + Incomplete, + Order, + ScalarAny, + ) from narwhals._plan.common import ExprIR, NamedIR from narwhals._plan.dummy import DummySeries from narwhals._plan.typing import IntoExpr from narwhals.dtypes import DType from narwhals.schema import Schema - from narwhals.typing import NonNestedLiteral, PythonLiteral + from narwhals.typing import IntoDType, NonNestedLiteral, PythonLiteral NativeFrame: TypeAlias = "pa.Table" NativeSeries: TypeAlias = "pa.ChunkedArray[t.Any]" UnaryFn: TypeAlias = "t.Callable[[NativeSeries], ScalarAny]" +SeriesT = TypeVar("SeriesT") def is_series(obj: t.Any) -> TypeIs[ArrowSeries]: @@ -129,6 +139,124 @@ def to_list(self) -> list[t.Any]: return self.native.to_pylist() +class SupportsBroadcast(Sized, t.Protocol[SeriesT]): + """Minimal broadcasting for `Expr` results.""" + + @classmethod + def from_series(cls, series: SeriesT, /) -> Self: ... + def to_series(self) -> SeriesT: ... + def broadcast(self, length: int, /) -> SeriesT: ... + @classmethod + def align(cls, *exprs: SupportsBroadcast[SeriesT]) -> Sequence[SeriesT]: + lengths = [len(e) for e in exprs] + max_length = max(lengths) + fast_path = all(len_ == max_length for len_ in lengths) + if fast_path: + return [e.to_series() for e in exprs] + return [e.broadcast(max_length) for e in exprs] + + +# NOTE: General expression result +# Mostly elementwise +class ArrowExpr(SupportsBroadcast[ArrowSeries]): + _series: ArrowSeries + + @classmethod + def from_series(cls, series: ArrowSeries) -> Self: + obj = cls.__new__(cls) + obj._series = series + return obj + + def to_series(self) -> ArrowSeries: + return self._series + + def __len__(self) -> int: + return len(self._series) + + def broadcast(self, length: int, /) -> ArrowSeries: + if (actual_len := len(self)) != length: + msg = f"Expected object of length {length}, got {actual_len}." + raise ShapeError(msg) + return self._series + + +# NOTE: Aggregation result or scalar +# Should handle broadcasting, without exposing it +class ArrowLiteral(SupportsBroadcast[ArrowSeries]): + _native_scalar: ScalarAny + _name: str + + @property + def name(self) -> str: + return self._name + + def __len__(self) -> int: + return 1 + + def broadcast(self, length: int, /) -> ArrowSeries: + import pyarrow as pa + + from narwhals._arrow.utils import chunked_array + + if length == 1: + chunked = chunked_array([[self._native_scalar]]) + else: + # NOTE: Same issue as `pa.scalar` overlapping overloads + # https://github.com/zen-xu/pyarrow-stubs/pull/209 + pa_repeat: Incomplete = pa.repeat + arr = pa_repeat(self._native_scalar, length) + chunked = chunked_array(arr) + return ArrowSeries.from_native(chunked, self.name) + + @classmethod + def from_series(cls, series: ArrowSeries) -> Self: + if len(series) == 1: + return cls.from_scalar(series.native[0], series.name) + elif len(series) == 0: + return cls.from_python(None, series.name, dtype=series.dtype) + else: + msg = f"Too long {len(series)!r}" + raise InvalidOperationError(msg) + + def to_series(self) -> ArrowSeries: + return self.broadcast(1) + + @classmethod + def from_python( + cls, + value: PythonLiteral, + name: str = "literal", + /, + *, + dtype: IntoDType | None = None, + ) -> Self: + import pyarrow as pa + + from narwhals._arrow.utils import narwhals_to_native_dtype + + version = Version.MAIN + dtype_pa: pa.DataType | None = None + if dtype: + dtype = into_dtype(dtype) + if not isinstance(dtype, version.dtypes.Unknown): + dtype_pa = narwhals_to_native_dtype(dtype, version) + # NOTE: PR that fixed this was closed + # https://github.com/zen-xu/pyarrow-stubs/pull/208 + lit: Incomplete = pa.scalar + return cls.from_scalar(lit(value, dtype_pa), name) + + @classmethod + def from_scalar(cls, scalar: ScalarAny, name: str = "literal", /) -> Self: + obj = cls.__new__(cls) + obj._native_scalar = scalar + obj._name = name + return obj + + @classmethod + def from_ir(cls, value: expr.Literal[NonNestedLiteral], /) -> Self: + return cls.from_python(value.unwrap(), value.name) + + # NOTE: Should mean we produce 1x CompliantSeries for the entire expression # Multi-output have already been separated # No intermediate CompliantSeries need to be created, just assign a name to the final one From 0ba05ebdecbc1d3e594c1e09518e601d1183712f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 4 Jul 2025 18:13:40 +0100 Subject: [PATCH 266/368] refactor: Move `flatten_hash_safe` --- narwhals/_plan/common.py | 30 +++++++++++++++++++++++++----- narwhals/_plan/selectors.py | 21 ++++----------------- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 955494780e..0b35a9411a 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -1,6 +1,7 @@ from __future__ import annotations import datetime as dt +from collections.abc import Iterable from decimal import Decimal from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload @@ -20,13 +21,18 @@ from narwhals.utils import Version if TYPE_CHECKING: - from collections.abc import Iterable, Iterator + from collections.abc import Iterator from typing import Any, Callable, Literal from typing_extensions import Never, Self, TypeIs, dataclass_transform from narwhals._plan import expr - from narwhals._plan.dummy import DummyExpr, DummySelector, DummySeries + from narwhals._plan.dummy import ( + DummyCompliantSeries, + DummyExpr, + DummySelector, + DummySeries, + ) from narwhals._plan.expr import Agg, BinaryExpr, FunctionExpr, WindowExpr from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.options import FunctionOptions @@ -425,10 +431,12 @@ def is_series(obj: Any) -> TypeIs[DummySeries]: return isinstance(obj, DummySeries) -def is_iterable_reject(obj: Any) -> TypeIs[str | bytes | DummySeries]: - from narwhals._plan.dummy import DummySeries +def is_iterable_reject( + obj: Any, +) -> TypeIs[str | bytes | DummySeries | DummyCompliantSeries]: + from narwhals._plan.dummy import DummyCompliantSeries, DummySeries - return isinstance(obj, (str, bytes, DummySeries)) + return isinstance(obj, (str, bytes, DummySeries, DummyCompliantSeries)) def is_regex_projection(name: str) -> bool: @@ -519,3 +527,15 @@ def map_ir( result = result.map_ir(fn) return result return origin.map_ir(function) + + +def flatten_hash_safe(iterable: Iterable[T | Iterable[T]], /) -> Iterator[T]: + """Fully unwrap all levels of nesting. + + Aiming to reduce the chances of passing an unhashable argument. + """ + for element in iterable: + if isinstance(element, Iterable) and not is_iterable_reject(element): + yield from flatten_hash_safe(element) + else: + yield element # type: ignore[misc] diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/selectors.py index af3524180f..8124cb0c9e 100644 --- a/narwhals/_plan/selectors.py +++ b/narwhals/_plan/selectors.py @@ -7,14 +7,13 @@ from __future__ import annotations import re -from collections.abc import Iterable from typing import TYPE_CHECKING -from narwhals._plan.common import Immutable, is_iterable_reject +from narwhals._plan.common import Immutable, flatten_hash_safe from narwhals._utils import Version, _parse_time_unit_and_time_zone if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Iterable, Iterator from datetime import timezone from typing import TypeVar @@ -55,7 +54,7 @@ class ByDType(Selector): def from_dtypes( *dtypes: DType | type[DType] | Iterable[DType | type[DType]], ) -> ByDType: - return ByDType(dtypes=frozenset(_flatten_hash_safe(dtypes))) + return ByDType(dtypes=frozenset(flatten_hash_safe(dtypes))) def __repr__(self) -> str: els = ", ".join( @@ -131,7 +130,7 @@ def from_string(pattern: str, /) -> Matches: @staticmethod def from_names(*names: str | Iterable[str]) -> Matches: """Implements `cs.by_name` to support `__r__` with column selections.""" - it: Iterator[str] = _flatten_hash_safe(names) + it: Iterator[str] = flatten_hash_safe(names) pattern = f"^({'|'.join(re.escape(name) for name in it)})$" return Matches.from_string(pattern) @@ -201,15 +200,3 @@ def numeric() -> DummySelector: def string() -> DummySelector: return String().to_selector().to_narwhals() - - -def _flatten_hash_safe(iterable: Iterable[T | Iterable[T]], /) -> Iterator[T]: - """Fully unwrap all levels of nesting. - - Aiming to reduce the chances of passing an unhashable argument. - """ - for element in iterable: - if isinstance(element, Iterable) and not is_iterable_reject(element): - yield from _flatten_hash_safe(element) - else: - yield element # type: ignore[misc] From 3a4776f59bed83059ad17ca9596e1dac83464498 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 4 Jul 2025 20:59:30 +0100 Subject: [PATCH 267/368] =?UTF-8?q?more=20giant=20refactors=20=F0=9F=98=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - split up `impl_arrow` - iron out some broadcasting issues --- narwhals/_plan/arrow/__init__.py | 0 narwhals/_plan/arrow/dataframe.py | 117 ++++++++ narwhals/_plan/arrow/evaluate.py | 247 ++++++++++++++++ narwhals/_plan/arrow/expr.py | 128 +++++++++ narwhals/_plan/arrow/series.py | 13 + narwhals/_plan/impl_arrow.py | 453 ------------------------------ narwhals/_plan/protocols.py | 69 +++++ tests/plan/to_compliant_test.py | 3 +- 8 files changed, 576 insertions(+), 454 deletions(-) create mode 100644 narwhals/_plan/arrow/__init__.py create mode 100644 narwhals/_plan/arrow/dataframe.py create mode 100644 narwhals/_plan/arrow/evaluate.py create mode 100644 narwhals/_plan/arrow/expr.py create mode 100644 narwhals/_plan/arrow/series.py delete mode 100644 narwhals/_plan/impl_arrow.py create mode 100644 narwhals/_plan/protocols.py diff --git a/narwhals/_plan/arrow/__init__.py b/narwhals/_plan/arrow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py new file mode 100644 index 0000000000..e44440593e --- /dev/null +++ b/narwhals/_plan/arrow/dataframe.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import typing as t +from itertools import chain + +import pyarrow as pa + +from narwhals._arrow.utils import native_to_narwhals_dtype +from narwhals._plan.arrow.expr import ArrowExpr, ArrowLiteral +from narwhals._plan.arrow.series import ArrowSeries +from narwhals._plan.contexts import ExprContext +from narwhals._plan.dummy import DummyCompliantFrame +from narwhals._plan.expr_expansion import into_named_irs, prepare_projection +from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir +from narwhals._utils import Version + +if t.TYPE_CHECKING: + from collections.abc import Iterable, Iterator + + from typing_extensions import Self, TypeAlias, TypeIs + + from narwhals._arrow.typing import ChunkedArrayAny, ScalarAny + from narwhals._plan.common import ExprIR, NamedIR + from narwhals._plan.typing import IntoExpr + from narwhals.dtypes import DType + from narwhals.schema import Schema + + +UnaryFn: TypeAlias = "t.Callable[[ChunkedArrayAny], ScalarAny]" + + +def is_series(obj: t.Any) -> TypeIs[ArrowSeries]: + return isinstance(obj, ArrowSeries) + + +class ArrowDataFrame(DummyCompliantFrame["pa.Table", "ChunkedArrayAny"]): + @property + def _series(self) -> type[ArrowSeries]: + return ArrowSeries + + @property + def _expr(self) -> type[ArrowExpr]: + return ArrowExpr + + @property + def _lit(self) -> type[ArrowLiteral]: + return ArrowLiteral + + @property + def columns(self) -> list[str]: + return self.native.column_names + + @property + def schema(self) -> dict[str, DType]: + schema = self.native.schema + return { + name: native_to_narwhals_dtype(dtype, self._version) + for name, dtype in zip(schema.names, schema.types) + } + + def __len__(self) -> int: + return len(self.native) + + @classmethod + def from_series( + cls, series: t.Iterable[ArrowSeries] | ArrowSeries, *more_series: ArrowSeries + ) -> Self: + lhs = (series,) if is_series(series) else series + it = chain(lhs, more_series) if more_series else lhs + return cls.from_dict({s.name: s.native for s in it}) + + @classmethod + def from_dict( + cls, + data: t.Mapping[str, t.Any], + /, + *, + schema: t.Mapping[str, DType] | Schema | None = None, + ) -> Self: + from narwhals.schema import Schema + + pa_schema = Schema(schema).to_arrow() if schema is not None else schema + native = pa.Table.from_pydict(data, schema=pa_schema) + return cls.from_native(native, version=Version.MAIN) + + def iter_columns(self) -> t.Iterator[ArrowSeries]: + for name, series in zip(self.columns, self.native.itercolumns()): + yield ArrowSeries.from_native(series, name, version=self.version) + + @t.overload + def to_dict(self, *, as_series: t.Literal[True]) -> dict[str, ArrowSeries]: ... + + @t.overload + def to_dict(self, *, as_series: t.Literal[False]) -> dict[str, list[t.Any]]: ... + + def to_dict( + self, *, as_series: bool + ) -> dict[str, ArrowSeries] | dict[str, list[t.Any]]: + it = self.iter_columns() + if as_series: + return {ser.name: ser for ser in it} + return {ser.name: ser.to_list() for ser in it} + + def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[ArrowSeries]: + from narwhals._plan.arrow.evaluate import evaluate + + yield from self._expr.align(evaluate(e, self) for e in nodes) + + def select( + self, *exprs: IntoExpr | t.Iterable[IntoExpr], **named_exprs: t.Any + ) -> Self: + irs, schema_frozen, output_names = prepare_projection( + parse_into_seq_of_expr_ir(*exprs, **named_exprs), self.schema + ) + named_irs = into_named_irs(irs, output_names) + named_irs, schema_projected = schema_frozen.project(named_irs, ExprContext.SELECT) + return self.from_series(self._evaluate_irs(named_irs)) diff --git a/narwhals/_plan/arrow/evaluate.py b/narwhals/_plan/arrow/evaluate.py new file mode 100644 index 0000000000..40a78358b4 --- /dev/null +++ b/narwhals/_plan/arrow/evaluate.py @@ -0,0 +1,247 @@ +"""Translating `ExprIR` nodes for pyarrow.""" + +from __future__ import annotations + +import typing as t + +# ruff: noqa: ARG001 +from functools import singledispatch +from itertools import repeat + +from narwhals._plan import aggregation as agg, expr +from narwhals._plan.arrow.series import ArrowSeries +from narwhals._plan.literal import is_literal_scalar + +if t.TYPE_CHECKING: + from typing_extensions import TypeAlias, TypeIs + + from narwhals._arrow.typing import ( # type: ignore[attr-defined] + ChunkedArrayAny, + Order, + ScalarAny, + ) + from narwhals._plan.arrow.dataframe import ArrowDataFrame + from narwhals._plan.common import ExprIR, NamedIR + from narwhals._plan.dummy import DummySeries + from narwhals._plan.protocols import SupportsBroadcast + from narwhals.typing import NonNestedLiteral, PythonLiteral + + +UnaryFn: TypeAlias = "t.Callable[[ChunkedArrayAny], ScalarAny]" + + +def is_scalar(obj: t.Any) -> TypeIs[ScalarAny]: + import pyarrow as pa + + return isinstance(obj, pa.Scalar) + + +def evaluate( + node: NamedIR[ExprIR], frame: ArrowDataFrame +) -> SupportsBroadcast[ArrowSeries]: + result = _evaluate_inner(node.expr, frame) + if is_scalar(result): + return frame._lit.from_scalar(result, node.name) + return frame._expr.from_native(result, node.name) + + +# NOTE: Should mean we produce 1x CompliantSeries for the entire expression +# Multi-output have already been separated +# No intermediate CompliantSeries need to be created, just assign a name to the final one +@singledispatch +def _evaluate_inner(node: ExprIR, frame: ArrowDataFrame) -> ChunkedArrayAny: + raise NotImplementedError(type(node)) + + +@_evaluate_inner.register(expr.Column) +def col(node: expr.Column, frame: ArrowDataFrame) -> ChunkedArrayAny: + return frame.native.column(node.name) + + +# NOTE: Using a very naïve approach to broadcasting **for now** +# - We already have something that works in main +# - Another approach would be to keep everything wrapped (or aggregated into) `expr.Literal` +def _lit_native( + value: PythonLiteral | ScalarAny, frame: ArrowDataFrame +) -> ChunkedArrayAny: + """Will need to support returning a native scalar as well.""" + import pyarrow as pa + + from narwhals._arrow.utils import chunked_array + + lit: t.Any = pa.scalar + scalar: t.Any = value if isinstance(value, pa.Scalar) else lit(value) + array = pa.repeat(scalar, len(frame)) + return chunked_array(array) + + +@_evaluate_inner.register(expr.Literal) +def lit_( + node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[ChunkedArrayAny]], + frame: ArrowDataFrame, +) -> ChunkedArrayAny: + if is_literal_scalar(node): + return _lit_native(node.unwrap(), frame) + return node.unwrap().to_native() + + +@_evaluate_inner.register(expr.Cast) +def cast_(node: expr.Cast, frame: ArrowDataFrame) -> ChunkedArrayAny: + from narwhals._arrow.utils import narwhals_to_native_dtype + + data_type = narwhals_to_native_dtype(node.dtype, frame.version) + return _evaluate_inner(node.expr, frame).cast(data_type) + + +@_evaluate_inner.register(expr.Sort) +def sort(node: expr.Sort, frame: ArrowDataFrame) -> ChunkedArrayAny: + import pyarrow.compute as pc + + native = _evaluate_inner(node.expr, frame) + sorted_indices = pc.array_sort_indices(native, options=node.options.to_arrow()) + return native.take(sorted_indices) + + +@_evaluate_inner.register(expr.SortBy) +def sort_by(node: expr.SortBy, frame: ArrowDataFrame) -> ChunkedArrayAny: + opts = node.options + if len(opts.nulls_last) != 1: + msg = f"pyarrow doesn't support multiple values for `nulls_last`, got: {opts.nulls_last!r}" + raise NotImplementedError(msg) + placement = "at_end" if opts.nulls_last[0] else "at_start" + from_native = ArrowSeries.from_native + by = ( + from_native(_evaluate_inner(e, frame), str(idx)) for idx, e in enumerate(node.by) + ) + df = frame.from_series(from_native(_evaluate_inner(node.expr, frame), ""), *by) + names = df.columns[1:] + if len(opts.descending) == 1: + descending: t.Iterable[bool] = repeat(opts.descending[0], len(names)) + else: + descending = opts.descending + sorting: list[tuple[str, Order]] = [ + (key, "descending" if desc else "ascending") + for key, desc in zip(names, descending) + ] + return df.native.sort_by(sorting, null_placement=placement).column(0) + + +@_evaluate_inner.register(expr.Filter) +def filter_(node: expr.Filter, frame: ArrowDataFrame) -> ChunkedArrayAny: + return _evaluate_inner(node.expr, frame).filter(_evaluate_inner(node.by, frame)) + + +@_evaluate_inner.register(expr.Len) +def len_(node: expr.Len, frame: ArrowDataFrame) -> ChunkedArrayAny: + return _lit_native(len(frame), frame) + + +@_evaluate_inner.register(expr.Ternary) +def ternary(node: expr.Ternary, frame: ArrowDataFrame) -> ChunkedArrayAny: + raise NotImplementedError(type(node)) + + +@_evaluate_inner.register(agg.Last) +@_evaluate_inner.register(agg.First) +def agg_first_last(node: agg.First | agg.Last, frame: ArrowDataFrame) -> ChunkedArrayAny: + native = _evaluate_inner(node.expr, frame) + if height := len(native): + result = native[height - 1 if isinstance(node, agg.Last) else 0] + else: + result = None + return _lit_native(result, frame) + + +@_evaluate_inner.register(agg.ArgMax) +@_evaluate_inner.register(agg.ArgMin) +def agg_arg_min_max( + node: agg.ArgMin | agg.ArgMax, frame: ArrowDataFrame +) -> ChunkedArrayAny: + import pyarrow.compute as pc + + native = _evaluate_inner(node.expr, frame) + fn = pc.min if isinstance(node, agg.ArgMin) else pc.max + result = pc.index(native, fn(native)) + return _lit_native(result, frame) + + +@_evaluate_inner.register(agg.Sum) +def agg_sum(node: agg.Sum, frame: ArrowDataFrame) -> ChunkedArrayAny: + import pyarrow.compute as pc + + result = pc.sum(_evaluate_inner(node.expr, frame), min_count=0) + return _lit_native(result, frame) + + +@_evaluate_inner.register(agg.NUnique) +def agg_n_unique(node: agg.NUnique, frame: ArrowDataFrame) -> ChunkedArrayAny: + import pyarrow.compute as pc + + result = pc.count(_evaluate_inner(node.expr, frame).unique(), mode="all") + return _lit_native(result, frame) + + +@_evaluate_inner.register(agg.Var) +@_evaluate_inner.register(agg.Std) +def agg_std_var(node: agg.Std | agg.Var, frame: ArrowDataFrame) -> ChunkedArrayAny: + import pyarrow.compute as pc + + fn = pc.stddev if isinstance(node, agg.Std) else pc.variance + result = fn(_evaluate_inner(node.expr, frame), ddof=node.ddof) + return _lit_native(result, frame) + + +@_evaluate_inner.register(agg.Quantile) +def agg_quantile(node: agg.Quantile, frame: ArrowDataFrame) -> ChunkedArrayAny: + import pyarrow.compute as pc + + result = pc.quantile( + _evaluate_inner(node.expr, frame), + q=node.quantile, + interpolation=node.interpolation, + )[0] + return _lit_native(result, frame) + + +@_evaluate_inner.register(expr.Agg) +def agg_expr(node: expr.Agg, frame: ArrowDataFrame) -> ChunkedArrayAny: + import pyarrow.compute as pc + + mapping: dict[type[expr.Agg], UnaryFn] = { + agg.Count: pc.count, + agg.Max: pc.max, + agg.Mean: pc.mean, + agg.Median: pc.approximate_median, + agg.Min: pc.min, + } + if fn := mapping.get(type(node)): + result = fn(_evaluate_inner(node.expr, frame)) + return _lit_native(result, frame) + raise NotImplementedError(type(node)) + + +@_evaluate_inner.register(expr.BinaryExpr) +def binary_expr(node: expr.BinaryExpr, frame: ArrowDataFrame) -> ChunkedArrayAny: + raise NotImplementedError(type(node)) + + +@_evaluate_inner.register(expr.FunctionExpr) +def function_expr( + node: expr.FunctionExpr[t.Any], frame: ArrowDataFrame +) -> ChunkedArrayAny: + raise NotImplementedError(type(node)) + + +@_evaluate_inner.register(expr.RollingExpr) +def rolling_expr(node: expr.RollingExpr[t.Any], frame: ArrowDataFrame) -> ChunkedArrayAny: + raise NotImplementedError(type(node)) + + +@_evaluate_inner.register(expr.WindowExpr) +def window_expr(node: expr.WindowExpr, frame: ArrowDataFrame) -> ChunkedArrayAny: + raise NotImplementedError(type(node)) + + +@_evaluate_inner.register(expr.AnonymousExpr) +def anonymous_expr(node: expr.AnonymousExpr, frame: ArrowDataFrame) -> ChunkedArrayAny: + raise NotImplementedError(type(node)) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py new file mode 100644 index 0000000000..3cd2d4977f --- /dev/null +++ b/narwhals/_plan/arrow/expr.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pyarrow as pa + +from narwhals._arrow.utils import chunked_array, narwhals_to_native_dtype +from narwhals._plan.arrow.series import ArrowSeries +from narwhals._plan.common import into_dtype +from narwhals._plan.protocols import SupportsBroadcast +from narwhals._utils import Version +from narwhals.exceptions import InvalidOperationError, ShapeError + +if TYPE_CHECKING: + from typing_extensions import Self + + from narwhals._arrow.typing import ChunkedArrayAny, Incomplete, ScalarAny + from narwhals._plan import expr + from narwhals._plan.dummy import DummySeries + from narwhals.typing import IntoDType, NonNestedLiteral, PythonLiteral + + +# NOTE: General expression result +# Mostly elementwise +class ArrowExpr(SupportsBroadcast[ArrowSeries]): + _compliant: ArrowSeries + + @classmethod + def from_series(cls, series: ArrowSeries) -> Self: + obj = cls.__new__(cls) + obj._compliant = series + return obj + + @classmethod + def from_native( + cls, + native: ChunkedArrayAny, + name: str = "", + /, + *, + version: Version = Version.MAIN, + ) -> Self: + return cls.from_series(ArrowSeries.from_native(native, name, version=version)) + + @classmethod + def from_ir(cls, value: expr.Literal[DummySeries[ChunkedArrayAny]], /) -> Self: + return cls.from_native(value.unwrap().to_native(), value.name) + + def to_series(self) -> ArrowSeries: + return self._compliant + + def __len__(self) -> int: + return len(self._compliant) + + def broadcast(self, length: int, /) -> ArrowSeries: + if (actual_len := len(self)) != length: + msg = f"Expected object of length {length}, got {actual_len}." + raise ShapeError(msg) + return self._compliant + + +# NOTE: Aggregation result or scalar +# Should handle broadcasting, without exposing it +class ArrowLiteral(SupportsBroadcast[ArrowSeries]): + _native_scalar: ScalarAny + _name: str + + @property + def name(self) -> str: + return self._name + + def __len__(self) -> int: + return 1 + + def broadcast(self, length: int, /) -> ArrowSeries: + if length == 1: + chunked = chunked_array([[self._native_scalar]]) + else: + # NOTE: Same issue as `pa.scalar` overlapping overloads + # https://github.com/zen-xu/pyarrow-stubs/pull/209 + pa_repeat: Incomplete = pa.repeat + arr = pa_repeat(self._native_scalar, length) + chunked = chunked_array(arr) + return ArrowSeries.from_native(chunked, self.name) + + @classmethod + def from_series(cls, series: ArrowSeries) -> Self: + if len(series) == 1: + return cls.from_scalar(series.native[0], series.name) + elif len(series) == 0: + return cls.from_python(None, series.name, dtype=series.dtype) + else: + msg = f"Too long {len(series)!r}" + raise InvalidOperationError(msg) + + def to_series(self) -> ArrowSeries: + return self.broadcast(1) + + @classmethod + def from_python( + cls, + value: PythonLiteral, + name: str = "literal", + /, + *, + dtype: IntoDType | None = None, + ) -> Self: + version = Version.MAIN + dtype_pa: pa.DataType | None = None + if dtype: + dtype = into_dtype(dtype) + if not isinstance(dtype, version.dtypes.Unknown): + dtype_pa = narwhals_to_native_dtype(dtype, version) + # NOTE: PR that fixed this was closed + # https://github.com/zen-xu/pyarrow-stubs/pull/208 + lit: Incomplete = pa.scalar + return cls.from_scalar(lit(value, dtype_pa), name) + + @classmethod + def from_scalar(cls, scalar: ScalarAny, name: str = "literal", /) -> Self: + obj = cls.__new__(cls) + obj._native_scalar = scalar + obj._name = name + return obj + + @classmethod + def from_ir(cls, value: expr.Literal[NonNestedLiteral], /) -> Self: + return cls.from_python(value.unwrap(), value.name) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py new file mode 100644 index 0000000000..7e9bc17ede --- /dev/null +++ b/narwhals/_plan/arrow/series.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from narwhals._plan.dummy import DummyCompliantSeries + +if TYPE_CHECKING: + from narwhals._arrow.typing import ChunkedArrayAny # noqa: F401 + + +class ArrowSeries(DummyCompliantSeries["ChunkedArrayAny"]): + def to_list(self) -> list[Any]: + return self.native.to_pylist() diff --git a/narwhals/_plan/impl_arrow.py b/narwhals/_plan/impl_arrow.py deleted file mode 100644 index c62979e32b..0000000000 --- a/narwhals/_plan/impl_arrow.py +++ /dev/null @@ -1,453 +0,0 @@ -"""Translating `ExprIR` nodes for pyarrow. - -Acting like a trimmed down, native-only `CompliantExpr`, `CompliantSeries`, etc. -""" - -from __future__ import annotations - -import typing as t -from collections.abc import Sequence, Sized - -# ruff: noqa: ARG001 -from functools import singledispatch -from itertools import chain, repeat - -from narwhals._plan import aggregation as agg, expr -from narwhals._plan.common import into_dtype -from narwhals._plan.contexts import ExprContext -from narwhals._plan.dummy import DummyCompliantFrame, DummyCompliantSeries -from narwhals._plan.expr_expansion import into_named_irs, prepare_projection -from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir -from narwhals._plan.literal import is_literal_scalar -from narwhals._typing_compat import TypeVar -from narwhals._utils import Version -from narwhals.exceptions import InvalidOperationError, ShapeError - -if t.TYPE_CHECKING: - import pyarrow as pa - from typing_extensions import Self, TypeAlias, TypeIs - - from narwhals._arrow.typing import ( # type: ignore[attr-defined] - Incomplete, - Order, - ScalarAny, - ) - from narwhals._plan.common import ExprIR, NamedIR - from narwhals._plan.dummy import DummySeries - from narwhals._plan.typing import IntoExpr - from narwhals.dtypes import DType - from narwhals.schema import Schema - from narwhals.typing import IntoDType, NonNestedLiteral, PythonLiteral - - -NativeFrame: TypeAlias = "pa.Table" -NativeSeries: TypeAlias = "pa.ChunkedArray[t.Any]" - -UnaryFn: TypeAlias = "t.Callable[[NativeSeries], ScalarAny]" -SeriesT = TypeVar("SeriesT") - - -def is_series(obj: t.Any) -> TypeIs[ArrowSeries]: - return isinstance(obj, ArrowSeries) - - -class ArrowDataFrame(DummyCompliantFrame[NativeFrame, NativeSeries]): - @property - def _series(self) -> type[ArrowSeries]: - return ArrowSeries - - @property - def columns(self) -> list[str]: - return self.native.column_names - - @property - def schema(self) -> dict[str, DType]: - from narwhals._arrow.utils import native_to_narwhals_dtype - - schema = self.native.schema - return { - name: native_to_narwhals_dtype(dtype, self._version) - for name, dtype in zip(schema.names, schema.types) - } - - def __len__(self) -> int: - return len(self.native) - - @classmethod - def from_series( - cls, series: t.Iterable[ArrowSeries] | ArrowSeries, *more_series: ArrowSeries - ) -> Self: - lhs = (series,) if is_series(series) else series - it = chain(lhs, more_series) if more_series else lhs - return cls.from_dict({s.name: s.native for s in it}) - - @classmethod - def from_dict( - cls, - data: t.Mapping[str, t.Any], - /, - *, - schema: t.Mapping[str, DType] | Schema | None = None, - ) -> Self: - import pyarrow as pa - - from narwhals.schema import Schema - - pa_schema = Schema(schema).to_arrow() if schema is not None else schema - native = pa.Table.from_pydict(data, schema=pa_schema) - return cls.from_native(native, version=Version.MAIN) - - def iter_columns(self) -> t.Iterator[ArrowSeries]: - for name, series in zip(self.columns, self.native.itercolumns()): - yield ArrowSeries.from_native(series, name, version=self.version) - - @t.overload - def to_dict(self, *, as_series: t.Literal[True]) -> dict[str, ArrowSeries]: ... - - @t.overload - def to_dict(self, *, as_series: t.Literal[False]) -> dict[str, list[t.Any]]: ... - - def to_dict( - self, *, as_series: bool - ) -> dict[str, ArrowSeries] | dict[str, list[t.Any]]: - it = self.iter_columns() - if as_series: - return {ser.name: ser for ser in it} - return {ser.name: ser.to_list() for ser in it} - - def _evaluate_irs( - self, nodes: t.Iterable[NamedIR[ExprIR]], / - ) -> t.Iterator[ArrowSeries]: - for node in nodes: - yield self._series.from_native( - _evaluate_inner(node.expr, self), node.name, version=self.version - ) - - def select( - self, *exprs: IntoExpr | t.Iterable[IntoExpr], **named_exprs: t.Any - ) -> Self: - irs, schema_frozen, output_names = prepare_projection( - parse_into_seq_of_expr_ir(*exprs, **named_exprs), self.schema - ) - named_irs = into_named_irs(irs, output_names) - named_irs, schema_projected = schema_frozen.project(named_irs, ExprContext.SELECT) - return self.from_series(self._evaluate_irs(named_irs)) - - -class ArrowSeries(DummyCompliantSeries[NativeSeries]): - def to_list(self) -> list[t.Any]: - return self.native.to_pylist() - - -class SupportsBroadcast(Sized, t.Protocol[SeriesT]): - """Minimal broadcasting for `Expr` results.""" - - @classmethod - def from_series(cls, series: SeriesT, /) -> Self: ... - def to_series(self) -> SeriesT: ... - def broadcast(self, length: int, /) -> SeriesT: ... - @classmethod - def align(cls, *exprs: SupportsBroadcast[SeriesT]) -> Sequence[SeriesT]: - lengths = [len(e) for e in exprs] - max_length = max(lengths) - fast_path = all(len_ == max_length for len_ in lengths) - if fast_path: - return [e.to_series() for e in exprs] - return [e.broadcast(max_length) for e in exprs] - - -# NOTE: General expression result -# Mostly elementwise -class ArrowExpr(SupportsBroadcast[ArrowSeries]): - _series: ArrowSeries - - @classmethod - def from_series(cls, series: ArrowSeries) -> Self: - obj = cls.__new__(cls) - obj._series = series - return obj - - def to_series(self) -> ArrowSeries: - return self._series - - def __len__(self) -> int: - return len(self._series) - - def broadcast(self, length: int, /) -> ArrowSeries: - if (actual_len := len(self)) != length: - msg = f"Expected object of length {length}, got {actual_len}." - raise ShapeError(msg) - return self._series - - -# NOTE: Aggregation result or scalar -# Should handle broadcasting, without exposing it -class ArrowLiteral(SupportsBroadcast[ArrowSeries]): - _native_scalar: ScalarAny - _name: str - - @property - def name(self) -> str: - return self._name - - def __len__(self) -> int: - return 1 - - def broadcast(self, length: int, /) -> ArrowSeries: - import pyarrow as pa - - from narwhals._arrow.utils import chunked_array - - if length == 1: - chunked = chunked_array([[self._native_scalar]]) - else: - # NOTE: Same issue as `pa.scalar` overlapping overloads - # https://github.com/zen-xu/pyarrow-stubs/pull/209 - pa_repeat: Incomplete = pa.repeat - arr = pa_repeat(self._native_scalar, length) - chunked = chunked_array(arr) - return ArrowSeries.from_native(chunked, self.name) - - @classmethod - def from_series(cls, series: ArrowSeries) -> Self: - if len(series) == 1: - return cls.from_scalar(series.native[0], series.name) - elif len(series) == 0: - return cls.from_python(None, series.name, dtype=series.dtype) - else: - msg = f"Too long {len(series)!r}" - raise InvalidOperationError(msg) - - def to_series(self) -> ArrowSeries: - return self.broadcast(1) - - @classmethod - def from_python( - cls, - value: PythonLiteral, - name: str = "literal", - /, - *, - dtype: IntoDType | None = None, - ) -> Self: - import pyarrow as pa - - from narwhals._arrow.utils import narwhals_to_native_dtype - - version = Version.MAIN - dtype_pa: pa.DataType | None = None - if dtype: - dtype = into_dtype(dtype) - if not isinstance(dtype, version.dtypes.Unknown): - dtype_pa = narwhals_to_native_dtype(dtype, version) - # NOTE: PR that fixed this was closed - # https://github.com/zen-xu/pyarrow-stubs/pull/208 - lit: Incomplete = pa.scalar - return cls.from_scalar(lit(value, dtype_pa), name) - - @classmethod - def from_scalar(cls, scalar: ScalarAny, name: str = "literal", /) -> Self: - obj = cls.__new__(cls) - obj._native_scalar = scalar - obj._name = name - return obj - - @classmethod - def from_ir(cls, value: expr.Literal[NonNestedLiteral], /) -> Self: - return cls.from_python(value.unwrap(), value.name) - - -# NOTE: Should mean we produce 1x CompliantSeries for the entire expression -# Multi-output have already been separated -# No intermediate CompliantSeries need to be created, just assign a name to the final one -@singledispatch -def _evaluate_inner(node: ExprIR, frame: ArrowDataFrame) -> NativeSeries: - raise NotImplementedError(type(node)) - - -@_evaluate_inner.register(expr.Column) -def col(node: expr.Column, frame: ArrowDataFrame) -> NativeSeries: - return frame.native.column(node.name) - - -# NOTE: Using a very naïve approach to broadcasting **for now** -# - We already have something that works in main -# - Another approach would be to keep everything wrapped (or aggregated into) `expr.Literal` -def _lit_native(value: PythonLiteral | ScalarAny, frame: ArrowDataFrame) -> NativeSeries: - """Will need to support returning a native scalar as well.""" - import pyarrow as pa - - from narwhals._arrow.utils import chunked_array - - lit: t.Any = pa.scalar - scalar: t.Any = value if isinstance(value, pa.Scalar) else lit(value) - array = pa.repeat(scalar, len(frame)) - return chunked_array(array) - - -@_evaluate_inner.register(expr.Literal) -def lit_( - node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[NativeSeries]], - frame: ArrowDataFrame, -) -> NativeSeries: - if is_literal_scalar(node): - return _lit_native(node.unwrap(), frame) - return node.unwrap().to_native() - - -@_evaluate_inner.register(expr.Cast) -def cast_(node: expr.Cast, frame: ArrowDataFrame) -> NativeSeries: - from narwhals._arrow.utils import narwhals_to_native_dtype - - data_type = narwhals_to_native_dtype(node.dtype, frame.version) - return _evaluate_inner(node.expr, frame).cast(data_type) - - -@_evaluate_inner.register(expr.Sort) -def sort(node: expr.Sort, frame: ArrowDataFrame) -> NativeSeries: - import pyarrow.compute as pc - - native = _evaluate_inner(node.expr, frame) - sorted_indices = pc.array_sort_indices(native, options=node.options.to_arrow()) - return native.take(sorted_indices) - - -@_evaluate_inner.register(expr.SortBy) -def sort_by(node: expr.SortBy, frame: ArrowDataFrame) -> NativeSeries: - opts = node.options - if len(opts.nulls_last) != 1: - msg = f"pyarrow doesn't support multiple values for `nulls_last`, got: {opts.nulls_last!r}" - raise NotImplementedError(msg) - placement = "at_end" if opts.nulls_last[0] else "at_start" - from_native = ArrowSeries.from_native - by = ( - from_native(_evaluate_inner(e, frame), str(idx)) for idx, e in enumerate(node.by) - ) - df = frame.from_series(from_native(_evaluate_inner(node.expr, frame), ""), *by) - names = df.columns[1:] - if len(opts.descending) == 1: - descending: t.Iterable[bool] = repeat(opts.descending[0], len(names)) - else: - descending = opts.descending - sorting: list[tuple[str, Order]] = [ - (key, "descending" if desc else "ascending") - for key, desc in zip(names, descending) - ] - return df.native.sort_by(sorting, null_placement=placement).column(0) - - -@_evaluate_inner.register(expr.Filter) -def filter_(node: expr.Filter, frame: ArrowDataFrame) -> NativeSeries: - return _evaluate_inner(node.expr, frame).filter(_evaluate_inner(node.by, frame)) - - -@_evaluate_inner.register(expr.Len) -def len_(node: expr.Len, frame: ArrowDataFrame) -> NativeSeries: - return _lit_native(len(frame), frame) - - -@_evaluate_inner.register(expr.Ternary) -def ternary(node: expr.Ternary, frame: ArrowDataFrame) -> NativeSeries: - raise NotImplementedError(type(node)) - - -@_evaluate_inner.register(agg.Last) -@_evaluate_inner.register(agg.First) -def agg_first_last(node: agg.First | agg.Last, frame: ArrowDataFrame) -> NativeSeries: - native = _evaluate_inner(node.expr, frame) - if height := len(native): - result = native[height - 1 if isinstance(node, agg.Last) else 0] - else: - result = None - return _lit_native(result, frame) - - -@_evaluate_inner.register(agg.ArgMax) -@_evaluate_inner.register(agg.ArgMin) -def agg_arg_min_max(node: agg.ArgMin | agg.ArgMax, frame: ArrowDataFrame) -> NativeSeries: - import pyarrow.compute as pc - - native = _evaluate_inner(node.expr, frame) - fn = pc.min if isinstance(node, agg.ArgMin) else pc.max - result = pc.index(native, fn(native)) - return _lit_native(result, frame) - - -@_evaluate_inner.register(agg.Sum) -def agg_sum(node: agg.Sum, frame: ArrowDataFrame) -> NativeSeries: - import pyarrow.compute as pc - - result = pc.sum(_evaluate_inner(node.expr, frame), min_count=0) - return _lit_native(result, frame) - - -@_evaluate_inner.register(agg.NUnique) -def agg_n_unique(node: agg.NUnique, frame: ArrowDataFrame) -> NativeSeries: - import pyarrow.compute as pc - - result = pc.count(_evaluate_inner(node.expr, frame).unique(), mode="all") - return _lit_native(result, frame) - - -@_evaluate_inner.register(agg.Var) -@_evaluate_inner.register(agg.Std) -def agg_std_var(node: agg.Std | agg.Var, frame: ArrowDataFrame) -> NativeSeries: - import pyarrow.compute as pc - - fn = pc.stddev if isinstance(node, agg.Std) else pc.variance - result = fn(_evaluate_inner(node.expr, frame), ddof=node.ddof) - return _lit_native(result, frame) - - -@_evaluate_inner.register(agg.Quantile) -def agg_quantile(node: agg.Quantile, frame: ArrowDataFrame) -> NativeSeries: - import pyarrow.compute as pc - - result = pc.quantile( - _evaluate_inner(node.expr, frame), - q=node.quantile, - interpolation=node.interpolation, - )[0] - return _lit_native(result, frame) - - -@_evaluate_inner.register(expr.Agg) -def agg_expr(node: expr.Agg, frame: ArrowDataFrame) -> NativeSeries: - import pyarrow.compute as pc - - mapping: dict[type[expr.Agg], UnaryFn] = { - agg.Count: pc.count, - agg.Max: pc.max, - agg.Mean: pc.mean, - agg.Median: pc.approximate_median, - agg.Min: pc.min, - } - if fn := mapping.get(type(node)): - result = fn(_evaluate_inner(node.expr, frame)) - return _lit_native(result, frame) - raise NotImplementedError(type(node)) - - -@_evaluate_inner.register(expr.BinaryExpr) -def binary_expr(node: expr.BinaryExpr, frame: ArrowDataFrame) -> NativeSeries: - raise NotImplementedError(type(node)) - - -@_evaluate_inner.register(expr.FunctionExpr) -def function_expr(node: expr.FunctionExpr[t.Any], frame: ArrowDataFrame) -> NativeSeries: - raise NotImplementedError(type(node)) - - -@_evaluate_inner.register(expr.RollingExpr) -def rolling_expr(node: expr.RollingExpr[t.Any], frame: ArrowDataFrame) -> NativeSeries: - raise NotImplementedError(type(node)) - - -@_evaluate_inner.register(expr.WindowExpr) -def window_expr(node: expr.WindowExpr, frame: ArrowDataFrame) -> NativeSeries: - raise NotImplementedError(type(node)) - - -@_evaluate_inner.register(expr.AnonymousExpr) -def anonymous_expr(node: expr.AnonymousExpr, frame: ArrowDataFrame) -> NativeSeries: - raise NotImplementedError(type(node)) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py new file mode 100644 index 0000000000..971f502f64 --- /dev/null +++ b/narwhals/_plan/protocols.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from collections.abc import Iterable, Iterator, Sized +from typing import TYPE_CHECKING, Any, Protocol + +from narwhals._plan.common import flatten_hash_safe +from narwhals._typing_compat import TypeVar + +if TYPE_CHECKING: + from typing_extensions import Self + +SeriesT = TypeVar("SeriesT") + + +class SupportsBroadcast(Sized, Protocol[SeriesT]): + """Minimal broadcasting for `Expr` results.""" + + @classmethod + def from_series(cls, series: SeriesT, /) -> Self: ... + def to_series(self) -> SeriesT: ... + def broadcast(self, length: int, /) -> SeriesT: ... + @classmethod + def align( + cls, *exprs: SupportsBroadcast[SeriesT] | Iterable[SupportsBroadcast[SeriesT]] + ) -> Iterator[SeriesT]: + exprs = tuple[SupportsBroadcast[SeriesT], ...](flatten_hash_safe(exprs)) + lengths = [len(e) for e in exprs] + max_length = max(lengths) + fast_path = all(len_ == max_length for len_ in lengths) + if fast_path: + for e in exprs: + yield e.to_series() + else: + for e in exprs: + yield e.broadcast(max_length) + + +class CompliantExpr(Protocol): + """Getting a bit tricky, just storing notes. + + - Separating series/scalar makes a lot of sense + - Handling the recursive case *without* intermediate (non-pyarrow) objects seems unachievable + - Everywhere would need to first check if it a scalar, which isn't ergonomic + - Broadcasting being separated is working + - A lot of `pyarrow.compute` (section 2) can work on either scalar or series (`FunctionExpr`) + - Aggregation can't, but that is already handled in `ExprIR` + - `polars` noops on aggregating a scalar, which we might be able to support this way + """ + + # scalar allowed + def cast(self, *args: Any, **kwds: Any) -> Any: ... + # array only (section 3) + def sort(self, *args: Any, **kwds: Any) -> Any: ... + def sort_by(self, *args: Any, **kwds: Any) -> Any: ... + def filter(self, *args: Any, **kwds: Any) -> Any: ... + def first(self, *args: Any, **kwds: Any) -> Any: ... + def last(self, *args: Any, **kwds: Any) -> Any: ... + def arg_min(self, *args: Any, **kwds: Any) -> Any: ... + def arg_max(self, *args: Any, **kwds: Any) -> Any: ... + def sum(self, *args: Any, **kwds: Any) -> Any: ... + def n_unique(self, *args: Any, **kwds: Any) -> Any: ... + def std(self, *args: Any, **kwds: Any) -> Any: ... + def var(self, *args: Any, **kwds: Any) -> Any: ... + def quantile(self, *args: Any, **kwds: Any) -> Any: ... + def count(self, *args: Any, **kwds: Any) -> Any: ... + def max(self, *args: Any, **kwds: Any) -> Any: ... + def mean(self, *args: Any, **kwds: Any) -> Any: ... + def median(self, *args: Any, **kwds: Any) -> Any: ... + def min(self, *args: Any, **kwds: Any) -> Any: ... diff --git a/tests/plan/to_compliant_test.py b/tests/plan/to_compliant_test.py index 69dc1857dd..837fbd2c52 100644 --- a/tests/plan/to_compliant_test.py +++ b/tests/plan/to_compliant_test.py @@ -7,7 +7,6 @@ import narwhals as nw import narwhals._plan.demo as nwd from narwhals._plan.common import is_expr -from narwhals._plan.impl_arrow import ArrowDataFrame from narwhals.utils import Version from tests.namespace_test import backends @@ -63,6 +62,8 @@ def test_select( pytest.importorskip("pyarrow") import pyarrow as pa + from narwhals._plan.arrow.dataframe import ArrowDataFrame + frame = pa.table(data_small) df = ArrowDataFrame.from_native(frame, Version.MAIN) result = df.select(expr).to_dict(as_series=False) From 5ca704f5db3840c0f43bfef1f5368dc0ba398458 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 4 Jul 2025 21:38:36 +0100 Subject: [PATCH 268/368] ignore-banned-import --- narwhals/_plan/arrow/dataframe.py | 2 +- narwhals/_plan/arrow/evaluate.py | 4 ++-- narwhals/_plan/arrow/expr.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index e44440593e..e72766cf18 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -3,7 +3,7 @@ import typing as t from itertools import chain -import pyarrow as pa +import pyarrow as pa # ignore-banned-import from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._plan.arrow.expr import ArrowExpr, ArrowLiteral diff --git a/narwhals/_plan/arrow/evaluate.py b/narwhals/_plan/arrow/evaluate.py index 40a78358b4..ec8c9f919b 100644 --- a/narwhals/_plan/arrow/evaluate.py +++ b/narwhals/_plan/arrow/evaluate.py @@ -31,7 +31,7 @@ def is_scalar(obj: t.Any) -> TypeIs[ScalarAny]: - import pyarrow as pa + import pyarrow as pa # ignore-banned-import return isinstance(obj, pa.Scalar) @@ -65,7 +65,7 @@ def _lit_native( value: PythonLiteral | ScalarAny, frame: ArrowDataFrame ) -> ChunkedArrayAny: """Will need to support returning a native scalar as well.""" - import pyarrow as pa + import pyarrow as pa # ignore-banned-import from narwhals._arrow.utils import chunked_array diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 3cd2d4977f..f44e59a7c9 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -import pyarrow as pa +import pyarrow as pa # ignore-banned-import from narwhals._arrow.utils import chunked_array, narwhals_to_native_dtype from narwhals._plan.arrow.series import ArrowSeries From 8e40fea90df04c0bb7301afd4c44ceccd9743ceb Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 4 Jul 2025 21:46:14 +0100 Subject: [PATCH 269/368] fix: alias removed exception types https://github.com/narwhals-dev/narwhals/pull/2752#pullrequestreview-2986828399 --- narwhals/_plan/demo.py | 2 +- narwhals/_plan/exceptions.py | 2 +- tests/plan/expr_parsing_test.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index d5bbfc5d2f..d2835f029e 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -22,7 +22,7 @@ from narwhals._plan.strings import ConcatHorizontal from narwhals._plan.when_then import When from narwhals._utils import Version, flatten -from narwhals.exceptions import OrderDependentExprError +from narwhals.exceptions import InvalidOperationError as OrderDependentExprError if t.TYPE_CHECKING: from typing_extensions import TypeIs diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index 8a7a76b02e..0028e3f7a5 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -12,7 +12,7 @@ DuplicateError, InvalidIntoExprError, InvalidOperationError, - LengthChangingExprError, + InvalidOperationError as LengthChangingExprError, MultiOutputExpressionError, ShapeError, ) diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 35826a0dad..3567eab69a 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -20,7 +20,7 @@ from narwhals.exceptions import ( InvalidIntoExprError, InvalidOperationError, - LengthChangingExprError, + InvalidOperationError as LengthChangingExprError, MultiOutputExpressionError, ShapeError, ) From e45118d2fcf29029b95da0acd2f3058c5b0772f2 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 5 Jul 2025 00:10:58 +0100 Subject: [PATCH 270/368] refactor: Handle expansion, projection at narwhals-level --- narwhals/_plan/arrow/dataframe.py | 27 ++--- narwhals/_plan/dummy.py | 165 +++++++++++++++++++++++++++--- tests/plan/expr_parsing_test.py | 6 +- tests/plan/to_compliant_test.py | 4 +- 4 files changed, 163 insertions(+), 39 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index e72766cf18..49e375f53e 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -8,10 +8,7 @@ from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._plan.arrow.expr import ArrowExpr, ArrowLiteral from narwhals._plan.arrow.series import ArrowSeries -from narwhals._plan.contexts import ExprContext -from narwhals._plan.dummy import DummyCompliantFrame -from narwhals._plan.expr_expansion import into_named_irs, prepare_projection -from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir +from narwhals._plan.dummy import DummyCompliantFrame, DummyFrame from narwhals._utils import Version if t.TYPE_CHECKING: @@ -21,7 +18,6 @@ from narwhals._arrow.typing import ChunkedArrayAny, ScalarAny from narwhals._plan.common import ExprIR, NamedIR - from narwhals._plan.typing import IntoExpr from narwhals.dtypes import DType from narwhals.schema import Schema @@ -33,7 +29,7 @@ def is_series(obj: t.Any) -> TypeIs[ArrowSeries]: return isinstance(obj, ArrowSeries) -class ArrowDataFrame(DummyCompliantFrame["pa.Table", "ChunkedArrayAny"]): +class ArrowDataFrame(DummyCompliantFrame[ArrowSeries, "pa.Table", "ChunkedArrayAny"]): @property def _series(self) -> type[ArrowSeries]: return ArrowSeries @@ -61,6 +57,9 @@ def schema(self) -> dict[str, DType]: def __len__(self) -> int: return len(self.native) + def to_narwhals(self) -> DummyFrame[pa.Table, ChunkedArrayAny]: + return DummyFrame[pa.Table, "ChunkedArrayAny"]._from_compliant(self) + @classmethod def from_series( cls, series: t.Iterable[ArrowSeries] | ArrowSeries, *more_series: ArrowSeries @@ -89,10 +88,12 @@ def iter_columns(self) -> t.Iterator[ArrowSeries]: @t.overload def to_dict(self, *, as_series: t.Literal[True]) -> dict[str, ArrowSeries]: ... - @t.overload def to_dict(self, *, as_series: t.Literal[False]) -> dict[str, list[t.Any]]: ... - + @t.overload + def to_dict( + self, *, as_series: bool + ) -> dict[str, ArrowSeries] | dict[str, list[t.Any]]: ... def to_dict( self, *, as_series: bool ) -> dict[str, ArrowSeries] | dict[str, list[t.Any]]: @@ -105,13 +106,3 @@ def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[ArrowSe from narwhals._plan.arrow.evaluate import evaluate yield from self._expr.align(evaluate(e, self) for e in nodes) - - def select( - self, *exprs: IntoExpr | t.Iterable[IntoExpr], **named_exprs: t.Any - ) -> Self: - irs, schema_frozen, output_names = prepare_projection( - parse_into_seq_of_expr_ir(*exprs, **named_exprs), self.schema - ) - named_irs = into_named_irs(irs, output_names) - named_irs, schema_projected = schema_frozen.project(named_irs, ExprContext.SELECT) - return self.from_series(self._evaluate_irs(named_irs)) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 133fab566a..b60750a771 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -10,11 +10,13 @@ aggregation as agg, boolean, expr, + expr_expansion, expr_parsing as parse, functions as F, # noqa: N812 operators as ops, ) -from narwhals._plan.common import is_column, is_expr, is_series +from narwhals._plan.common import NamedIR, is_column, is_expr, is_series +from narwhals._plan.contexts import ExprContext from narwhals._plan.options import ( EWMOptions, RankOptions, @@ -26,18 +28,25 @@ from narwhals._plan.selectors import by_name from narwhals._plan.typing import NativeFrameT, NativeSeriesT from narwhals._plan.window import Over +from narwhals._typing_compat import TypeVar from narwhals._utils import Version, _hasattr_static +from narwhals.dependencies import is_pyarrow_chunked_array, is_pyarrow_table from narwhals.dtypes import DType from narwhals.exceptions import ComputeError, InvalidOperationError +from narwhals.schema import Schema if TYPE_CHECKING: - from typing_extensions import Never, Self + from collections.abc import Iterable, Iterator, Mapping + + import pyarrow as pa + from typing_extensions import Never, Self, TypeAlias from narwhals._plan.categorical import ExprCatNamespace from narwhals._plan.common import ExprIR from narwhals._plan.lists import ExprListNamespace from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.name import ExprNameNamespace + from narwhals._plan.schema import FrozenSchema from narwhals._plan.strings import ExprStringNamespace from narwhals._plan.struct import ExprStructNamespace from narwhals._plan.temporal import ExprDateTimeNamespace @@ -46,6 +55,8 @@ ClosedInterval, FillNullStrategy, IntoDType, + NativeFrame, + NativeSeries, NumericLiteral, RankMethod, RollingInterpolationMethod, @@ -53,6 +64,10 @@ ) +CompliantSeriesT = TypeVar("CompliantSeriesT", bound="DummyCompliantSeries[t.Any]") +CompliantFrame: TypeAlias = "DummyCompliantFrame[t.Any, NativeFrameT, NativeSeriesT]" + + # NOTE: Overly simplified placeholders for mocking typing # Entirely ignoring namespace + function binding class DummyExpr: @@ -768,7 +783,7 @@ def to_narwhals(self) -> DummyExpr: class DummyFrame(Generic[NativeFrameT, NativeSeriesT]): - _compliant: DummyCompliantFrame[NativeFrameT, NativeSeriesT] + _compliant: CompliantFrame[NativeFrameT, NativeSeriesT] _version: t.ClassVar[Version] = Version.MAIN @property @@ -776,25 +791,76 @@ def version(self) -> Version: return self._version @property - def _series(self) -> type[DummySeries]: - return DummySeries + def _series(self) -> type[DummySeries[NativeSeriesT]]: + return DummySeries[NativeSeriesT] + + @property + def schema(self) -> Schema: + return Schema(self._compliant.schema.items()) + + @property + def columns(self) -> list[str]: + return self._compliant.columns + # NOTE: Gave up on trying to get typing working for now @classmethod - def from_native(cls, native: NativeFrameT, /) -> Self: + def from_native( + cls, native: NativeFrame, / + ) -> DummyFrame[pa.Table, pa.ChunkedArray[t.Any]]: + if is_pyarrow_table(native): + from narwhals._plan.arrow.dataframe import ArrowDataFrame + + return ArrowDataFrame.from_native(native, cls._version).to_narwhals() + + raise NotImplementedError(type(native)) + + @classmethod + def _from_compliant( + cls, compliant: CompliantFrame[NativeFrameT, NativeSeriesT], / + ) -> Self: obj = cls.__new__(cls) - obj._compliant = DummyCompliantFrame[NativeFrameT, NativeSeriesT].from_native( - native, cls._version - ) + obj._compliant = compliant return obj def to_native(self) -> NativeFrameT: return self._compliant.native + @t.overload + def to_dict( + self, *, as_series: t.Literal[True] = ... + ) -> dict[str, DummySeries[NativeSeriesT]]: ... + + @t.overload + def to_dict(self, *, as_series: t.Literal[False]) -> dict[str, list[t.Any]]: ... + + @t.overload + def to_dict( + self, *, as_series: bool + ) -> dict[str, DummySeries[NativeSeriesT]] | dict[str, list[t.Any]]: ... + + def to_dict( + self, *, as_series: bool = True + ) -> dict[str, DummySeries[NativeSeriesT]] | dict[str, list[t.Any]]: + if as_series: + return { + key: self._series._from_compliant(value) + for key, value in self._compliant.to_dict(as_series=as_series).items() + } + return self._compliant.to_dict(as_series=as_series) + def __len__(self) -> int: return len(self._compliant) + def select(self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: t.Any) -> Self: + irs, schema_frozen, output_names = expr_expansion.prepare_projection( + parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), self.schema + ) + named_irs = expr_expansion.into_named_irs(irs, output_names) + named_irs, schema_projected = schema_frozen.project(named_irs, ExprContext.SELECT) + return self._from_compliant(self._compliant.select(named_irs, schema_projected)) + -class DummyCompliantFrame(Generic[NativeFrameT, NativeSeriesT]): +class DummyCompliantFrame(Generic[CompliantSeriesT, NativeFrameT, NativeSeriesT]): _native: NativeFrameT _version: Version @@ -807,8 +873,15 @@ def version(self) -> Version: return self._version @property - def _series(self) -> type[DummyCompliantSeries[NativeSeriesT]]: - return DummyCompliantSeries[NativeSeriesT] + def columns(self) -> list[str]: + raise NotImplementedError + + @property + def _series(self) -> type[CompliantSeriesT]: + raise NotImplementedError + + def to_narwhals(self) -> DummyFrame[NativeFrameT, NativeSeriesT]: + raise NotImplementedError @classmethod def from_native(cls, native: NativeFrameT, /, version: Version) -> Self: @@ -817,9 +890,54 @@ def from_native(cls, native: NativeFrameT, /, version: Version) -> Self: obj._version = version return obj + @classmethod + def from_series( + cls, + series: Iterable[CompliantSeriesT] | CompliantSeriesT, + *more_series: CompliantSeriesT, + ) -> Self: + """Return a new DataFrame, horizontally concatenating multiple Series.""" + raise NotImplementedError + + @classmethod + def from_dict( + cls, + data: Mapping[str, t.Any], + /, + *, + schema: Mapping[str, DType] | Schema | None = None, + ) -> Self: + raise NotImplementedError + + @t.overload + def to_dict(self, *, as_series: t.Literal[True]) -> dict[str, CompliantSeriesT]: ... + @t.overload + def to_dict(self, *, as_series: t.Literal[False]) -> dict[str, list[t.Any]]: ... + @t.overload + def to_dict( + self, *, as_series: bool + ) -> dict[str, CompliantSeriesT] | dict[str, list[t.Any]]: ... + + def to_dict( + self, *, as_series: bool + ) -> dict[str, CompliantSeriesT] | dict[str, list[t.Any]]: + raise NotImplementedError + def __len__(self) -> int: raise NotImplementedError + @property + def schema(self) -> Mapping[str, DType]: + raise NotImplementedError + + def _evaluate_irs( + self, nodes: Iterable[NamedIR[ExprIR]], / + ) -> Iterator[CompliantSeriesT]: + raise NotImplementedError + + def select(self, irs: Seq[NamedIR], projected: FrozenSchema) -> Self: + return self.from_series(self._evaluate_irs(irs)) + class DummySeries(Generic[NativeSeriesT]): _compliant: DummyCompliantSeries[NativeSeriesT] @@ -837,12 +955,24 @@ def dtype(self) -> DType: def name(self) -> str: return self._compliant.name + # NOTE: Gave up on trying to get typing working for now @classmethod - def from_native(cls, native: NativeSeriesT, name: str = "", /) -> Self: + def from_native( + cls, native: NativeSeries, name: str = "", / + ) -> DummySeries[pa.ChunkedArray[t.Any]]: + if is_pyarrow_chunked_array(native): + from narwhals._plan.arrow.series import ArrowSeries + + return ArrowSeries.from_native( + native, name, version=cls._version + ).to_narwhals() + + raise NotImplementedError(type(native)) + + @classmethod + def _from_compliant(cls, compliant: DummyCompliantSeries[NativeSeriesT], /) -> Self: obj = cls.__new__(cls) - obj._compliant = DummyCompliantSeries[NativeSeriesT].from_native( - native, name, version=cls._version - ) + obj._compliant = compliant return obj def to_native(self) -> NativeSeriesT: @@ -877,6 +1007,9 @@ def dtype(self) -> DType: def name(self) -> str: return self._name + def to_narwhals(self) -> DummySeries[NativeSeriesT]: + return DummySeries[NativeSeriesT]._from_compliant(self) + @classmethod def from_native( cls, native: NativeSeriesT, name: str = "", /, *, version: Version = Version.MAIN diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 3567eab69a..90fb8f4992 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -274,10 +274,10 @@ def test_is_in_seq(into_iter: IntoIterable) -> None: def test_is_in_series() -> None: - pytest.importorskip("polars") - import polars as pl + pytest.importorskip("pyarrow") + import pyarrow as pa - native = pl.Series([1, 2, 3]) + native = pa.chunked_array([pa.array([1, 2, 3])]) other = DummySeries.from_native(native) expr = nwd.col("a").is_in(other) ir = expr._ir diff --git a/tests/plan/to_compliant_test.py b/tests/plan/to_compliant_test.py index 837fbd2c52..f60478f464 100644 --- a/tests/plan/to_compliant_test.py +++ b/tests/plan/to_compliant_test.py @@ -62,9 +62,9 @@ def test_select( pytest.importorskip("pyarrow") import pyarrow as pa - from narwhals._plan.arrow.dataframe import ArrowDataFrame + from narwhals._plan.dummy import DummyFrame frame = pa.table(data_small) - df = ArrowDataFrame.from_native(frame, Version.MAIN) + df = DummyFrame.from_native(frame) result = df.select(expr).to_dict(as_series=False) assert result == expected From 86d56adf675c78bbbb9d591397d3b5c6ba9c047f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 5 Jul 2025 11:30:43 +0100 Subject: [PATCH 271/368] feat: Add `NamedIR.from_name` --- narwhals/_plan/common.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 0b35a9411a..cf5ce0160c 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -33,7 +33,7 @@ DummySelector, DummySeries, ) - from narwhals._plan.expr import Agg, BinaryExpr, FunctionExpr, WindowExpr + from narwhals._plan.expr import Agg, BinaryExpr, Column, FunctionExpr, WindowExpr from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.options import FunctionOptions from narwhals.typing import NonNestedDType, NonNestedLiteral @@ -305,6 +305,16 @@ class NamedIR(Immutable, Generic[ExprIRT]): expr: ExprIRT name: str + @staticmethod + def from_name(name: str, /) -> NamedIR[Column]: + """Construct as a simple, unaliased `col(name)` expression. + + Intended to be used in `with_columns` from a `FrozenSchema`'s keys. + """ + from narwhals._plan.expr import col + + return NamedIR(expr=col(name), name=name) + def map_ir(self, function: MapIR, /) -> NamedIR[ExprIR]: """**WARNING**: don't use renaming ops here, or `self.name` is invalid.""" return self.with_expr(function(self.expr.map_ir(function))) From a193af0d7c69bde16cd4a843b04f8e35a560ed60 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 5 Jul 2025 11:51:21 +0100 Subject: [PATCH 272/368] test: Add `test_lit_series_roundtrip` --- narwhals/_plan/arrow/series.py | 6 ++++++ narwhals/_plan/dummy.py | 8 +++++++- tests/plan/expr_parsing_test.py | 21 +++++++++++++++++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index 7e9bc17ede..f8189b848d 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -2,12 +2,18 @@ from typing import TYPE_CHECKING, Any +from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._plan.dummy import DummyCompliantSeries if TYPE_CHECKING: from narwhals._arrow.typing import ChunkedArrayAny # noqa: F401 + from narwhals.dtypes import DType class ArrowSeries(DummyCompliantSeries["ChunkedArrayAny"]): def to_list(self) -> list[Any]: return self.native.to_pylist() + + @property + def dtype(self) -> DType: + return native_to_narwhals_dtype(self.native.type, self._version) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index b60750a771..fe4760232e 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -978,6 +978,9 @@ def _from_compliant(cls, compliant: DummyCompliantSeries[NativeSeriesT], /) -> S def to_native(self) -> NativeSeriesT: return self._compliant.native + def to_list(self) -> list[t.Any]: + return self._compliant.to_list() + def __iter__(self) -> t.Iterator[t.Any]: yield from self.to_native() @@ -1001,7 +1004,7 @@ def version(self) -> Version: @property def dtype(self) -> DType: - return self.version.dtypes.Float64() + raise NotImplementedError @property def name(self) -> str: @@ -1031,3 +1034,6 @@ def alias(self, name: str) -> Self: def __len__(self) -> int: return len(self.native) + + def to_list(self) -> list[t.Any]: + raise NotADirectoryError diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 90fb8f4992..132c71a496 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -12,11 +12,13 @@ import narwhals._plan.demo as nwd from narwhals._plan import ( boolean, + expr, functions as F, # noqa: N812 ) from narwhals._plan.common import ExprIR, Function from narwhals._plan.dummy import DummyExpr, DummySeries from narwhals._plan.expr import BinaryExpr, FunctionExpr, RangeExpr +from narwhals._plan.literal import SeriesLiteral from narwhals.exceptions import ( InvalidIntoExprError, InvalidOperationError, @@ -363,3 +365,22 @@ def test_filter_partial_spellings( ) -> None: with context: assert nwd.col("a").filter(*predicates, **constraints) + + +def test_lit_series_roundtrip() -> None: + pytest.importorskip("pyarrow") + import pyarrow as pa + + data = ["a", "b", "c"] + native = pa.chunked_array([pa.array(data)]) + series = DummySeries.from_native(native) + lit_series = nwd.lit(series) # type: ignore[arg-type] + assert lit_series.meta.is_literal() + ir = lit_series._ir + assert isinstance(ir, expr.Literal) + assert isinstance(ir.dtype, nw.String) + assert isinstance(ir.value, SeriesLiteral) + unwrapped = ir.unwrap() + assert isinstance(unwrapped, DummySeries) + assert isinstance(unwrapped.to_native(), pa.ChunkedArray) + assert unwrapped.to_list() == data From 8e477886df3f9afed520106b43a04d1c76575c6d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 5 Jul 2025 12:10:35 +0100 Subject: [PATCH 273/368] fix(typing): Propagate `NativeSeriesT` --- narwhals/_plan/boolean.py | 8 +++++--- narwhals/_plan/common.py | 5 ++++- narwhals/_plan/demo.py | 10 +++++----- narwhals/_plan/literal.py | 16 ++++++++++------ tests/plan/expr_parsing_test.py | 2 +- 5 files changed, 25 insertions(+), 16 deletions(-) diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/boolean.py index fda6733ed5..1779d6e885 100644 --- a/narwhals/_plan/boolean.py +++ b/narwhals/_plan/boolean.py @@ -12,7 +12,7 @@ from narwhals._plan.common import ExprIR from narwhals._plan.dummy import DummySeries from narwhals._plan.expr import Literal # noqa: F401 - from narwhals._plan.typing import Seq # noqa: F401 + from narwhals._plan.typing import NativeSeriesT, Seq # noqa: F401 from narwhals.typing import ClosedInterval OtherT = TypeVar("OtherT") @@ -128,9 +128,11 @@ def from_iterable(cls, other: t.Iterable[t.Any], /) -> IsInSeq: # NOTE: Shouldn't be allowed for lazy backends (maybe besides `polars`) -class IsInSeries(IsIn["Literal[DummySeries]"]): +class IsInSeries(IsIn["Literal[DummySeries[NativeSeriesT]]"]): @classmethod - def from_series(cls, other: DummySeries, /) -> IsInSeries: + def from_series( + cls, other: DummySeries[NativeSeriesT], / + ) -> IsInSeries[NativeSeriesT]: from narwhals._plan.literal import SeriesLiteral return IsInSeries(other=SeriesLiteral(value=other).to_literal()) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index cf5ce0160c..7fc90000b1 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -13,6 +13,7 @@ IRNamespaceT, MapIR, NamedOrExprIRT, + NativeSeriesT, NonNestedDTypeT, Ns, Seq, @@ -435,7 +436,9 @@ def is_column(obj: Any) -> TypeIs[DummyExpr]: return is_expr(obj) and obj.meta.is_column() -def is_series(obj: Any) -> TypeIs[DummySeries]: +def is_series( + obj: DummySeries[NativeSeriesT] | Any, +) -> TypeIs[DummySeries[NativeSeriesT]]: from narwhals._plan.dummy import DummySeries return isinstance(obj, DummySeries) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index d2835f029e..abed55bf2a 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -13,9 +13,9 @@ ExprIR, into_dtype, is_non_nested_literal, + is_series, py_to_narwhals_dtype, ) -from narwhals._plan.dummy import DummySeries from narwhals._plan.expr import All, Column, Columns, IndexColumns, Len, Nth from narwhals._plan.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.ranges import IntRange @@ -27,9 +27,9 @@ if t.TYPE_CHECKING: from typing_extensions import TypeIs - from narwhals._plan.dummy import DummyExpr + from narwhals._plan.dummy import DummyExpr, DummySeries from narwhals._plan.expr import SortBy - from narwhals._plan.typing import IntoExpr, IntoExprColumn + from narwhals._plan.typing import IntoExpr, IntoExprColumn, NativeSeriesT from narwhals.dtypes import IntegerType from narwhals.typing import IntoDType, NonNestedLiteral @@ -55,9 +55,9 @@ def nth(*indices: int | t.Sequence[int]) -> DummyExpr: def lit( - value: NonNestedLiteral | DummySeries, dtype: IntoDType | None = None + value: NonNestedLiteral | DummySeries[NativeSeriesT], dtype: IntoDType | None = None ) -> DummyExpr: - if isinstance(value, DummySeries): + if is_series(value): return SeriesLiteral(value=value).to_literal().to_narwhals() if not is_non_nested_literal(value): msg = f"{type(value).__name__!r} is not supported in `nw.lit`, got: {value!r}." diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index c8a9cdcb6f..9a0cb13e6e 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Generic from narwhals._plan.common import Immutable -from narwhals._plan.typing import LiteralT, NonNestedLiteralT +from narwhals._plan.typing import LiteralT, NativeSeriesT, NonNestedLiteralT if TYPE_CHECKING: from typing_extensions import TypeIs @@ -56,7 +56,7 @@ def unwrap(self) -> NonNestedLiteralT: return self.value -class SeriesLiteral(LiteralValue["DummySeries"]): +class SeriesLiteral(LiteralValue["DummySeries[NativeSeriesT]"]): """We already need this. https://github.com/narwhals-dev/narwhals/blob/e51eba891719a5eb1f7ce91c02a477af39c0baee/narwhals/_expression_parsing.py#L96-L97 @@ -64,7 +64,7 @@ class SeriesLiteral(LiteralValue["DummySeries"]): __slots__ = ("value",) - value: DummySeries + value: DummySeries[NativeSeriesT] @property def dtype(self) -> DType: @@ -77,7 +77,7 @@ def name(self) -> str: def __repr__(self) -> str: return "Series" - def unwrap(self) -> DummySeries: + def unwrap(self) -> DummySeries[NativeSeriesT]: return self.value @@ -87,7 +87,9 @@ def _is_scalar( return isinstance(obj, ScalarLiteral) -def _is_series(obj: Any) -> TypeIs[SeriesLiteral]: +def _is_series( + obj: SeriesLiteral[NativeSeriesT] | Any, +) -> TypeIs[SeriesLiteral[NativeSeriesT]]: return isinstance(obj, SeriesLiteral) @@ -103,5 +105,7 @@ def is_literal_scalar( return is_literal(obj) and _is_scalar(obj.value) -def is_literal_series(obj: Literal[DummySeries] | Any) -> TypeIs[Literal[DummySeries]]: +def is_literal_series( + obj: Literal[DummySeries[NativeSeriesT]] | Any, +) -> TypeIs[Literal[DummySeries[NativeSeriesT]]]: return is_literal(obj) and _is_series(obj.value) diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 132c71a496..85badfaeb7 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -374,7 +374,7 @@ def test_lit_series_roundtrip() -> None: data = ["a", "b", "c"] native = pa.chunked_array([pa.array(data)]) series = DummySeries.from_native(native) - lit_series = nwd.lit(series) # type: ignore[arg-type] + lit_series = nwd.lit(series) assert lit_series.meta.is_literal() ir = lit_series._ir assert isinstance(ir, expr.Literal) From 9e2e2b974fbb46353de6d51b442a39c7de233016 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 5 Jul 2025 12:29:00 +0100 Subject: [PATCH 274/368] chore(typing): Link to pyright explainer --- narwhals/_plan/common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 7fc90000b1..68299f1ab4 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -542,6 +542,9 @@ def map_ir( return origin.map_ir(function) +# TODO @dangotbanned: Review again and try to work around (https://github.com/microsoft/pyright/issues/10673#issuecomment-3033789021) +# The issue is `T` possibly being `Iterable` +# Ignoring here still leaks the issue to the caller, where you need to annotate the base case def flatten_hash_safe(iterable: Iterable[T | Iterable[T]], /) -> Iterator[T]: """Fully unwrap all levels of nesting. From 154a3a050b61cfc423a14640e6268a1034981d83 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 5 Jul 2025 13:10:37 +0100 Subject: [PATCH 275/368] test: Update tests that shouldn't broadcast New broadcasting behavior hasn't been integrated into `_evaluate_inner`, which is always broadcasting to input length --- tests/plan/to_compliant_test.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/tests/plan/to_compliant_test.py b/tests/plan/to_compliant_test.py index f60478f464..438b7970a0 100644 --- a/tests/plan/to_compliant_test.py +++ b/tests/plan/to_compliant_test.py @@ -45,14 +45,25 @@ def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None: assert isinstance(compliant_expr, namespace._expr) +XFAIL_BROADCAST = pytest.mark.xfail( + reason="Shouldn't broadcast when all Series are length 1." +) + + @pytest.mark.parametrize( ("expr", "expected"), [ (nwd.col("a"), {"a": ["A", "B", "A"]}), (nwd.col("a", "b"), {"a": ["A", "B", "A"], "b": [1, 2, 3]}), - (nwd.lit(1), {"literal": [1, 1, 1]}), - (nwd.lit(2.0), {"literal": [2.0, 2.0, 2.0]}), - (nwd.lit(None, nw.String()), {"literal": [None, None, None]}), + pytest.param(nwd.lit(1), {"literal": [1]}, marks=XFAIL_BROADCAST), + pytest.param(nwd.lit(2.0), {"literal": [2.0]}, marks=XFAIL_BROADCAST), + pytest.param( + nwd.lit(None, nw.String()), {"literal": [None]}, marks=XFAIL_BROADCAST + ), + pytest.param( + nwd.col("a", "b").first(), {"a": ["A"], "b": [1]}, marks=XFAIL_BROADCAST + ), + pytest.param(nwd.col("d").max(), {"d": [8]}, marks=XFAIL_BROADCAST), ], ids=_ids_ir, ) From 6bf86d84f2b79f2083b536f2c3e5d709411ce003 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 5 Jul 2025 23:27:47 +0100 Subject: [PATCH 276/368] keep on iterating - Getting closer, but the `Expr`/`Scalar` concept needs more work - It is looking like we could loosen up what kinds of expressions are allowed --- narwhals/_plan/arrow/evaluate.py | 6 +- narwhals/_plan/arrow/expr.py | 267 +++++++++++++++++++++++++++- narwhals/_plan/common.py | 14 +- narwhals/_plan/dummy.py | 2 +- narwhals/_plan/protocols.py | 295 +++++++++++++++++++++++++++---- 5 files changed, 538 insertions(+), 46 deletions(-) diff --git a/narwhals/_plan/arrow/evaluate.py b/narwhals/_plan/arrow/evaluate.py index ec8c9f919b..a408db7158 100644 --- a/narwhals/_plan/arrow/evaluate.py +++ b/narwhals/_plan/arrow/evaluate.py @@ -23,7 +23,7 @@ from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.common import ExprIR, NamedIR from narwhals._plan.dummy import DummySeries - from narwhals._plan.protocols import SupportsBroadcast + from narwhals._plan.protocols import EagerBroadcast from narwhals.typing import NonNestedLiteral, PythonLiteral @@ -36,9 +36,7 @@ def is_scalar(obj: t.Any) -> TypeIs[ScalarAny]: return isinstance(obj, pa.Scalar) -def evaluate( - node: NamedIR[ExprIR], frame: ArrowDataFrame -) -> SupportsBroadcast[ArrowSeries]: +def evaluate(node: NamedIR[ExprIR], frame: ArrowDataFrame) -> EagerBroadcast[ArrowSeries]: result = _evaluate_inner(node.expr, frame) if is_scalar(result): return frame._lit.from_scalar(result, node.name) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index f44e59a7c9..e914c6dbec 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -1,28 +1,281 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from functools import singledispatchmethod +from typing import TYPE_CHECKING, Any import pyarrow as pa # ignore-banned-import +import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import chunked_array, narwhals_to_native_dtype +from narwhals._plan import expr from narwhals._plan.arrow.series import ArrowSeries -from narwhals._plan.common import into_dtype -from narwhals._plan.protocols import SupportsBroadcast +from narwhals._plan.common import ExprIR, NamedIR, into_dtype +from narwhals._plan.protocols import EagerBroadcast, EagerExpr, EagerScalar from narwhals._utils import Version from narwhals.exceptions import InvalidOperationError, ShapeError if TYPE_CHECKING: - from typing_extensions import Self + from typing_extensions import Self, TypeAlias from narwhals._arrow.typing import ChunkedArrayAny, Incomplete, ScalarAny - from narwhals._plan import expr + from narwhals._plan.aggregation import ( + ArgMax, + ArgMin, + Count, + First, + Last, + Max, + Mean, + Median, + Min, + NUnique, + Quantile, + Std, + Sum, + Var, + ) + from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.dummy import DummySeries from narwhals.typing import IntoDType, NonNestedLiteral, PythonLiteral +NativeScalar: TypeAlias = "pa.Scalar[Any]" + + +class ArrowExpr2(EagerExpr["ArrowDataFrame", ArrowSeries]): + _evaluated: ArrowSeries + + @property + def name(self) -> str: + return self._evaluated.name + + @property + def version(self) -> Version: + return self._evaluated.version + + @classmethod + def from_series(cls, series: ArrowSeries, /) -> Self: + obj = cls.__new__(cls) + obj._evaluated = series + return obj + + @classmethod + def from_native( + cls, native: ChunkedArrayAny, name: str = "", /, version: Version = Version.MAIN + ) -> Self: + return cls.from_series(ArrowSeries.from_native(native, name, version=version)) + + def _with_native( + self, result: ChunkedArrayAny | NativeScalar, name: str = "", / + ) -> Self: + if isinstance(result, pa.Scalar): + # NOTE: Will need to resolve this eventually + # Currently the *least bad* option is the single ignore here + return ArrowScalar.from_native(result, name, version=self.version) # type: ignore[return-value] + return super()._with_native(result, name) + + @property + def native(self) -> ChunkedArrayAny: + return self._evaluated.native + + def to_series(self) -> ArrowSeries: + return self._evaluated + + def broadcast(self, length: int, /) -> ArrowSeries: + if (actual_len := len(self)) != length: + msg = f"Expected object of length {length}, got {actual_len}." + raise ShapeError(msg) + return self._evaluated + + def __len__(self) -> int: + return len(self._evaluated) + + # NOTE: Dispatch is on `ExprIR`, which is recursive + # There is only a top-level `NamedIR` per column + def evaluate(self, named_ir: NamedIR[ExprIR], frame: ArrowDataFrame) -> ArrowExpr2: + return self._evaluate_inner(named_ir.expr, frame, named_ir.name) + + # NOTE: Don't use `Self`, it breaks the descriptor typing + # The implementations *can* use `Self`, just not here + @singledispatchmethod + def _evaluate_inner( + self, node: ExprIR, frame: ArrowDataFrame, name: str + ) -> ArrowExpr2: + raise NotImplementedError(type(node)) + + @_evaluate_inner.register(expr.Cast) + def cast(self, node: expr.Cast, frame: ArrowDataFrame, name: str) -> Self: + data_type = narwhals_to_native_dtype(node.dtype, frame.version) + native = self._evaluate_inner(node.expr, frame, name).native + return self._with_native(pc.cast(native, data_type), name) + + def sort(self, node: expr.Sort, frame: ArrowDataFrame, name: str) -> ArrowExpr2: + raise NotImplementedError + + def sort_by(self, node: expr.SortBy, frame: ArrowDataFrame, name: str) -> ArrowExpr2: + raise NotImplementedError + + def filter(self, node: expr.Filter, frame: ArrowDataFrame, name: str) -> ArrowExpr2: + raise NotImplementedError + + def first(self, node: First, frame: ArrowDataFrame, name: str) -> ArrowScalar: + raise NotImplementedError + + def last(self, node: Last, frame: ArrowDataFrame, name: str) -> ArrowScalar: + raise NotImplementedError + + def arg_min(self, node: ArgMin, frame: ArrowDataFrame, name: str) -> ArrowScalar: + raise NotImplementedError + + def arg_max(self, node: ArgMax, frame: ArrowDataFrame, name: str) -> ArrowScalar: + raise NotImplementedError + + def sum(self, node: Sum, frame: ArrowDataFrame, name: str) -> ArrowScalar: + raise NotImplementedError + + def n_unique(self, node: NUnique, frame: ArrowDataFrame, name: str) -> ArrowScalar: + raise NotImplementedError + + def std(self, node: Std, frame: ArrowDataFrame, name: str) -> ArrowScalar: + raise NotImplementedError + + def var(self, node: Var, frame: ArrowDataFrame, name: str) -> ArrowScalar: + raise NotImplementedError + + def quantile(self, node: Quantile, frame: ArrowDataFrame, name: str) -> ArrowScalar: + raise NotImplementedError + + def count(self, node: Count, frame: ArrowDataFrame, name: str) -> ArrowScalar: + raise NotImplementedError + + def max(self, node: Max, frame: ArrowDataFrame, name: str) -> ArrowScalar: + raise NotImplementedError + + def mean(self, node: Mean, frame: ArrowDataFrame, name: str) -> ArrowScalar: + raise NotImplementedError + + def median(self, node: Median, frame: ArrowDataFrame, name: str) -> ArrowScalar: + raise NotImplementedError + + def min(self, node: Min, frame: ArrowDataFrame, name: str) -> ArrowScalar: + raise NotImplementedError + + +class ArrowScalar(EagerScalar["ArrowDataFrame", ArrowSeries]): + _name: str + _version: Version + _evaluated: NativeScalar + + @property + def name(self) -> str: + return self._name + + @classmethod + def from_native( + cls, + scalar: NativeScalar, + name: str = "literal", + /, + version: Version = Version.MAIN, + ) -> Self: + obj = cls.__new__(cls) + obj._evaluated = scalar + obj._name = name + obj._version = version + return obj + + @classmethod + def from_python( + cls, + value: PythonLiteral, + name: str = "literal", + /, + *, + dtype: IntoDType | None = None, + version: Version = Version.MAIN, + ) -> Self: + dtype_pa: pa.DataType | None = None + if dtype: + dtype = into_dtype(dtype) + if not isinstance(dtype, version.dtypes.Unknown): + dtype_pa = narwhals_to_native_dtype(dtype, version) + # NOTE: PR that fixed this was closed + # https://github.com/zen-xu/pyarrow-stubs/pull/208 + lit: Incomplete = pa.scalar + return cls.from_native(lit(value, dtype_pa), name, version) + + @classmethod + def from_series(cls, series: ArrowSeries) -> Self: + if len(series) == 1: + return cls.from_native(series.native[0], series.name, series.version) + elif len(series) == 0: + return cls.from_python( + None, series.name, dtype=series.dtype, version=series.version + ) + else: + msg = f"Too long {len(series)!r}" + raise InvalidOperationError(msg) + + @property + def native(self) -> NativeScalar: + return self._evaluated + + def to_series(self) -> ArrowSeries: + return self.broadcast(1) + + def broadcast(self, length: int) -> ArrowSeries: + scalar = self.native + if length == 1: + chunked = chunked_array([[scalar]]) + else: + # NOTE: Same issue as `pa.scalar` overlapping overloads + # https://github.com/zen-xu/pyarrow-stubs/pull/209 + pa_repeat: Incomplete = pa.repeat + chunked = chunked_array(pa_repeat(scalar, length)) + return ArrowSeries.from_native(chunked, self.name, version=self.version) + + # NOTE: Dispatch is on `ExprIR`, which is recursive + # There is only a top-level `NamedIR` per column + def evaluate(self, named_ir: NamedIR[ExprIR], frame: ArrowDataFrame) -> ArrowScalar: + return self._evaluate_inner(named_ir.expr, frame, named_ir.name) + + @singledispatchmethod + def _evaluate_inner( + self, node: ExprIR, frame: ArrowDataFrame, name: str + ) -> ArrowScalar: + raise NotImplementedError(type(node)) + + @_evaluate_inner.register(expr.Cast) + def cast(self, node: expr.Cast, frame: ArrowDataFrame, name: str) -> ArrowScalar: + data_type = narwhals_to_native_dtype(node.dtype, frame.version) + native = self._evaluate_inner(node.expr, frame, name).native + return self._with_native(pc.cast(native, data_type), name) + + def filter(self, node: expr.Filter, frame: ArrowDataFrame, name: str) -> Any: + raise NotImplementedError + + def arg_min(self, node: ArgMin, frame: ArrowDataFrame, name: str) -> ArrowScalar: + return self._with_native(pa.scalar(0), name) + + def arg_max(self, node: ArgMax, frame: ArrowDataFrame, name: str) -> ArrowScalar: + return self._with_native(pa.scalar(0), name) + + def n_unique(self, node: NUnique, frame: ArrowDataFrame, name: str) -> ArrowScalar: + return self._with_native(pa.scalar(1), name) + + def std(self, node: Std, frame: ArrowDataFrame, name: str) -> ArrowScalar: + return self._with_native(pa.scalar(None, pa.null()), name) + + def var(self, node: Var, frame: ArrowDataFrame, name: str) -> ArrowScalar: + return self._with_native(pa.scalar(None, pa.null()), name) + + def count(self, node: Count, frame: ArrowDataFrame, name: str) -> ArrowScalar: + native = self._evaluate_inner(node.expr, frame, name).native + return self._with_native(pa.scalar(1 if native.is_valid else 0), name) + # NOTE: General expression result # Mostly elementwise -class ArrowExpr(SupportsBroadcast[ArrowSeries]): +class ArrowExpr(EagerBroadcast[ArrowSeries]): _compliant: ArrowSeries @classmethod @@ -61,7 +314,7 @@ def broadcast(self, length: int, /) -> ArrowSeries: # NOTE: Aggregation result or scalar # Should handle broadcasting, without exposing it -class ArrowLiteral(SupportsBroadcast[ArrowSeries]): +class ArrowLiteral(EagerBroadcast[ArrowSeries]): _native_scalar: ScalarAny _name: str diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 68299f1ab4..3fd364a42f 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -34,7 +34,14 @@ DummySelector, DummySeries, ) - from narwhals._plan.expr import Agg, BinaryExpr, Column, FunctionExpr, WindowExpr + from narwhals._plan.expr import ( + Agg, + BinaryExpr, + Cast, + Column, + FunctionExpr, + WindowExpr, + ) from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.options import FunctionOptions from narwhals.typing import NonNestedDType, NonNestedLiteral @@ -268,6 +275,11 @@ def meta(self) -> IRMetaNamespace: return IRMetaNamespace(_ir=self) + def cast(self, dtype: DType) -> Cast: + from narwhals._plan.expr import Cast + + return Cast(expr=self, dtype=dtype) + def _repr_html_(self) -> str: return self.__repr__() diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index fe4760232e..5b7e9f2918 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -102,7 +102,7 @@ def alias(self, name: str) -> Self: def cast(self, dtype: DType | type[DType]) -> Self: dtype = dtype if isinstance(dtype, DType) else self.version.dtypes.Unknown() - return self._from_ir(expr.Cast(expr=self._ir, dtype=dtype)) + return self._from_ir(self._ir.cast(dtype)) def exclude(self, *names: str | t.Iterable[str]) -> Self: return self._from_ir(expr.Exclude.from_names(self._ir, *names)) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 971f502f64..3713808e89 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -1,41 +1,90 @@ from __future__ import annotations -from collections.abc import Iterable, Iterator, Sized +from collections.abc import Iterable, Iterator, Sequence, Sized from typing import TYPE_CHECKING, Any, Protocol -from narwhals._plan.common import flatten_hash_safe +from narwhals._plan.common import ExprIR, flatten_hash_safe from narwhals._typing_compat import TypeVar +from narwhals._utils import Version if TYPE_CHECKING: - from typing_extensions import Self + from typing_extensions import Self, TypeAlias + from narwhals._plan import aggregation as agg, expr + from narwhals.typing import IntoDType, PythonLiteral + +T = TypeVar("T") SeriesT = TypeVar("SeriesT") +SeriesT_co = TypeVar("SeriesT_co", covariant=True) +FrameT_contra = TypeVar("FrameT_contra", contravariant=True) +OneOrIterable: TypeAlias = "T | Iterable[T]" +LengthT = TypeVar("LengthT") +NativeT_co = TypeVar("NativeT_co", covariant=True, default=Any) -class SupportsBroadcast(Sized, Protocol[SeriesT]): +class SupportsBroadcast(Protocol[SeriesT, LengthT]): """Minimal broadcasting for `Expr` results.""" @classmethod def from_series(cls, series: SeriesT, /) -> Self: ... def to_series(self) -> SeriesT: ... - def broadcast(self, length: int, /) -> SeriesT: ... + def broadcast(self, length: LengthT, /) -> SeriesT: ... + def _length(self) -> LengthT: + """Return the length of the current expression.""" + ... + + @classmethod + def _length_max(cls, lengths: Sequence[LengthT], /) -> LengthT: + """Return the maximum length among `exprs`.""" + ... + + @classmethod + def _length_required( + cls, exprs: Sequence[SupportsBroadcast[SeriesT, LengthT]], / + ) -> LengthT | None: + """Return the broadcast length, if all lengths do not equal the maximum.""" + + @classmethod + def _length_all( + cls, exprs: Sequence[SupportsBroadcast[SeriesT, LengthT]], / + ) -> Sequence[LengthT]: + return [e._length() for e in exprs] + @classmethod def align( - cls, *exprs: SupportsBroadcast[SeriesT] | Iterable[SupportsBroadcast[SeriesT]] + cls, *exprs: OneOrIterable[SupportsBroadcast[SeriesT, LengthT]] ) -> Iterator[SeriesT]: - exprs = tuple[SupportsBroadcast[SeriesT], ...](flatten_hash_safe(exprs)) - lengths = [len(e) for e in exprs] - max_length = max(lengths) - fast_path = all(len_ == max_length for len_ in lengths) - if fast_path: + exprs = tuple[SupportsBroadcast[SeriesT, LengthT], ...](flatten_hash_safe(exprs)) + length = cls._length_required(exprs) + if length is None: for e in exprs: yield e.to_series() else: for e in exprs: - yield e.broadcast(max_length) + yield e.broadcast(length) + + +class EagerBroadcast(Sized, SupportsBroadcast[SeriesT, int], Protocol[SeriesT]): + """Determines expression length via the size of the container.""" + + def _length(self) -> int: + return len(self) + @classmethod + def _length_max(cls, lengths: Sequence[int], /) -> int: + return max(lengths) + + @classmethod + def _length_required( + cls, exprs: Sequence[SupportsBroadcast[SeriesT, int]], / + ) -> int | None: + lengths = cls._length_all(exprs) + max_length = cls._length_max(lengths) + required = any(len_ != max_length for len_ in lengths) + return max_length if required else None -class CompliantExpr(Protocol): + +class CompliantExpr(Protocol[FrameT_contra, SeriesT_co, NativeT_co]): """Getting a bit tricky, just storing notes. - Separating series/scalar makes a lot of sense @@ -47,23 +96,203 @@ class CompliantExpr(Protocol): - `polars` noops on aggregating a scalar, which we might be able to support this way """ - # scalar allowed - def cast(self, *args: Any, **kwds: Any) -> Any: ... - # array only (section 3) - def sort(self, *args: Any, **kwds: Any) -> Any: ... - def sort_by(self, *args: Any, **kwds: Any) -> Any: ... - def filter(self, *args: Any, **kwds: Any) -> Any: ... - def first(self, *args: Any, **kwds: Any) -> Any: ... - def last(self, *args: Any, **kwds: Any) -> Any: ... - def arg_min(self, *args: Any, **kwds: Any) -> Any: ... - def arg_max(self, *args: Any, **kwds: Any) -> Any: ... - def sum(self, *args: Any, **kwds: Any) -> Any: ... - def n_unique(self, *args: Any, **kwds: Any) -> Any: ... - def std(self, *args: Any, **kwds: Any) -> Any: ... - def var(self, *args: Any, **kwds: Any) -> Any: ... - def quantile(self, *args: Any, **kwds: Any) -> Any: ... - def count(self, *args: Any, **kwds: Any) -> Any: ... - def max(self, *args: Any, **kwds: Any) -> Any: ... - def mean(self, *args: Any, **kwds: Any) -> Any: ... - def median(self, *args: Any, **kwds: Any) -> Any: ... - def min(self, *args: Any, **kwds: Any) -> Any: ... + _evaluated: Any + """Compliant or native value.""" + + @property + def version(self) -> Version: ... + @property + def name(self) -> str: ... + + @property + def native(self) -> NativeT_co: ... + + @classmethod + def from_native( + cls, native: Any, name: str = "", /, version: Version = Version.MAIN + ) -> Self: ... + + def _with_native(self, native: Any, name: str = "", /) -> Self: + return self.from_native(native, name or self.name, self.version) + + # series & scalar + def cast(self, node: expr.Cast, frame: FrameT_contra, name: str) -> Self: ... + # series only (section 3) + def sort(self, node: expr.Sort, frame: FrameT_contra, name: str) -> Self: ... + def sort_by(self, node: expr.SortBy, frame: FrameT_contra, name: str) -> Self: ... + def filter(self, node: expr.Filter, frame: FrameT_contra, name: str) -> Self: ... + # series -> scalar + def first( + self, node: agg.First, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def last( + self, node: agg.Last, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def arg_min( + self, node: agg.ArgMin, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def arg_max( + self, node: agg.ArgMax, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def sum( + self, node: agg.Sum, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def n_unique( + self, node: agg.NUnique, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def std( + self, node: agg.Std, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def var( + self, node: agg.Var, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def quantile( + self, node: agg.Quantile, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def count( + self, node: agg.Count, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def max( + self, node: agg.Max, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def mean( + self, node: agg.Mean, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def median( + self, node: agg.Median, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def min( + self, node: agg.Min, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + + +class CompliantScalar( + CompliantExpr[FrameT_contra, SeriesT_co, NativeT_co], + Protocol[FrameT_contra, SeriesT_co, NativeT_co], +): + _name: str + _version: Version + + @property + def name(self) -> str: + return self._name + + @property + def version(self) -> Version: + return self._version + + @classmethod + def from_python( + cls, + value: PythonLiteral, + name: str = "literal", + /, + *, + dtype: IntoDType | None, + version: Version, + ) -> Self: ... + + def _with_evaluated(self, evaluated: Any, name: str = "") -> Self: + """Expr is based on a series having these via accessors, but a scalar needs to keep passing through.""" + cls = type(self) + obj = cls.__new__(cls) + obj._evaluated = evaluated + obj._name = name or self.name + obj._version = self.version + return obj + + def max(self, node: agg.Max, frame: FrameT_contra, name: str) -> Self: + """Returns self.""" + return self._with_evaluated(self._evaluated, name) + + def min(self, node: agg.Min, frame: FrameT_contra, name: str) -> Self: + """Returns self.""" + return self._with_evaluated(self._evaluated, name) + + def sum(self, node: agg.Sum, frame: FrameT_contra, name: str) -> Self: + """Returns self.""" + return self._with_evaluated(self._evaluated, name) + + def first(self, node: agg.First, frame: FrameT_contra, name: str) -> Self: + """Returns self.""" + return self._with_evaluated(self._evaluated, name) + + def last(self, node: agg.Last, frame: FrameT_contra, name: str) -> Self: + """Returns self.""" + return self._with_evaluated(self._evaluated, name) + + def _cast_float(self, node: ExprIR, frame: FrameT_contra, name: str) -> Self: + """`polars` interpolates a single scalar as a float.""" + dtype = self.version.dtypes.Float64() + return self.cast(node.cast(dtype), frame, name) + + def mean(self, node: agg.Mean, frame: FrameT_contra, name: str) -> Self: + return self._cast_float(node.expr, frame, name) + + def median(self, node: agg.Median, frame: FrameT_contra, name: str) -> Self: + return self._cast_float(node.expr, frame, name) + + def quantile(self, node: agg.Quantile, frame: FrameT_contra, name: str) -> Self: + return self._cast_float(node.expr, frame, name) + + def n_unique(self, node: agg.NUnique, frame: FrameT_contra, name: str) -> Self: + """Returns 1.""" + ... + + def std(self, node: agg.Std, frame: FrameT_contra, name: str) -> Self: + """Returns null.""" + ... + + def var(self, node: agg.Var, frame: FrameT_contra, name: str) -> Self: + """Returns null.""" + ... + + def arg_min(self, node: agg.ArgMin, frame: FrameT_contra, name: str) -> Self: + """Returns 0.""" + ... + + def arg_max(self, node: agg.ArgMax, frame: FrameT_contra, name: str) -> Self: + """Returns 0.""" + ... + + def count(self, node: agg.Count, frame: FrameT_contra, name: str) -> Self: + """Returns 0 if null, else 1.""" + ... + + def sort(self, node: expr.Sort, frame: FrameT_contra, name: str) -> Self: + return self._with_evaluated(self._evaluated) + + def sort_by(self, node: expr.SortBy, frame: FrameT_contra, name: str) -> Self: + return self._with_evaluated(self._evaluated) + + # NOTE: `Filter` behaves the same, (maybe) no need to override + + +class EagerExpr( + EagerBroadcast[SeriesT], + CompliantExpr[FrameT_contra, SeriesT, NativeT_co], + Protocol[FrameT_contra, SeriesT, NativeT_co], +): ... + + +class LazyExpr( + SupportsBroadcast[SeriesT, LengthT], + CompliantExpr[FrameT_contra, SeriesT, NativeT_co], + Protocol[FrameT_contra, SeriesT, LengthT, NativeT_co], +): ... + + +class EagerScalar( + CompliantScalar[FrameT_contra, SeriesT, NativeT_co], + EagerExpr[FrameT_contra, SeriesT, NativeT_co], + Protocol[FrameT_contra, SeriesT, NativeT_co], +): + def __len__(self) -> int: + return 1 + + +class LazyScalar( + CompliantScalar[FrameT_contra, SeriesT, NativeT_co], + LazyExpr[FrameT_contra, SeriesT, LengthT, NativeT_co], + Protocol[FrameT_contra, SeriesT, LengthT, NativeT_co], +): ... From f2c6566e4fa3c8ec963d0a1a90a306b15bcceb5c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 6 Jul 2025 11:54:15 +0100 Subject: [PATCH 277/368] fix: Fill in missing type params https://github.com/narwhals-dev/narwhals/actions/runs/16092633791/job/45411333375?pr=2572 --- narwhals/_plan/arrow/expr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index e914c6dbec..dbf721ba14 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -41,7 +41,7 @@ NativeScalar: TypeAlias = "pa.Scalar[Any]" -class ArrowExpr2(EagerExpr["ArrowDataFrame", ArrowSeries]): +class ArrowExpr2(EagerExpr["ArrowDataFrame", ArrowSeries, "ChunkedArrayAny"]): _evaluated: ArrowSeries @property @@ -160,7 +160,7 @@ def min(self, node: Min, frame: ArrowDataFrame, name: str) -> ArrowScalar: raise NotImplementedError -class ArrowScalar(EagerScalar["ArrowDataFrame", ArrowSeries]): +class ArrowScalar(EagerScalar["ArrowDataFrame", ArrowSeries, NativeScalar]): _name: str _version: Version _evaluated: NativeScalar From 423ea9a20c511a25c7f244eb306af5766160be47 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 6 Jul 2025 14:30:28 +0100 Subject: [PATCH 278/368] refactor: Move `native` out of higher protocol --- narwhals/_plan/arrow/expr.py | 10 +++++++--- narwhals/_plan/protocols.py | 28 ++++++++++++---------------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index dbf721ba14..4339aba2a6 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -11,7 +11,7 @@ from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.common import ExprIR, NamedIR, into_dtype from narwhals._plan.protocols import EagerBroadcast, EagerExpr, EagerScalar -from narwhals._utils import Version +from narwhals._utils import Version, _StoresNative from narwhals.exceptions import InvalidOperationError, ShapeError if TYPE_CHECKING: @@ -41,7 +41,9 @@ NativeScalar: TypeAlias = "pa.Scalar[Any]" -class ArrowExpr2(EagerExpr["ArrowDataFrame", ArrowSeries, "ChunkedArrayAny"]): +class ArrowExpr2( + _StoresNative["ChunkedArrayAny"], EagerExpr["ArrowDataFrame", ArrowSeries] +): _evaluated: ArrowSeries @property @@ -160,7 +162,9 @@ def min(self, node: Min, frame: ArrowDataFrame, name: str) -> ArrowScalar: raise NotImplementedError -class ArrowScalar(EagerScalar["ArrowDataFrame", ArrowSeries, NativeScalar]): +class ArrowScalar( + _StoresNative[NativeScalar], EagerScalar["ArrowDataFrame", ArrowSeries] +): _name: str _version: Version _evaluated: NativeScalar diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 3713808e89..02860eeccf 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -84,7 +84,7 @@ def _length_required( return max_length if required else None -class CompliantExpr(Protocol[FrameT_contra, SeriesT_co, NativeT_co]): +class CompliantExpr(Protocol[FrameT_contra, SeriesT_co]): """Getting a bit tricky, just storing notes. - Separating series/scalar makes a lot of sense @@ -104,9 +104,6 @@ def version(self) -> Version: ... @property def name(self) -> str: ... - @property - def native(self) -> NativeT_co: ... - @classmethod def from_native( cls, native: Any, name: str = "", /, version: Version = Version.MAIN @@ -167,8 +164,7 @@ def min( class CompliantScalar( - CompliantExpr[FrameT_contra, SeriesT_co, NativeT_co], - Protocol[FrameT_contra, SeriesT_co, NativeT_co], + CompliantExpr[FrameT_contra, SeriesT_co], Protocol[FrameT_contra, SeriesT_co] ): _name: str _version: Version @@ -270,29 +266,29 @@ def sort_by(self, node: expr.SortBy, frame: FrameT_contra, name: str) -> Self: class EagerExpr( EagerBroadcast[SeriesT], - CompliantExpr[FrameT_contra, SeriesT, NativeT_co], - Protocol[FrameT_contra, SeriesT, NativeT_co], + CompliantExpr[FrameT_contra, SeriesT], + Protocol[FrameT_contra, SeriesT], ): ... class LazyExpr( SupportsBroadcast[SeriesT, LengthT], - CompliantExpr[FrameT_contra, SeriesT, NativeT_co], - Protocol[FrameT_contra, SeriesT, LengthT, NativeT_co], + CompliantExpr[FrameT_contra, SeriesT], + Protocol[FrameT_contra, SeriesT, LengthT], ): ... class EagerScalar( - CompliantScalar[FrameT_contra, SeriesT, NativeT_co], - EagerExpr[FrameT_contra, SeriesT, NativeT_co], - Protocol[FrameT_contra, SeriesT, NativeT_co], + CompliantScalar[FrameT_contra, SeriesT], + EagerExpr[FrameT_contra, SeriesT], + Protocol[FrameT_contra, SeriesT], ): def __len__(self) -> int: return 1 class LazyScalar( - CompliantScalar[FrameT_contra, SeriesT, NativeT_co], - LazyExpr[FrameT_contra, SeriesT, LengthT, NativeT_co], - Protocol[FrameT_contra, SeriesT, LengthT, NativeT_co], + CompliantScalar[FrameT_contra, SeriesT], + LazyExpr[FrameT_contra, SeriesT, LengthT], + Protocol[FrameT_contra, SeriesT, LengthT], ): ... From f47fed2da743b879ecf0ba388148dcd300e68caf Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 6 Jul 2025 14:38:33 +0100 Subject: [PATCH 279/368] refactor: Impl dispatch only once? - `singledispatchmethod` would need to be repeated across every class - Doing either this, or using `lambda`s (saw in `sqlglot`) should be less repetitive - typing isn't ideal, but marking as a `ClassVar` is more important --- narwhals/_plan/protocols.py | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 02860eeccf..77105e09ed 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -1,16 +1,16 @@ from __future__ import annotations -from collections.abc import Iterable, Iterator, Sequence, Sized -from typing import TYPE_CHECKING, Any, Protocol +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence, Sized +from typing import TYPE_CHECKING, Any, ClassVar, Protocol -from narwhals._plan.common import ExprIR, flatten_hash_safe +from narwhals._plan import aggregation as agg, expr +from narwhals._plan.common import ExprIR, NamedIR, flatten_hash_safe from narwhals._typing_compat import TypeVar from narwhals._utils import Version if TYPE_CHECKING: from typing_extensions import Self, TypeAlias - from narwhals._plan import aggregation as agg, expr from narwhals.typing import IntoDType, PythonLiteral T = TypeVar("T") @@ -20,6 +20,7 @@ OneOrIterable: TypeAlias = "T | Iterable[T]" LengthT = TypeVar("LengthT") NativeT_co = TypeVar("NativeT_co", covariant=True, default=Any) +ExprAny: TypeAlias = "CompliantExpr[Any, Any]" class SupportsBroadcast(Protocol[SeriesT, LengthT]): @@ -162,6 +163,34 @@ def min( self, node: agg.Min, frame: FrameT_contra, name: str ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + _DISPATCH: ClassVar[Mapping[type[ExprIR], Callable[..., ExprAny]]] = { + expr.Cast: cast, + expr.Sort: sort, + expr.SortBy: sort_by, + expr.Filter: filter, + agg.First: first, + agg.Last: last, + agg.ArgMin: arg_min, + agg.ArgMax: arg_max, + agg.Sum: sum, + agg.NUnique: n_unique, + agg.Std: std, + agg.Var: var, + agg.Quantile: quantile, + agg.Count: count, + agg.Max: max, + agg.Mean: mean, + agg.Median: median, + agg.Min: min, + } + + def _dispatch(self, named_ir: NamedIR[ExprIR], frame: FrameT_contra) -> ExprAny: + return self._dispatch_inner(named_ir.expr, frame, named_ir.name) + + def _dispatch_inner(self, node: ExprIR, frame: FrameT_contra, name: str) -> ExprAny: + method = self._DISPATCH[node.__class__] + return method(self, node, frame, name) + class CompliantScalar( CompliantExpr[FrameT_contra, SeriesT_co], Protocol[FrameT_contra, SeriesT_co] From a4b1a02ca6533e938d437eb8ede7f8c844c9b765 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 6 Jul 2025 14:39:20 +0100 Subject: [PATCH 280/368] chore: planning `CompliantNamespace` --- narwhals/_plan/protocols.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 77105e09ed..b8237d4a56 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -16,11 +16,17 @@ T = TypeVar("T") SeriesT = TypeVar("SeriesT") SeriesT_co = TypeVar("SeriesT_co", covariant=True) +FrameT = TypeVar("FrameT") +FrameT_co = TypeVar("FrameT_co", covariant=True) FrameT_contra = TypeVar("FrameT_contra", contravariant=True) OneOrIterable: TypeAlias = "T | Iterable[T]" LengthT = TypeVar("LengthT") NativeT_co = TypeVar("NativeT_co", covariant=True, default=Any) ExprAny: TypeAlias = "CompliantExpr[Any, Any]" +ScalarAny: TypeAlias = "CompliantScalar[Any, Any]" +ExprT_co = TypeVar("ExprT_co", bound=ExprAny, covariant=True) +ScalarT = TypeVar("ScalarT", bound="CompliantScalar[Any, Any]") +ScalarT_co = TypeVar("ScalarT_co", bound="CompliantScalar[Any, Any]", covariant=True) class SupportsBroadcast(Protocol[SeriesT, LengthT]): @@ -321,3 +327,30 @@ class LazyScalar( LazyExpr[FrameT_contra, SeriesT, LengthT], Protocol[FrameT_contra, SeriesT, LengthT], ): ... + + +class CompliantNamespace(Protocol[FrameT_co, SeriesT_co, ExprT_co, ScalarT_co]): + """Need to hold `Expr` and `Scalar` types outside of their defs. + + Likely, re-wrapping the output types will work like: + + + ns = DataFrame().__narwhals_namespace__() + if ns._expr.is_native(out): + ns._expr.from_native(out, ...) + elif ns._scalar.is_native(out): + ns._scalar.from_native(out, ...) + else: + assert_never(out) + + Currently that is causing issues in `ArrowExpr2._with_native` + """ + + @property + def _expr(self) -> type[ExprT_co]: ... + @property + def _scalar(self) -> type[ScalarT_co]: ... + @property + def _series(self) -> type[SeriesT_co]: ... + @property + def _dataframe(self) -> type[FrameT_co]: ... From 1aaca2af89e2aa8b198f39192f574c17a2f50e20 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 6 Jul 2025 16:11:22 +0100 Subject: [PATCH 281/368] feat: `col`, `lit` classmethods? --- narwhals/_plan/arrow/expr.py | 25 +++++++++++++++++++++++++ narwhals/_plan/protocols.py | 17 ++++++++++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 4339aba2a6..f6977603ce 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -10,6 +10,7 @@ from narwhals._plan import expr from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.common import ExprIR, NamedIR, into_dtype +from narwhals._plan.literal import is_literal_scalar from narwhals._plan.protocols import EagerBroadcast, EagerExpr, EagerScalar from narwhals._utils import Version, _StoresNative from narwhals.exceptions import InvalidOperationError, ShapeError @@ -66,6 +67,26 @@ def from_native( ) -> Self: return cls.from_series(ArrowSeries.from_native(native, name, version=version)) + @classmethod + def from_ir(cls, value: expr.Literal[DummySeries[ChunkedArrayAny]], /) -> Self: + nw_ser = value.unwrap() + return cls.from_native(nw_ser.to_native(), value.name, nw_ser.version) + + @classmethod + def col(cls, node: expr.Column, frame: ArrowDataFrame, name: str) -> Self: + return cls.from_native(frame.native.column(node.name), name) + + @classmethod + def lit( + cls, + node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[ChunkedArrayAny]], + frame: ArrowDataFrame, # noqa: ARG003 + name: str, # noqa: ARG003 + ) -> ArrowScalar | Self: + if is_literal_scalar(node): + return ArrowScalar.from_ir(node) + return cls.from_ir(node) + def _with_native( self, result: ChunkedArrayAny | NativeScalar, name: str = "", / ) -> Self: @@ -219,6 +240,10 @@ def from_series(cls, series: ArrowSeries) -> Self: msg = f"Too long {len(series)!r}" raise InvalidOperationError(msg) + @classmethod + def from_ir(cls, value: expr.Literal[NonNestedLiteral], /) -> Self: + return cls.from_python(value.unwrap(), value.name, dtype=value.dtype) + @property def native(self) -> NativeScalar: return self._evaluated diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index b8237d4a56..1a339428ac 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -11,7 +11,8 @@ if TYPE_CHECKING: from typing_extensions import Self, TypeAlias - from narwhals.typing import IntoDType, PythonLiteral + from narwhals._plan.dummy import DummySeries + from narwhals.typing import IntoDType, NonNestedLiteral, PythonLiteral T = TypeVar("T") SeriesT = TypeVar("SeriesT") @@ -119,6 +120,18 @@ def from_native( def _with_native(self, native: Any, name: str = "", /) -> Self: return self.from_native(native, name or self.name, self.version) + # entry points + @classmethod + def col(cls, node: expr.Column, frame: FrameT_contra, name: str) -> Self: ... + + @classmethod + def lit( + cls, + node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[Any]], + frame: FrameT_contra, + name: str, + ) -> CompliantScalar[FrameT_contra, SeriesT_co] | Self: ... + # series & scalar def cast(self, node: expr.Cast, frame: FrameT_contra, name: str) -> Self: ... # series only (section 3) @@ -170,6 +183,8 @@ def min( ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... _DISPATCH: ClassVar[Mapping[type[ExprIR], Callable[..., ExprAny]]] = { + expr.Column: col, + expr.Literal: lit, expr.Cast: cast, expr.Sort: sort, expr.SortBy: sort_by, From 1060f0e3c058fb1493cc74155ed205052b20429f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 6 Jul 2025 21:36:14 +0100 Subject: [PATCH 282/368] =?UTF-8?q?feat(DRAFT):=20Dispatch=20take=20?= =?UTF-8?q?=E2=9C=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Still gotta copy over from `evaluate`, but tests are passing! --- narwhals/_plan/arrow/dataframe.py | 21 ++- narwhals/_plan/arrow/evaluate.py | 6 +- narwhals/_plan/arrow/expr.py | 228 +++++++----------------------- narwhals/_plan/arrow/namespace.py | 49 +++++++ narwhals/_plan/dummy.py | 18 ++- narwhals/_plan/protocols.py | 113 +++++++++------ tests/plan/to_compliant_test.py | 19 +-- 7 files changed, 204 insertions(+), 250 deletions(-) create mode 100644 narwhals/_plan/arrow/namespace.py diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 49e375f53e..337ac3e299 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -6,7 +6,6 @@ import pyarrow as pa # ignore-banned-import from narwhals._arrow.utils import native_to_narwhals_dtype -from narwhals._plan.arrow.expr import ArrowExpr, ArrowLiteral from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.dummy import DummyCompliantFrame, DummyFrame from narwhals._utils import Version @@ -17,6 +16,7 @@ from typing_extensions import Self, TypeAlias, TypeIs from narwhals._arrow.typing import ChunkedArrayAny, ScalarAny + from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.common import ExprIR, NamedIR from narwhals.dtypes import DType from narwhals.schema import Schema @@ -30,17 +30,14 @@ def is_series(obj: t.Any) -> TypeIs[ArrowSeries]: class ArrowDataFrame(DummyCompliantFrame[ArrowSeries, "pa.Table", "ChunkedArrayAny"]): - @property - def _series(self) -> type[ArrowSeries]: - return ArrowSeries + def __narwhals_namespace__(self) -> ArrowNamespace: + from narwhals._plan.arrow.namespace import ArrowNamespace - @property - def _expr(self) -> type[ArrowExpr]: - return ArrowExpr + return ArrowNamespace(self._version) @property - def _lit(self) -> type[ArrowLiteral]: - return ArrowLiteral + def _series(self) -> type[ArrowSeries]: + return ArrowSeries @property def columns(self) -> list[str]: @@ -103,6 +100,6 @@ def to_dict( return {ser.name: ser.to_list() for ser in it} def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[ArrowSeries]: - from narwhals._plan.arrow.evaluate import evaluate - - yield from self._expr.align(evaluate(e, self) for e in nodes) + ns = self.__narwhals_namespace__() + from_named_ir = ns._expr.from_named_ir + yield from ns._expr.align(from_named_ir(e, self) for e in nodes) diff --git a/narwhals/_plan/arrow/evaluate.py b/narwhals/_plan/arrow/evaluate.py index a408db7158..d57de4c0b6 100644 --- a/narwhals/_plan/arrow/evaluate.py +++ b/narwhals/_plan/arrow/evaluate.py @@ -1,4 +1,4 @@ -"""Translating `ExprIR` nodes for pyarrow.""" +"""TODO: Move all the impls `ArrowExpr`/`ArrowScalar`, then delete.""" from __future__ import annotations @@ -39,8 +39,8 @@ def is_scalar(obj: t.Any) -> TypeIs[ScalarAny]: def evaluate(node: NamedIR[ExprIR], frame: ArrowDataFrame) -> EagerBroadcast[ArrowSeries]: result = _evaluate_inner(node.expr, frame) if is_scalar(result): - return frame._lit.from_scalar(result, node.name) - return frame._expr.from_native(result, node.name) + return frame.__narwhals_namespace__()._scalar.from_native(result, node.name) + return frame.__narwhals_namespace__()._expr.from_native(result, node.name) # NOTE: Should mean we produce 1x CompliantSeries for the entire expression diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index f6977603ce..5bb1bd77a1 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -1,24 +1,23 @@ from __future__ import annotations -from functools import singledispatchmethod -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import chunked_array, narwhals_to_native_dtype -from narwhals._plan import expr from narwhals._plan.arrow.series import ArrowSeries -from narwhals._plan.common import ExprIR, NamedIR, into_dtype +from narwhals._plan.common import into_dtype from narwhals._plan.literal import is_literal_scalar -from narwhals._plan.protocols import EagerBroadcast, EagerExpr, EagerScalar +from narwhals._plan.protocols import Dispatch, EagerExpr, EagerScalar from narwhals._utils import Version, _StoresNative from narwhals.exceptions import InvalidOperationError, ShapeError if TYPE_CHECKING: from typing_extensions import Self, TypeAlias - from narwhals._arrow.typing import ChunkedArrayAny, Incomplete, ScalarAny + from narwhals._arrow.typing import ChunkedArrayAny, Incomplete + from narwhals._plan import expr from narwhals._plan.aggregation import ( ArgMax, ArgMin, @@ -42,8 +41,10 @@ NativeScalar: TypeAlias = "pa.Scalar[Any]" -class ArrowExpr2( - _StoresNative["ChunkedArrayAny"], EagerExpr["ArrowDataFrame", ArrowSeries] +class ArrowExpr( + Dispatch["ArrowDataFrame", "ArrowExpr | ArrowScalar"], + _StoresNative["ChunkedArrayAny"], + EagerExpr["ArrowDataFrame", ArrowSeries], ): _evaluated: ArrowSeries @@ -51,10 +52,6 @@ class ArrowExpr2( def name(self) -> str: return self._evaluated.name - @property - def version(self) -> Version: - return self._evaluated.version - @classmethod def from_series(cls, series: ArrowSeries, /) -> Self: obj = cls.__new__(cls) @@ -68,32 +65,37 @@ def from_native( return cls.from_series(ArrowSeries.from_native(native, name, version=version)) @classmethod - def from_ir(cls, value: expr.Literal[DummySeries[ChunkedArrayAny]], /) -> Self: + def from_ir( + cls, value: expr.Literal[DummySeries[ChunkedArrayAny]], name: str = "", / + ) -> Self: nw_ser = value.unwrap() - return cls.from_native(nw_ser.to_native(), value.name, nw_ser.version) + return cls.from_native(nw_ser.to_native(), name or value.name, nw_ser.version) - @classmethod - def col(cls, node: expr.Column, frame: ArrowDataFrame, name: str) -> Self: - return cls.from_native(frame.native.column(node.name), name) + def col(self, node: expr.Column, frame: ArrowDataFrame, name: str) -> Self: + return self.from_native(frame.native.column(node.name), name) - @classmethod def lit( - cls, + self, node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[ChunkedArrayAny]], - frame: ArrowDataFrame, # noqa: ARG003 - name: str, # noqa: ARG003 + name: str, ) -> ArrowScalar | Self: if is_literal_scalar(node): - return ArrowScalar.from_ir(node) - return cls.from_ir(node) - + return ArrowScalar.from_ir(node, name) + return self.from_ir(node, name) + + @overload + def _with_native(self, result: ChunkedArrayAny, name: str = ..., /) -> Self: ... + @overload + def _with_native(self, result: NativeScalar, name: str = ..., /) -> ArrowScalar: ... + @overload + def _with_native( + self, result: ChunkedArrayAny | NativeScalar, name: str = ..., / + ) -> ArrowScalar | Self: ... def _with_native( self, result: ChunkedArrayAny | NativeScalar, name: str = "", / - ) -> Self: + ) -> ArrowScalar | Self: if isinstance(result, pa.Scalar): - # NOTE: Will need to resolve this eventually - # Currently the *least bad* option is the single ignore here - return ArrowScalar.from_native(result, name, version=self.version) # type: ignore[return-value] + return ArrowScalar.from_native(result, name, version=self.version) return super()._with_native(result, name) @property @@ -112,36 +114,28 @@ def broadcast(self, length: int, /) -> ArrowSeries: def __len__(self) -> int: return len(self._evaluated) - # NOTE: Dispatch is on `ExprIR`, which is recursive - # There is only a top-level `NamedIR` per column - def evaluate(self, named_ir: NamedIR[ExprIR], frame: ArrowDataFrame) -> ArrowExpr2: - return self._evaluate_inner(named_ir.expr, frame, named_ir.name) - - # NOTE: Don't use `Self`, it breaks the descriptor typing - # The implementations *can* use `Self`, just not here - @singledispatchmethod - def _evaluate_inner( - self, node: ExprIR, frame: ArrowDataFrame, name: str - ) -> ArrowExpr2: - raise NotImplementedError(type(node)) - - @_evaluate_inner.register(expr.Cast) - def cast(self, node: expr.Cast, frame: ArrowDataFrame, name: str) -> Self: + def cast( # type: ignore[override] + self, node: expr.Cast, frame: ArrowDataFrame, name: str + ) -> ArrowScalar | Self: data_type = narwhals_to_native_dtype(node.dtype, frame.version) - native = self._evaluate_inner(node.expr, frame, name).native + native = self._dispatch(node.expr, frame, name).native return self._with_native(pc.cast(native, data_type), name) - def sort(self, node: expr.Sort, frame: ArrowDataFrame, name: str) -> ArrowExpr2: + def sort(self, node: expr.Sort, frame: ArrowDataFrame, name: str) -> ArrowExpr: raise NotImplementedError - def sort_by(self, node: expr.SortBy, frame: ArrowDataFrame, name: str) -> ArrowExpr2: + def sort_by(self, node: expr.SortBy, frame: ArrowDataFrame, name: str) -> ArrowExpr: raise NotImplementedError - def filter(self, node: expr.Filter, frame: ArrowDataFrame, name: str) -> ArrowExpr2: + def filter(self, node: expr.Filter, frame: ArrowDataFrame, name: str) -> ArrowExpr: raise NotImplementedError def first(self, node: First, frame: ArrowDataFrame, name: str) -> ArrowScalar: - raise NotImplementedError + native = self._dispatch(node.expr, frame, name).to_series().native + result: NativeScalar = ( + native[0] if (len(native)) else pa.scalar(None, native.type) + ) + return self._with_native(result, name) def last(self, node: Last, frame: ArrowDataFrame, name: str) -> ArrowScalar: raise NotImplementedError @@ -171,7 +165,10 @@ def count(self, node: Count, frame: ArrowDataFrame, name: str) -> ArrowScalar: raise NotImplementedError def max(self, node: Max, frame: ArrowDataFrame, name: str) -> ArrowScalar: - raise NotImplementedError + result: NativeScalar = pc.max( + self._dispatch(node.expr, frame, name).to_series().native + ) + return self._with_native(result, name) def mean(self, node: Mean, frame: ArrowDataFrame, name: str) -> ArrowScalar: raise NotImplementedError @@ -184,10 +181,11 @@ def min(self, node: Min, frame: ArrowDataFrame, name: str) -> ArrowScalar: class ArrowScalar( - _StoresNative[NativeScalar], EagerScalar["ArrowDataFrame", ArrowSeries] + Dispatch["ArrowDataFrame", "ArrowScalar"], + _StoresNative[NativeScalar], + EagerScalar["ArrowDataFrame", ArrowSeries], ): _name: str - _version: Version _evaluated: NativeScalar @property @@ -241,8 +239,8 @@ def from_series(cls, series: ArrowSeries) -> Self: raise InvalidOperationError(msg) @classmethod - def from_ir(cls, value: expr.Literal[NonNestedLiteral], /) -> Self: - return cls.from_python(value.unwrap(), value.name, dtype=value.dtype) + def from_ir(cls, value: expr.Literal[NonNestedLiteral], name: str, /) -> Self: + return cls.from_python(value.unwrap(), name, dtype=value.dtype) @property def native(self) -> NativeScalar: @@ -262,21 +260,9 @@ def broadcast(self, length: int) -> ArrowSeries: chunked = chunked_array(pa_repeat(scalar, length)) return ArrowSeries.from_native(chunked, self.name, version=self.version) - # NOTE: Dispatch is on `ExprIR`, which is recursive - # There is only a top-level `NamedIR` per column - def evaluate(self, named_ir: NamedIR[ExprIR], frame: ArrowDataFrame) -> ArrowScalar: - return self._evaluate_inner(named_ir.expr, frame, named_ir.name) - - @singledispatchmethod - def _evaluate_inner( - self, node: ExprIR, frame: ArrowDataFrame, name: str - ) -> ArrowScalar: - raise NotImplementedError(type(node)) - - @_evaluate_inner.register(expr.Cast) def cast(self, node: expr.Cast, frame: ArrowDataFrame, name: str) -> ArrowScalar: data_type = narwhals_to_native_dtype(node.dtype, frame.version) - native = self._evaluate_inner(node.expr, frame, name).native + native = self._dispatch(node.expr, frame, name).native return self._with_native(pc.cast(native, data_type), name) def filter(self, node: expr.Filter, frame: ArrowDataFrame, name: str) -> Any: @@ -298,113 +284,5 @@ def var(self, node: Var, frame: ArrowDataFrame, name: str) -> ArrowScalar: return self._with_native(pa.scalar(None, pa.null()), name) def count(self, node: Count, frame: ArrowDataFrame, name: str) -> ArrowScalar: - native = self._evaluate_inner(node.expr, frame, name).native + native = self._dispatch(node.expr, frame, name).native return self._with_native(pa.scalar(1 if native.is_valid else 0), name) - - -# NOTE: General expression result -# Mostly elementwise -class ArrowExpr(EagerBroadcast[ArrowSeries]): - _compliant: ArrowSeries - - @classmethod - def from_series(cls, series: ArrowSeries) -> Self: - obj = cls.__new__(cls) - obj._compliant = series - return obj - - @classmethod - def from_native( - cls, - native: ChunkedArrayAny, - name: str = "", - /, - *, - version: Version = Version.MAIN, - ) -> Self: - return cls.from_series(ArrowSeries.from_native(native, name, version=version)) - - @classmethod - def from_ir(cls, value: expr.Literal[DummySeries[ChunkedArrayAny]], /) -> Self: - return cls.from_native(value.unwrap().to_native(), value.name) - - def to_series(self) -> ArrowSeries: - return self._compliant - - def __len__(self) -> int: - return len(self._compliant) - - def broadcast(self, length: int, /) -> ArrowSeries: - if (actual_len := len(self)) != length: - msg = f"Expected object of length {length}, got {actual_len}." - raise ShapeError(msg) - return self._compliant - - -# NOTE: Aggregation result or scalar -# Should handle broadcasting, without exposing it -class ArrowLiteral(EagerBroadcast[ArrowSeries]): - _native_scalar: ScalarAny - _name: str - - @property - def name(self) -> str: - return self._name - - def __len__(self) -> int: - return 1 - - def broadcast(self, length: int, /) -> ArrowSeries: - if length == 1: - chunked = chunked_array([[self._native_scalar]]) - else: - # NOTE: Same issue as `pa.scalar` overlapping overloads - # https://github.com/zen-xu/pyarrow-stubs/pull/209 - pa_repeat: Incomplete = pa.repeat - arr = pa_repeat(self._native_scalar, length) - chunked = chunked_array(arr) - return ArrowSeries.from_native(chunked, self.name) - - @classmethod - def from_series(cls, series: ArrowSeries) -> Self: - if len(series) == 1: - return cls.from_scalar(series.native[0], series.name) - elif len(series) == 0: - return cls.from_python(None, series.name, dtype=series.dtype) - else: - msg = f"Too long {len(series)!r}" - raise InvalidOperationError(msg) - - def to_series(self) -> ArrowSeries: - return self.broadcast(1) - - @classmethod - def from_python( - cls, - value: PythonLiteral, - name: str = "literal", - /, - *, - dtype: IntoDType | None = None, - ) -> Self: - version = Version.MAIN - dtype_pa: pa.DataType | None = None - if dtype: - dtype = into_dtype(dtype) - if not isinstance(dtype, version.dtypes.Unknown): - dtype_pa = narwhals_to_native_dtype(dtype, version) - # NOTE: PR that fixed this was closed - # https://github.com/zen-xu/pyarrow-stubs/pull/208 - lit: Incomplete = pa.scalar - return cls.from_scalar(lit(value, dtype_pa), name) - - @classmethod - def from_scalar(cls, scalar: ScalarAny, name: str = "literal", /) -> Self: - obj = cls.__new__(cls) - obj._native_scalar = scalar - obj._name = name - return obj - - @classmethod - def from_ir(cls, value: expr.Literal[NonNestedLiteral], /) -> Self: - return cls.from_python(value.unwrap(), value.name) diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py new file mode 100644 index 0000000000..9a212c2179 --- /dev/null +++ b/narwhals/_plan/arrow/namespace.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from narwhals._plan.protocols import EagerNamespace +from narwhals._utils import Version + +if TYPE_CHECKING: + from narwhals._plan.arrow.dataframe import ArrowDataFrame + from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar + from narwhals._plan.arrow.series import ArrowSeries + from narwhals._plan.common import ExprIR, NamedIR + + +class ArrowNamespace( + EagerNamespace["ArrowDataFrame", "ArrowSeries", "ArrowExpr", "ArrowScalar"] +): + def __init__(self, version: Version = Version.MAIN) -> None: + self._version = version + + @property + def _expr(self) -> type[ArrowExpr]: + from narwhals._plan.arrow.expr import ArrowExpr + + return ArrowExpr + + @property + def _scalar(self) -> type[ArrowScalar]: + from narwhals._plan.arrow.expr import ArrowScalar + + return ArrowScalar + + @property + def _series(self) -> type[ArrowSeries]: + from narwhals._plan.arrow.series import ArrowSeries + + return ArrowSeries + + @property + def _dataframe(self) -> type[ArrowDataFrame]: + from narwhals._plan.arrow.dataframe import ArrowDataFrame + + return ArrowDataFrame + + def dispatch_expr(self, named_ir: NamedIR[ExprIR], frame: ArrowDataFrame) -> Any: + return self._expr.from_named_ir(named_ir, frame) + + +ArrowNamespace() diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 5b7e9f2918..bcf5b6eee0 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -851,12 +851,24 @@ def to_dict( def __len__(self) -> int: return len(self._compliant) - def select(self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: t.Any) -> Self: + def _project( + self, + exprs: tuple[IntoExpr | Iterable[IntoExpr], ...], + named_exprs: dict[str, t.Any], + context: ExprContext, + /, + ) -> tuple[Seq[NamedIR[ExprIR]], FrozenSchema]: + """Temp, while these parts aren't connected, this is easier for testing.""" irs, schema_frozen, output_names = expr_expansion.prepare_projection( parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), self.schema ) named_irs = expr_expansion.into_named_irs(irs, output_names) - named_irs, schema_projected = schema_frozen.project(named_irs, ExprContext.SELECT) + return schema_frozen.project(named_irs, context) + + def select(self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: t.Any) -> Self: + named_irs, schema_projected = self._project( + exprs, named_exprs, ExprContext.SELECT + ) return self._from_compliant(self._compliant.select(named_irs, schema_projected)) @@ -864,6 +876,8 @@ class DummyCompliantFrame(Generic[CompliantSeriesT, NativeFrameT, NativeSeriesT] _native: NativeFrameT _version: Version + def __narwhals_namespace__(self) -> t.Any: ... + @property def native(self) -> NativeFrameT: return self._native diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 1a339428ac..f90df417af 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -6,7 +6,7 @@ from narwhals._plan import aggregation as agg, expr from narwhals._plan.common import ExprIR, NamedIR, flatten_hash_safe from narwhals._typing_compat import TypeVar -from narwhals._utils import Version +from narwhals._utils import Version, _StoresVersion if TYPE_CHECKING: from typing_extensions import Self, TypeAlias @@ -15,11 +15,12 @@ from narwhals.typing import IntoDType, NonNestedLiteral, PythonLiteral T = TypeVar("T") +R_co = TypeVar("R_co", covariant=True) SeriesT = TypeVar("SeriesT") SeriesT_co = TypeVar("SeriesT_co", covariant=True) FrameT = TypeVar("FrameT") FrameT_co = TypeVar("FrameT_co", covariant=True) -FrameT_contra = TypeVar("FrameT_contra", contravariant=True) +FrameT_contra = TypeVar("FrameT_contra", bound="_StoresVersion", contravariant=True) OneOrIterable: TypeAlias = "T | Iterable[T]" LengthT = TypeVar("LengthT") NativeT_co = TypeVar("NativeT_co", covariant=True, default=Any) @@ -28,6 +29,12 @@ ExprT_co = TypeVar("ExprT_co", bound=ExprAny, covariant=True) ScalarT = TypeVar("ScalarT", bound="CompliantScalar[Any, Any]") ScalarT_co = TypeVar("ScalarT_co", bound="CompliantScalar[Any, Any]", covariant=True) +IntoSeriesT_co = TypeVar("IntoSeriesT_co", bound="ExprAny | ScalarAny", covariant=True) + +EagerExprT_co = TypeVar("EagerExprT_co", bound="EagerExpr[Any, Any]", covariant=True) +EagerScalarT_co = TypeVar( + "EagerScalarT_co", bound="EagerScalar[Any, Any]", covariant=True +) class SupportsBroadcast(Protocol[SeriesT, LengthT]): @@ -107,8 +114,12 @@ class CompliantExpr(Protocol[FrameT_contra, SeriesT_co]): _evaluated: Any """Compliant or native value.""" + _version: Version + @property - def version(self) -> Version: ... + def version(self) -> Version: + return self._version + @property def name(self) -> str: ... @@ -120,15 +131,10 @@ def from_native( def _with_native(self, native: Any, name: str = "", /) -> Self: return self.from_native(native, name or self.name, self.version) - # entry points - @classmethod - def col(cls, node: expr.Column, frame: FrameT_contra, name: str) -> Self: ... - - @classmethod + def col(self, node: expr.Column, frame: FrameT_contra, name: str) -> Self: ... def lit( - cls, + self, node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[Any]], - frame: FrameT_contra, name: str, ) -> CompliantScalar[FrameT_contra, SeriesT_co] | Self: ... @@ -182,35 +188,44 @@ def min( self, node: agg.Min, frame: FrameT_contra, name: str ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - _DISPATCH: ClassVar[Mapping[type[ExprIR], Callable[..., ExprAny]]] = { - expr.Column: col, - expr.Literal: lit, - expr.Cast: cast, - expr.Sort: sort, - expr.SortBy: sort_by, - expr.Filter: filter, - agg.First: first, - agg.Last: last, - agg.ArgMin: arg_min, - agg.ArgMax: arg_max, - agg.Sum: sum, - agg.NUnique: n_unique, - agg.Std: std, - agg.Var: var, - agg.Quantile: quantile, - agg.Count: count, - agg.Max: max, - agg.Mean: mean, - agg.Median: median, - agg.Min: min, - } - def _dispatch(self, named_ir: NamedIR[ExprIR], frame: FrameT_contra) -> ExprAny: - return self._dispatch_inner(named_ir.expr, frame, named_ir.name) +class Dispatch(Protocol[FrameT_contra, R_co]): + _DISPATCH: ClassVar[Mapping[type[ExprIR], Callable[[Any, ExprIR, Any, str], Any]]] = { + expr.Column: lambda self, node, frame, name: self.col(node, frame, name), + expr.Literal: lambda self, node, _, name: self.lit(node, name), + expr.Cast: lambda self, node, frame, name: self.cast(node, frame, name), + expr.Sort: lambda self, node, frame, name: self.sort(node, frame, name), + expr.SortBy: lambda self, node, frame, name: self.sort_by(node, frame, name), + expr.Filter: lambda self, node, frame, name: self.filter(node, frame, name), + agg.First: lambda self, node, frame, name: self.first(node, frame, name), + agg.Last: lambda self, node, frame, name: self.last(node, frame, name), + agg.ArgMin: lambda self, node, frame, name: self.arg_min(node, frame, name), + agg.ArgMax: lambda self, node, frame, name: self.arg_max(node, frame, name), + agg.Sum: lambda self, node, frame, name: self.sum(node, frame, name), + agg.NUnique: lambda self, node, frame, name: self.n_unique(node, frame, name), + agg.Std: lambda self, node, frame, name: self.std(node, frame, name), + agg.Var: lambda self, node, frame, name: self.var(node, frame, name), + agg.Quantile: lambda self, node, frame, name: self.quantile(node, frame, name), + agg.Count: lambda self, node, frame, name: self.count(node, frame, name), + agg.Max: lambda self, node, frame, name: self.max(node, frame, name), + agg.Mean: lambda self, node, frame, name: self.mean(node, frame, name), + agg.Median: lambda self, node, frame, name: self.median(node, frame, name), + agg.Min: lambda self, node, frame, name: self.min(node, frame, name), + } + _version: Version - def _dispatch_inner(self, node: ExprIR, frame: FrameT_contra, name: str) -> ExprAny: + def _dispatch(self, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: method = self._DISPATCH[node.__class__] - return method(self, node, frame, name) + return method(self, node, frame, name) # type: ignore[no-any-return] + + @classmethod + def from_named_ir(cls, named_ir: NamedIR[ExprIR], frame: FrameT_contra) -> R_co: + node = named_ir.expr + name = named_ir.name + method = cls._DISPATCH[node.__class__] + obj = cls.__new__(cls) + obj._version = frame._version + return method(obj, node, frame, name) # type: ignore[no-any-return] class CompliantScalar( @@ -238,7 +253,7 @@ def from_python( version: Version, ) -> Self: ... - def _with_evaluated(self, evaluated: Any, name: str = "") -> Self: + def _with_evaluated(self, evaluated: Any, name: str) -> Self: """Expr is based on a series having these via accessors, but a scalar needs to keep passing through.""" cls = type(self) obj = cls.__new__(cls) @@ -306,10 +321,10 @@ def count(self, node: agg.Count, frame: FrameT_contra, name: str) -> Self: ... def sort(self, node: expr.Sort, frame: FrameT_contra, name: str) -> Self: - return self._with_evaluated(self._evaluated) + return self._with_evaluated(self._evaluated, name) def sort_by(self, node: expr.SortBy, frame: FrameT_contra, name: str) -> Self: - return self._with_evaluated(self._evaluated) + return self._with_evaluated(self._evaluated, name) # NOTE: `Filter` behaves the same, (maybe) no need to override @@ -357,15 +372,25 @@ class CompliantNamespace(Protocol[FrameT_co, SeriesT_co, ExprT_co, ScalarT_co]): ns._scalar.from_native(out, ...) else: assert_never(out) - - Currently that is causing issues in `ArrowExpr2._with_native` """ + _version: Version + + @property + def _dataframe(self) -> type[FrameT_co]: ... + @property + def _series(self) -> type[SeriesT_co]: ... @property def _expr(self) -> type[ExprT_co]: ... @property def _scalar(self) -> type[ScalarT_co]: ... + @property - def _series(self) -> type[SeriesT_co]: ... - @property - def _dataframe(self) -> type[FrameT_co]: ... + def version(self) -> Version: + return self._version + + +class EagerNamespace( + CompliantNamespace[FrameT_co, SeriesT_co, EagerExprT_co, EagerScalarT_co], + Protocol[FrameT_co, SeriesT_co, EagerExprT_co, EagerScalarT_co], +): ... diff --git a/tests/plan/to_compliant_test.py b/tests/plan/to_compliant_test.py index 438b7970a0..dd3c4f0cea 100644 --- a/tests/plan/to_compliant_test.py +++ b/tests/plan/to_compliant_test.py @@ -45,25 +45,16 @@ def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None: assert isinstance(compliant_expr, namespace._expr) -XFAIL_BROADCAST = pytest.mark.xfail( - reason="Shouldn't broadcast when all Series are length 1." -) - - @pytest.mark.parametrize( ("expr", "expected"), [ (nwd.col("a"), {"a": ["A", "B", "A"]}), (nwd.col("a", "b"), {"a": ["A", "B", "A"], "b": [1, 2, 3]}), - pytest.param(nwd.lit(1), {"literal": [1]}, marks=XFAIL_BROADCAST), - pytest.param(nwd.lit(2.0), {"literal": [2.0]}, marks=XFAIL_BROADCAST), - pytest.param( - nwd.lit(None, nw.String()), {"literal": [None]}, marks=XFAIL_BROADCAST - ), - pytest.param( - nwd.col("a", "b").first(), {"a": ["A"], "b": [1]}, marks=XFAIL_BROADCAST - ), - pytest.param(nwd.col("d").max(), {"d": [8]}, marks=XFAIL_BROADCAST), + (nwd.lit(1), {"literal": [1]}), + (nwd.lit(2.0), {"literal": [2.0]}), + (nwd.lit(None, nw.String()), {"literal": [None]}), + (nwd.col("a", "b").first(), {"a": ["A"], "b": [1]}), + (nwd.col("d").max(), {"d": [8]}), ], ids=_ids_ir, ) From d6ebf9b346ae1bffe3f6ed3f05dce2dc0d2266c5 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 6 Jul 2025 21:38:57 +0100 Subject: [PATCH 283/368] Update narwhals/_plan/dummy.py --- narwhals/_plan/dummy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index bcf5b6eee0..8cf118983f 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -1050,4 +1050,4 @@ def __len__(self) -> int: return len(self.native) def to_list(self) -> list[t.Any]: - raise NotADirectoryError + raise NotImplementedError From 33ddb8d92bb286fe1498fd0276d5e4a5d9635c22 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 6 Jul 2025 22:07:38 +0100 Subject: [PATCH 284/368] maybe `pyarrow` backcompat? https://github.com/narwhals-dev/narwhals/actions/runs/16102965010/job/45434495395?pr=2572 --- narwhals/_plan/arrow/expr.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 5bb1bd77a1..5db428ca23 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -132,9 +132,7 @@ def filter(self, node: expr.Filter, frame: ArrowDataFrame, name: str) -> ArrowEx def first(self, node: First, frame: ArrowDataFrame, name: str) -> ArrowScalar: native = self._dispatch(node.expr, frame, name).to_series().native - result: NativeScalar = ( - native[0] if (len(native)) else pa.scalar(None, native.type) - ) + result = lit(native[0]) if len(native) else lit(None, native.type) return self._with_native(result, name) def last(self, node: Last, frame: ArrowDataFrame, name: str) -> ArrowScalar: @@ -180,6 +178,15 @@ def min(self, node: Min, frame: ArrowDataFrame, name: str) -> ArrowScalar: raise NotImplementedError +def lit(value: Any, dtype: pa.DataType | None = None) -> NativeScalar: + # NOTE: Needed for `pyarrow<13` + if isinstance(value, pa.Scalar): + return value + # NOTE: PR that fixed this the overloads was closed + # https://github.com/zen-xu/pyarrow-stubs/pull/208 + return pa.scalar(value) if dtype is None else pa.scalar(value, dtype) + + class ArrowScalar( Dispatch["ArrowDataFrame", "ArrowScalar"], _StoresNative[NativeScalar], @@ -221,9 +228,6 @@ def from_python( dtype = into_dtype(dtype) if not isinstance(dtype, version.dtypes.Unknown): dtype_pa = narwhals_to_native_dtype(dtype, version) - # NOTE: PR that fixed this was closed - # https://github.com/zen-xu/pyarrow-stubs/pull/208 - lit: Incomplete = pa.scalar return cls.from_native(lit(value, dtype_pa), name, version) @classmethod From 7aa7d1dbfe6bbea2c352d1f676021692541088da Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 6 Jul 2025 22:18:01 +0100 Subject: [PATCH 285/368] is `len` the issue? https://github.com/narwhals-dev/narwhals/actions/runs/16103181063/job/45434983508?pr=2572 --- narwhals/_plan/arrow/expr.py | 5 +++-- narwhals/_plan/arrow/series.py | 3 +++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 5db428ca23..6f25baafed 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -131,8 +131,9 @@ def filter(self, node: expr.Filter, frame: ArrowDataFrame, name: str) -> ArrowEx raise NotImplementedError def first(self, node: First, frame: ArrowDataFrame, name: str) -> ArrowScalar: - native = self._dispatch(node.expr, frame, name).to_series().native - result = lit(native[0]) if len(native) else lit(None, native.type) + prev = self._dispatch(node.expr, frame, name) + native = prev.to_series().native + result = lit(native[0]) if len(prev) else lit(None, native.type) return self._with_native(result, name) def last(self, node: Last, frame: ArrowDataFrame, name: str) -> ArrowScalar: diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index f8189b848d..1692c459ce 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -14,6 +14,9 @@ class ArrowSeries(DummyCompliantSeries["ChunkedArrayAny"]): def to_list(self) -> list[Any]: return self.native.to_pylist() + def __len__(self) -> int: + return self.native.length() + @property def dtype(self) -> DType: return native_to_narwhals_dtype(self.native.type, self._version) From eebef7a4d8c0448fa1868ecb0f80f256e02d3edb Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 6 Jul 2025 22:26:32 +0100 Subject: [PATCH 286/368] plz https://github.com/narwhals-dev/narwhals/actions/runs/16103277767/job/45435186586?pr=2572 --- narwhals/_plan/arrow/expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 6f25baafed..cae070d4e8 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -257,7 +257,7 @@ def to_series(self) -> ArrowSeries: def broadcast(self, length: int) -> ArrowSeries: scalar = self.native if length == 1: - chunked = chunked_array([[scalar]]) + chunked = chunked_array(pa.array([scalar])) else: # NOTE: Same issue as `pa.scalar` overlapping overloads # https://github.com/zen-xu/pyarrow-stubs/pull/209 From acdbf5e193525a570af2e5a4b225024ae5650a39 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 7 Jul 2025 11:37:22 +0100 Subject: [PATCH 287/368] revert: remove typing check --- narwhals/_plan/arrow/namespace.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index 9a212c2179..90c05d124f 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -44,6 +44,3 @@ def _dataframe(self) -> type[ArrowDataFrame]: def dispatch_expr(self, named_ir: NamedIR[ExprIR], frame: ArrowDataFrame) -> Any: return self._expr.from_named_ir(named_ir, frame) - - -ArrowNamespace() From 443786100f3ac3b78c8910d6a38036e5b4020192 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 7 Jul 2025 12:28:16 +0100 Subject: [PATCH 288/368] fix: Unwrap scalar on old pyarrow https://github.com/narwhals-dev/narwhals/pull/2572#issuecomment-3042716206 --- narwhals/_plan/arrow/expr.py | 39 +++++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index cae070d4e8..b1ad8b4a17 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -1,22 +1,32 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, Any, cast, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import -from narwhals._arrow.utils import chunked_array, narwhals_to_native_dtype +from narwhals._arrow.utils import ( + chunked_array as _chunked_array, + narwhals_to_native_dtype, +) from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.common import into_dtype from narwhals._plan.literal import is_literal_scalar from narwhals._plan.protocols import Dispatch, EagerExpr, EagerScalar -from narwhals._utils import Version, _StoresNative +from narwhals._utils import Implementation, Version, _StoresNative from narwhals.exceptions import InvalidOperationError, ShapeError if TYPE_CHECKING: + from collections.abc import Iterable + from typing_extensions import Self, TypeAlias - from narwhals._arrow.typing import ChunkedArrayAny, Incomplete + from narwhals._arrow.typing import ( + ArrayAny, + ArrayOrScalar, + ChunkedArrayAny, + Incomplete, + ) from narwhals._plan import expr from narwhals._plan.aggregation import ( ArgMax, @@ -40,6 +50,8 @@ NativeScalar: TypeAlias = "pa.Scalar[Any]" +BACKEND_VERSION = Implementation.PYARROW._backend_version() + class ArrowExpr( Dispatch["ArrowDataFrame", "ArrowExpr | ArrowScalar"], @@ -188,6 +200,23 @@ def lit(value: Any, dtype: pa.DataType | None = None) -> NativeScalar: return pa.scalar(value) if dtype is None else pa.scalar(value, dtype) +# NOTE: https://github.com/apache/arrow/issues/21761 +# fmt: off +if BACKEND_VERSION >= (13,): + def array(value: NativeScalar) -> ArrayAny: + return pa.array([value], value.type) +else: + def array(value: NativeScalar) -> ArrayAny: + return cast("ArrayAny", pa.array([value.as_py()], value.type)) +# fmt: on + + +def chunked_array( + arr: ArrayOrScalar | list[Iterable[Any]], dtype: pa.DataType | None = None, / +) -> ChunkedArrayAny: + return _chunked_array(array(arr) if isinstance(arr, pa.Scalar) else arr, dtype) + + class ArrowScalar( Dispatch["ArrowDataFrame", "ArrowScalar"], _StoresNative[NativeScalar], @@ -257,7 +286,7 @@ def to_series(self) -> ArrowSeries: def broadcast(self, length: int) -> ArrowSeries: scalar = self.native if length == 1: - chunked = chunked_array(pa.array([scalar])) + chunked = chunked_array(scalar) else: # NOTE: Same issue as `pa.scalar` overlapping overloads # https://github.com/zen-xu/pyarrow-stubs/pull/209 From bcdca6ebce602a77d7400bde06037873ba1a8ef4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 7 Jul 2025 12:38:37 +0100 Subject: [PATCH 289/368] fix: Use new `Interval` helper in `truncate` Added in #2733 --- narwhals/_plan/temporal.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/temporal.py b/narwhals/_plan/temporal.py index d0b4a2cd1c..6069d95a6a 100644 --- a/narwhals/_plan/temporal.py +++ b/narwhals/_plan/temporal.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from typing_extensions import TypeAlias, TypeIs - from narwhals._duration import IntervalUnit + from narwhals._duration import Interval, IntervalUnit from narwhals._plan.dummy import DummyExpr from narwhals.typing import TimeUnit @@ -152,10 +152,13 @@ class Truncate(TemporalFunction): @staticmethod def from_string(every: str, /) -> Truncate: - from narwhals._duration import parse_interval_string + from narwhals._duration import Interval - multiple, unit = parse_interval_string(every) - return Truncate(multiple=multiple, unit=unit) + return Truncate.from_interval(Interval.parse(every)) + + @staticmethod + def from_interval(every: Interval, /) -> Truncate: + return Truncate(multiple=every.multiple, unit=every.unit) class IRDateTimeNamespace(IRNamespace): From 96b0c9c9df6cfa1e1d34ff2784dbfc241bb301ee Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 7 Jul 2025 16:47:22 +0100 Subject: [PATCH 290/368] feat(pyarrow): Impl `len` --- narwhals/_plan/arrow/expr.py | 66 +++++++++++++++++++++--------------- narwhals/_plan/protocols.py | 26 ++++++++------ 2 files changed, 55 insertions(+), 37 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index b1ad8b4a17..d73c63248e 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -13,7 +13,7 @@ from narwhals._plan.common import into_dtype from narwhals._plan.literal import is_literal_scalar from narwhals._plan.protocols import Dispatch, EagerExpr, EagerScalar -from narwhals._utils import Implementation, Version, _StoresNative +from narwhals._utils import Implementation, Version, _StoresNative, not_implemented from narwhals.exceptions import InvalidOperationError, ShapeError if TYPE_CHECKING: @@ -46,6 +46,7 @@ ) from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.dummy import DummySeries + from narwhals._plan.expr import Len from narwhals.typing import IntoDType, NonNestedLiteral, PythonLiteral NativeScalar: TypeAlias = "pa.Scalar[Any]" @@ -59,6 +60,7 @@ class ArrowExpr( EagerExpr["ArrowDataFrame", ArrowSeries], ): _evaluated: ArrowSeries + _version: Version @property def name(self) -> str: @@ -68,6 +70,7 @@ def name(self) -> str: def from_series(cls, series: ArrowSeries, /) -> Self: obj = cls.__new__(cls) obj._evaluated = series + obj._version = series.version return obj @classmethod @@ -76,25 +79,6 @@ def from_native( ) -> Self: return cls.from_series(ArrowSeries.from_native(native, name, version=version)) - @classmethod - def from_ir( - cls, value: expr.Literal[DummySeries[ChunkedArrayAny]], name: str = "", / - ) -> Self: - nw_ser = value.unwrap() - return cls.from_native(nw_ser.to_native(), name or value.name, nw_ser.version) - - def col(self, node: expr.Column, frame: ArrowDataFrame, name: str) -> Self: - return self.from_native(frame.native.column(node.name), name) - - def lit( - self, - node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[ChunkedArrayAny]], - name: str, - ) -> ArrowScalar | Self: - if is_literal_scalar(node): - return ArrowScalar.from_ir(node, name) - return self.from_ir(node, name) - @overload def _with_native(self, result: ChunkedArrayAny, name: str = ..., /) -> Self: ... @overload @@ -126,6 +110,23 @@ def broadcast(self, length: int, /) -> ArrowSeries: def __len__(self) -> int: return len(self._evaluated) + def col(self, node: expr.Column, frame: ArrowDataFrame, name: str) -> Self: + return self.from_native(frame.native.column(node.name), name) + + def lit( + self, + node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[ChunkedArrayAny]], + frame: ArrowDataFrame, + name: str, + ) -> ArrowScalar | Self: + if is_literal_scalar(node): + return ArrowScalar.from_ir(node, frame, name) + nw_ser = node.unwrap() + return self.from_native(nw_ser.to_native(), name or node.name, nw_ser.version) + + def len(self, node: Len, frame: ArrowDataFrame, name: str) -> ArrowScalar: + return ArrowScalar.from_ir(node, frame, name) + def cast( # type: ignore[override] self, node: expr.Cast, frame: ArrowDataFrame, name: str ) -> ArrowScalar | Self: @@ -224,6 +225,7 @@ class ArrowScalar( ): _name: str _evaluated: NativeScalar + _version: Version @property def name(self) -> str: @@ -272,10 +274,6 @@ def from_series(cls, series: ArrowSeries) -> Self: msg = f"Too long {len(series)!r}" raise InvalidOperationError(msg) - @classmethod - def from_ir(cls, value: expr.Literal[NonNestedLiteral], name: str, /) -> Self: - return cls.from_python(value.unwrap(), name, dtype=value.dtype) - @property def native(self) -> NativeScalar: return self._evaluated @@ -294,14 +292,26 @@ def broadcast(self, length: int) -> ArrowSeries: chunked = chunked_array(pa_repeat(scalar, length)) return ArrowSeries.from_native(chunked, self.name, version=self.version) + # NOTE: These need to move into a single def in `Namespace` + # - col + # - len + # - lit + def col(self, node: expr.Column, frame: ArrowDataFrame, name: str) -> Self: + return ArrowExpr.from_ir(node, frame, name) # type: ignore[return-value] + + def lit( + self, node: expr.Literal[NonNestedLiteral], frame: ArrowDataFrame, name: str + ) -> Self: + return self.from_python(node.unwrap(), name, dtype=node.dtype) + + def len(self, node: Len, frame: ArrowDataFrame, name: str) -> Self: + return self.from_python(len(frame), name or node.name, version=frame.version) + def cast(self, node: expr.Cast, frame: ArrowDataFrame, name: str) -> ArrowScalar: data_type = narwhals_to_native_dtype(node.dtype, frame.version) native = self._dispatch(node.expr, frame, name).native return self._with_native(pc.cast(native, data_type), name) - def filter(self, node: expr.Filter, frame: ArrowDataFrame, name: str) -> Any: - raise NotImplementedError - def arg_min(self, node: ArgMin, frame: ArrowDataFrame, name: str) -> ArrowScalar: return self._with_native(pa.scalar(0), name) @@ -320,3 +330,5 @@ def var(self, node: Var, frame: ArrowDataFrame, name: str) -> ArrowScalar: def count(self, node: Count, frame: ArrowDataFrame, name: str) -> ArrowScalar: native = self._dispatch(node.expr, frame, name).native return self._with_native(pa.scalar(1 if native.is_valid else 0), name) + + filter = not_implemented() diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index f90df417af..9cd743d7b8 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -11,7 +11,6 @@ if TYPE_CHECKING: from typing_extensions import Self, TypeAlias - from narwhals._plan.dummy import DummySeries from narwhals.typing import IntoDType, NonNestedLiteral, PythonLiteral T = TypeVar("T") @@ -133,10 +132,11 @@ def _with_native(self, native: Any, name: str = "", /) -> Self: def col(self, node: expr.Column, frame: FrameT_contra, name: str) -> Self: ... def lit( - self, - node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[Any]], - name: str, + self, node: expr.Literal[Any], frame: FrameT_contra, name: str ) -> CompliantScalar[FrameT_contra, SeriesT_co] | Self: ... + def len( + self, node: expr.Len, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... # series & scalar def cast(self, node: expr.Cast, frame: FrameT_contra, name: str) -> Self: ... @@ -192,7 +192,8 @@ def min( class Dispatch(Protocol[FrameT_contra, R_co]): _DISPATCH: ClassVar[Mapping[type[ExprIR], Callable[[Any, ExprIR, Any, str], Any]]] = { expr.Column: lambda self, node, frame, name: self.col(node, frame, name), - expr.Literal: lambda self, node, _, name: self.lit(node, name), + expr.Literal: lambda self, node, frame, name: self.lit(node, frame, name), + expr.Len: lambda self, node, frame, name: self.len(node, frame, name), expr.Cast: lambda self, node, frame, name: self.cast(node, frame, name), expr.Sort: lambda self, node, frame, name: self.sort(node, frame, name), expr.SortBy: lambda self, node, frame, name: self.sort_by(node, frame, name), @@ -219,13 +220,14 @@ def _dispatch(self, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: return method(self, node, frame, name) # type: ignore[no-any-return] @classmethod - def from_named_ir(cls, named_ir: NamedIR[ExprIR], frame: FrameT_contra) -> R_co: - node = named_ir.expr - name = named_ir.name - method = cls._DISPATCH[node.__class__] + def from_ir(cls, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: obj = cls.__new__(cls) obj._version = frame._version - return method(obj, node, frame, name) # type: ignore[no-any-return] + return obj._dispatch(node, frame, name) + + @classmethod + def from_named_ir(cls, named_ir: NamedIR[ExprIR], frame: FrameT_contra) -> R_co: + return cls.from_ir(named_ir.expr, frame, named_ir.name) class CompliantScalar( @@ -262,6 +264,10 @@ def _with_evaluated(self, evaluated: Any, name: str) -> Self: obj._version = self.version return obj + def lit( + self, node: expr.Literal[NonNestedLiteral], frame: FrameT_contra, name: str + ) -> Self: ... + def max(self, node: agg.Max, frame: FrameT_contra, name: str) -> Self: """Returns self.""" return self._with_evaluated(self._evaluated, name) From 95833838833f8c799582a6804cab7629507c6876 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 7 Jul 2025 17:14:02 +0100 Subject: [PATCH 291/368] feat(pyarrow): More impls --- narwhals/_plan/arrow/expr.py | 45 ++++++++++++++++++++++++++++-------- narwhals/_plan/protocols.py | 6 +++-- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index d73c63248e..7503d2184d 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -10,7 +10,7 @@ narwhals_to_native_dtype, ) from narwhals._plan.arrow.series import ArrowSeries -from narwhals._plan.common import into_dtype +from narwhals._plan.common import ExprIR, into_dtype from narwhals._plan.literal import is_literal_scalar from narwhals._plan.protocols import Dispatch, EagerExpr, EagerScalar from narwhals._utils import Implementation, Version, _StoresNative, not_implemented @@ -94,6 +94,18 @@ def _with_native( return ArrowScalar.from_native(result, name, version=self.version) return super()._with_native(result, name) + def _dispatch_expr( + self, node: ExprIR, frame: ArrowDataFrame, name: str + ) -> ArrowSeries: + """Use instead of `_dispatch` *iff* an operation isn't natively supported on `ChunkedArray`. + + There is no need to broadcast, as they may have a cheaper impl elsewhere (`CompliantScalar` or `ArrowScalar`). + + Mainly for the benefit of a type checker, but the equivalent `ArrowScalar._dispatch_expr` will raise if + the assumption fails. + """ + return self._dispatch(node, frame, name).to_series() + @property def native(self) -> ChunkedArrayAny: return self._evaluated.native @@ -135,22 +147,33 @@ def cast( # type: ignore[override] return self._with_native(pc.cast(native, data_type), name) def sort(self, node: expr.Sort, frame: ArrowDataFrame, name: str) -> ArrowExpr: - raise NotImplementedError + native = self._dispatch_expr(node.expr, frame, name).native + sorted_indices = pc.array_sort_indices(native, options=node.options.to_arrow()) + return self._with_native(native.take(sorted_indices), name) def sort_by(self, node: expr.SortBy, frame: ArrowDataFrame, name: str) -> ArrowExpr: raise NotImplementedError def filter(self, node: expr.Filter, frame: ArrowDataFrame, name: str) -> ArrowExpr: - raise NotImplementedError + return self._with_native( + self._dispatch_expr(node.expr, frame, name).native.filter( + self._dispatch_expr(node.by, frame, name).native + ) + ) def first(self, node: First, frame: ArrowDataFrame, name: str) -> ArrowScalar: - prev = self._dispatch(node.expr, frame, name) - native = prev.to_series().native + prev = self._dispatch_expr(node.expr, frame, name) + native = prev.native result = lit(native[0]) if len(prev) else lit(None, native.type) return self._with_native(result, name) def last(self, node: Last, frame: ArrowDataFrame, name: str) -> ArrowScalar: - raise NotImplementedError + prev = self._dispatch_expr(node.expr, frame, name) + native = prev.native + result = ( + lit(native[height - 1]) if (height := len(prev)) else lit(None, native.type) + ) + return self._with_native(result, name) def arg_min(self, node: ArgMin, frame: ArrowDataFrame, name: str) -> ArrowScalar: raise NotImplementedError @@ -177,9 +200,7 @@ def count(self, node: Count, frame: ArrowDataFrame, name: str) -> ArrowScalar: raise NotImplementedError def max(self, node: Max, frame: ArrowDataFrame, name: str) -> ArrowScalar: - result: NativeScalar = pc.max( - self._dispatch(node.expr, frame, name).to_series().native - ) + result: NativeScalar = pc.max(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) def mean(self, node: Mean, frame: ArrowDataFrame, name: str) -> ArrowScalar: @@ -274,6 +295,12 @@ def from_series(cls, series: ArrowSeries) -> Self: msg = f"Too long {len(series)!r}" raise InvalidOperationError(msg) + def _dispatch_expr( + self, node: ExprIR, frame: ArrowDataFrame, name: str + ) -> ArrowSeries: + msg = f"Expected unreachable, but hit at: {node!r}" + raise InvalidOperationError(msg) + @property def native(self) -> NativeScalar: return self._evaluated diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 9cd743d7b8..0da4913ee9 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -216,8 +216,10 @@ class Dispatch(Protocol[FrameT_contra, R_co]): _version: Version def _dispatch(self, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: - method = self._DISPATCH[node.__class__] - return method(self, node, frame, name) # type: ignore[no-any-return] + if method := self._DISPATCH.get(node.__class__): + return method(self, node, frame, name) # type: ignore[no-any-return] + msg = f"Support for {node.__class__.__name__!r} is not yet implemented, got:\n{node!r}" + raise NotImplementedError(msg) @classmethod def from_ir(cls, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: From 5457b30b0ad5352bbe142dd3384d69b33560c8f5 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 7 Jul 2025 17:16:26 +0100 Subject: [PATCH 292/368] test(pyarrow): Coverage for most of the current impl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit oh look - we have have selectors for free 😎 --- tests/plan/to_compliant_test.py | 34 +++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/tests/plan/to_compliant_test.py b/tests/plan/to_compliant_test.py index dd3c4f0cea..457ee4232a 100644 --- a/tests/plan/to_compliant_test.py +++ b/tests/plan/to_compliant_test.py @@ -5,12 +5,15 @@ import pytest import narwhals as nw -import narwhals._plan.demo as nwd +from narwhals._plan import demo as nwd, selectors as ndcs from narwhals._plan.common import is_expr +from narwhals.exceptions import ComputeError from narwhals.utils import Version from tests.namespace_test import backends if TYPE_CHECKING: + from collections.abc import Sequence + from narwhals._namespace import BackendName from narwhals._plan.dummy import DummyExpr @@ -45,6 +48,14 @@ def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None: assert isinstance(compliant_expr, namespace._expr) +XFAIL_REQUIRES_BINARY_EXPR = pytest.mark.xfail( + reason="Requires `BinaryExpr` implementation.", raises=NotImplementedError +) +XFAIL_REWRITE_SPECIAL_ALIASES = pytest.mark.xfail( + reason="Bug in `meta` namespace impl", raises=ComputeError +) + + @pytest.mark.parametrize( ("expr", "expected"), [ @@ -55,11 +66,30 @@ def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None: (nwd.lit(None, nw.String()), {"literal": [None]}), (nwd.col("a", "b").first(), {"a": ["A"], "b": [1]}), (nwd.col("d").max(), {"d": [8]}), + ([nwd.len(), nwd.nth(3).last()], {"len": [3], "d": [8]}), + ( + [nwd.len().alias("e"), nwd.nth(3).last(), nwd.nth(2)], + {"e": [3, 3, 3], "d": [8, 8, 8], "c": [9, 2, 4]}, + ), + (nwd.col("b").sort(descending=True).alias("b_desc"), {"b_desc": [3, 2, 1]}), + pytest.param( + nwd.col("c").filter(a="B"), {"c": [2]}, marks=XFAIL_REQUIRES_BINARY_EXPR + ), + (nwd.col("b").cast(nw.Float64()), {"b": [1.0, 2.0, 3.0]}), + (nwd.lit(1).cast(nw.Float64()).alias("literal_cast"), {"literal_cast": [1.0]}), + pytest.param( + nwd.lit(1).cast(nw.Float64()).name.suffix("_cast"), + {"literal_cast": [1.0]}, + marks=XFAIL_REWRITE_SPECIAL_ALIASES, + ), + ([ndcs.string().first(), nwd.col("b")], {"a": ["A", "A", "A"], "b": [1, 2, 3]}), ], ids=_ids_ir, ) def test_select( - expr: DummyExpr, expected: dict[str, Any], data_small: dict[str, Any] + expr: DummyExpr | Sequence[DummyExpr], + expected: dict[str, Any], + data_small: dict[str, Any], ) -> None: pytest.importorskip("pyarrow") import pyarrow as pa From c0c6b7dfa44241ee22f43d3c6fc767353e2145d3 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 7 Jul 2025 17:48:21 +0100 Subject: [PATCH 293/368] refactor(pyarrow): Migrate most of `evaluate` into `ArrowExpr` Will take another swing at `sort_by` after adding `ArrowDataFrame.sort` --- narwhals/_plan/arrow/evaluate.py | 140 +------------------------------ narwhals/_plan/arrow/expr.py | 45 +++++++--- 2 files changed, 35 insertions(+), 150 deletions(-) diff --git a/narwhals/_plan/arrow/evaluate.py b/narwhals/_plan/arrow/evaluate.py index d57de4c0b6..63d6828b2c 100644 --- a/narwhals/_plan/arrow/evaluate.py +++ b/narwhals/_plan/arrow/evaluate.py @@ -3,14 +3,11 @@ from __future__ import annotations import typing as t - -# ruff: noqa: ARG001 from functools import singledispatch from itertools import repeat -from narwhals._plan import aggregation as agg, expr +from narwhals._plan import expr from narwhals._plan.arrow.series import ArrowSeries -from narwhals._plan.literal import is_literal_scalar if t.TYPE_CHECKING: from typing_extensions import TypeAlias, TypeIs @@ -22,9 +19,7 @@ ) from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.common import ExprIR, NamedIR - from narwhals._plan.dummy import DummySeries from narwhals._plan.protocols import EagerBroadcast - from narwhals.typing import NonNestedLiteral, PythonLiteral UnaryFn: TypeAlias = "t.Callable[[ChunkedArrayAny], ScalarAny]" @@ -56,50 +51,6 @@ def col(node: expr.Column, frame: ArrowDataFrame) -> ChunkedArrayAny: return frame.native.column(node.name) -# NOTE: Using a very naïve approach to broadcasting **for now** -# - We already have something that works in main -# - Another approach would be to keep everything wrapped (or aggregated into) `expr.Literal` -def _lit_native( - value: PythonLiteral | ScalarAny, frame: ArrowDataFrame -) -> ChunkedArrayAny: - """Will need to support returning a native scalar as well.""" - import pyarrow as pa # ignore-banned-import - - from narwhals._arrow.utils import chunked_array - - lit: t.Any = pa.scalar - scalar: t.Any = value if isinstance(value, pa.Scalar) else lit(value) - array = pa.repeat(scalar, len(frame)) - return chunked_array(array) - - -@_evaluate_inner.register(expr.Literal) -def lit_( - node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[ChunkedArrayAny]], - frame: ArrowDataFrame, -) -> ChunkedArrayAny: - if is_literal_scalar(node): - return _lit_native(node.unwrap(), frame) - return node.unwrap().to_native() - - -@_evaluate_inner.register(expr.Cast) -def cast_(node: expr.Cast, frame: ArrowDataFrame) -> ChunkedArrayAny: - from narwhals._arrow.utils import narwhals_to_native_dtype - - data_type = narwhals_to_native_dtype(node.dtype, frame.version) - return _evaluate_inner(node.expr, frame).cast(data_type) - - -@_evaluate_inner.register(expr.Sort) -def sort(node: expr.Sort, frame: ArrowDataFrame) -> ChunkedArrayAny: - import pyarrow.compute as pc - - native = _evaluate_inner(node.expr, frame) - sorted_indices = pc.array_sort_indices(native, options=node.options.to_arrow()) - return native.take(sorted_indices) - - @_evaluate_inner.register(expr.SortBy) def sort_by(node: expr.SortBy, frame: ArrowDataFrame) -> ChunkedArrayAny: opts = node.options @@ -124,100 +75,11 @@ def sort_by(node: expr.SortBy, frame: ArrowDataFrame) -> ChunkedArrayAny: return df.native.sort_by(sorting, null_placement=placement).column(0) -@_evaluate_inner.register(expr.Filter) -def filter_(node: expr.Filter, frame: ArrowDataFrame) -> ChunkedArrayAny: - return _evaluate_inner(node.expr, frame).filter(_evaluate_inner(node.by, frame)) - - -@_evaluate_inner.register(expr.Len) -def len_(node: expr.Len, frame: ArrowDataFrame) -> ChunkedArrayAny: - return _lit_native(len(frame), frame) - - @_evaluate_inner.register(expr.Ternary) def ternary(node: expr.Ternary, frame: ArrowDataFrame) -> ChunkedArrayAny: raise NotImplementedError(type(node)) -@_evaluate_inner.register(agg.Last) -@_evaluate_inner.register(agg.First) -def agg_first_last(node: agg.First | agg.Last, frame: ArrowDataFrame) -> ChunkedArrayAny: - native = _evaluate_inner(node.expr, frame) - if height := len(native): - result = native[height - 1 if isinstance(node, agg.Last) else 0] - else: - result = None - return _lit_native(result, frame) - - -@_evaluate_inner.register(agg.ArgMax) -@_evaluate_inner.register(agg.ArgMin) -def agg_arg_min_max( - node: agg.ArgMin | agg.ArgMax, frame: ArrowDataFrame -) -> ChunkedArrayAny: - import pyarrow.compute as pc - - native = _evaluate_inner(node.expr, frame) - fn = pc.min if isinstance(node, agg.ArgMin) else pc.max - result = pc.index(native, fn(native)) - return _lit_native(result, frame) - - -@_evaluate_inner.register(agg.Sum) -def agg_sum(node: agg.Sum, frame: ArrowDataFrame) -> ChunkedArrayAny: - import pyarrow.compute as pc - - result = pc.sum(_evaluate_inner(node.expr, frame), min_count=0) - return _lit_native(result, frame) - - -@_evaluate_inner.register(agg.NUnique) -def agg_n_unique(node: agg.NUnique, frame: ArrowDataFrame) -> ChunkedArrayAny: - import pyarrow.compute as pc - - result = pc.count(_evaluate_inner(node.expr, frame).unique(), mode="all") - return _lit_native(result, frame) - - -@_evaluate_inner.register(agg.Var) -@_evaluate_inner.register(agg.Std) -def agg_std_var(node: agg.Std | agg.Var, frame: ArrowDataFrame) -> ChunkedArrayAny: - import pyarrow.compute as pc - - fn = pc.stddev if isinstance(node, agg.Std) else pc.variance - result = fn(_evaluate_inner(node.expr, frame), ddof=node.ddof) - return _lit_native(result, frame) - - -@_evaluate_inner.register(agg.Quantile) -def agg_quantile(node: agg.Quantile, frame: ArrowDataFrame) -> ChunkedArrayAny: - import pyarrow.compute as pc - - result = pc.quantile( - _evaluate_inner(node.expr, frame), - q=node.quantile, - interpolation=node.interpolation, - )[0] - return _lit_native(result, frame) - - -@_evaluate_inner.register(expr.Agg) -def agg_expr(node: expr.Agg, frame: ArrowDataFrame) -> ChunkedArrayAny: - import pyarrow.compute as pc - - mapping: dict[type[expr.Agg], UnaryFn] = { - agg.Count: pc.count, - agg.Max: pc.max, - agg.Mean: pc.mean, - agg.Median: pc.approximate_median, - agg.Min: pc.min, - } - if fn := mapping.get(type(node)): - result = fn(_evaluate_inner(node.expr, frame)) - return _lit_native(result, frame) - raise NotImplementedError(type(node)) - - @_evaluate_inner.register(expr.BinaryExpr) def binary_expr(node: expr.BinaryExpr, frame: ArrowDataFrame) -> ChunkedArrayAny: raise NotImplementedError(type(node)) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 7503d2184d..f8ef7b3455 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -176,41 +176,64 @@ def last(self, node: Last, frame: ArrowDataFrame, name: str) -> ArrowScalar: return self._with_native(result, name) def arg_min(self, node: ArgMin, frame: ArrowDataFrame, name: str) -> ArrowScalar: - raise NotImplementedError + native = self._dispatch_expr(node.expr, frame, name).native + result = pc.index(native, pc.min(native)) + return self._with_native(result, name) def arg_max(self, node: ArgMax, frame: ArrowDataFrame, name: str) -> ArrowScalar: - raise NotImplementedError + native = self._dispatch_expr(node.expr, frame, name).native + result: NativeScalar = pc.index(native, pc.max(native)) + return self._with_native(result, name) def sum(self, node: Sum, frame: ArrowDataFrame, name: str) -> ArrowScalar: - raise NotImplementedError + result: NativeScalar = pc.sum( + self._dispatch_expr(node.expr, frame, name).native, min_count=0 + ) + return self._with_native(result, name) def n_unique(self, node: NUnique, frame: ArrowDataFrame, name: str) -> ArrowScalar: - raise NotImplementedError + result = pc.count(self._dispatch_expr(node.expr, frame, name).native, mode="all") + return self._with_native(result, name) def std(self, node: Std, frame: ArrowDataFrame, name: str) -> ArrowScalar: - raise NotImplementedError + result = pc.stddev( + self._dispatch_expr(node.expr, frame, name).native, ddof=node.ddof + ) + return self._with_native(result, name) def var(self, node: Var, frame: ArrowDataFrame, name: str) -> ArrowScalar: - raise NotImplementedError + result = pc.variance( + self._dispatch_expr(node.expr, frame, name).native, ddof=node.ddof + ) + return self._with_native(result, name) def quantile(self, node: Quantile, frame: ArrowDataFrame, name: str) -> ArrowScalar: - raise NotImplementedError + result = pc.quantile( + self._dispatch_expr(node.expr, frame, name).native, + q=node.quantile, + interpolation=node.interpolation, + )[0] + return self._with_native(result, name) def count(self, node: Count, frame: ArrowDataFrame, name: str) -> ArrowScalar: - raise NotImplementedError + result = pc.count(self._dispatch_expr(node.expr, frame, name).native) + return self._with_native(result, name) def max(self, node: Max, frame: ArrowDataFrame, name: str) -> ArrowScalar: result: NativeScalar = pc.max(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) def mean(self, node: Mean, frame: ArrowDataFrame, name: str) -> ArrowScalar: - raise NotImplementedError + result = pc.mean(self._dispatch_expr(node.expr, frame, name).native) + return self._with_native(result, name) def median(self, node: Median, frame: ArrowDataFrame, name: str) -> ArrowScalar: - raise NotImplementedError + result = pc.approximate_median(self._dispatch_expr(node.expr, frame, name).native) + return self._with_native(result, name) def min(self, node: Min, frame: ArrowDataFrame, name: str) -> ArrowScalar: - raise NotImplementedError + result: NativeScalar = pc.min(self._dispatch_expr(node.expr, frame, name).native) + return self._with_native(result, name) def lit(value: Any, dtype: pa.DataType | None = None) -> NativeScalar: From 5039a0992c360f7d626930b50ce1ce6dddb8bb76 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 7 Jul 2025 19:16:25 +0100 Subject: [PATCH 294/368] feat(pyarrow): Add `ArrowDataFrame.sort` Should be much easier to handle `ArrowExpr.sort_by` now --- narwhals/_plan/arrow/dataframe.py | 14 ++++++++ narwhals/_plan/dummy.py | 53 +++++++++++++++++++++++++------ narwhals/_plan/options.py | 32 ++++++++++++++++++- 3 files changed, 89 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 337ac3e299..eab3ac7f37 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -4,9 +4,11 @@ from itertools import chain import pyarrow as pa # ignore-banned-import +import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._plan.arrow.series import ArrowSeries +from narwhals._plan.common import ExprIR from narwhals._plan.dummy import DummyCompliantFrame, DummyFrame from narwhals._utils import Version @@ -18,6 +20,9 @@ from narwhals._arrow.typing import ChunkedArrayAny, ScalarAny from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.common import ExprIR, NamedIR + from narwhals._plan.options import SortMultipleOptions + from narwhals._plan.schema import FrozenSchema + from narwhals._plan.typing import Seq from narwhals.dtypes import DType from narwhals.schema import Schema @@ -103,3 +108,12 @@ def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[ArrowSe ns = self.__narwhals_namespace__() from_named_ir = ns._expr.from_named_ir yield from ns._expr.align(from_named_ir(e, self) for e in nodes) + + # NOTE: Not handling actual expressions yet + # `DummyFrame` is typed for just `str` names + def sort( + self, by: Seq[NamedIR], options: SortMultipleOptions, projected: FrozenSchema + ) -> Self: + df_by = self.select(by, projected) + indices = pc.sort_indices(df_by.native, options=options.to_arrow(df_by.columns)) + return self._with_native(self.native.take(indices)) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 8cf118983f..0c9b303b13 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -36,7 +36,7 @@ from narwhals.schema import Schema if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Mapping + from collections.abc import Iterable, Iterator, Mapping, Sequence import pyarrow as pa from typing_extensions import Never, Self, TypeAlias @@ -68,6 +68,21 @@ CompliantFrame: TypeAlias = "DummyCompliantFrame[t.Any, NativeFrameT, NativeSeriesT]" +# NOTE: Trying to keep consistent logic between `DataFrame.sort` and `Expr.sort_by` +def _parse_sort_by( + by: IntoExpr | Iterable[IntoExpr] = (), + *more_by: IntoExpr, + descending: bool | t.Iterable[bool] = False, + nulls_last: bool | t.Iterable[bool] = False, +) -> tuple[Seq[ExprIR], SortMultipleOptions]: + sort_by = parse.parse_into_seq_of_expr_ir(by, *more_by) + if length_changing := next((e for e in sort_by if e.is_scalar), None): + msg = f"All expressions sort keys must preserve length, but got:\n{length_changing!r}" + raise InvalidOperationError(msg) + options = SortMultipleOptions.parse(descending=descending, nulls_last=nulls_last) + return sort_by, options + + # NOTE: Overly simplified placeholders for mocking typing # Entirely ignoring namespace + function binding class DummyExpr: @@ -186,14 +201,10 @@ def sort_by( descending: bool | t.Iterable[bool] = False, nulls_last: bool | t.Iterable[bool] = False, ) -> Self: - sort_by = parse.parse_into_seq_of_expr_ir(by, *more_by) - if length_changing := next((e for e in sort_by if e.is_scalar), None): - msg = f"All expressions passed to `sort_by` must preserve length, but got:\n{length_changing!r}" - raise InvalidOperationError(msg) - desc = (descending,) if isinstance(descending, bool) else tuple(descending) - nulls = (nulls_last,) if isinstance(nulls_last, bool) else tuple(nulls_last) - options = SortMultipleOptions(descending=desc, nulls_last=nulls) - return self._from_ir(expr.SortBy(expr=self._ir, by=sort_by, options=options)) + keys, opts = _parse_sort_by( + by, *more_by, descending=descending, nulls_last=nulls_last + ) + return self._from_ir(expr.SortBy(expr=self._ir, by=keys, options=opts)) def filter( self, @@ -871,6 +882,22 @@ def select(self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: t.Any) -> ) return self._from_compliant(self._compliant.select(named_irs, schema_projected)) + def sort( + self, + by: str | Iterable[str], + *more_by: str, + descending: bool | Sequence[bool] = False, + nulls_last: bool | Sequence[bool] = False, + ) -> Self: + sort, opts = _parse_sort_by( + by, *more_by, descending=descending, nulls_last=nulls_last + ) + irs, schema_frozen, output_names = expr_expansion.prepare_projection( + sort, self.schema + ) + named_irs = expr_expansion.into_named_irs(irs, output_names) + return self._from_compliant(self._compliant.sort(named_irs, opts, schema_frozen)) + class DummyCompliantFrame(Generic[CompliantSeriesT, NativeFrameT, NativeSeriesT]): _native: NativeFrameT @@ -904,6 +931,9 @@ def from_native(cls, native: NativeFrameT, /, version: Version) -> Self: obj._version = version return obj + def _with_native(self, native: NativeFrameT) -> Self: + return self.from_native(native, self.version) + @classmethod def from_series( cls, @@ -952,6 +982,11 @@ def _evaluate_irs( def select(self, irs: Seq[NamedIR], projected: FrozenSchema) -> Self: return self.from_series(self._evaluate_irs(irs)) + def sort( + self, by: Seq[NamedIR], options: SortMultipleOptions, projected: FrozenSchema + ) -> Self: + raise NotImplementedError + class DummySeries(Generic[NativeSeriesT]): _compliant: DummyCompliantSeries[NativeSeriesT] diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 1164358458..7e5229f4a1 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -1,11 +1,14 @@ from __future__ import annotations import enum -from typing import TYPE_CHECKING +from itertools import repeat +from typing import TYPE_CHECKING, Literal from narwhals._plan.common import Immutable if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + import pyarrow.compute as pc from narwhals._plan.typing import Seq @@ -170,6 +173,33 @@ def __repr__(self) -> str: ) return f"{type(self).__name__}({args})" + @staticmethod + def parse( + *, descending: bool | Iterable[bool], nulls_last: bool | Iterable[bool] + ) -> SortMultipleOptions: + desc = (descending,) if isinstance(descending, bool) else tuple(descending) + nulls = (nulls_last,) if isinstance(nulls_last, bool) else tuple(nulls_last) + return SortMultipleOptions(descending=desc, nulls_last=nulls) + + def to_arrow(self, by: Sequence[str]) -> pc.SortOptions: + import pyarrow.compute as pc + + if len(self.nulls_last) != 1: + msg = f"pyarrow doesn't support multiple values for `nulls_last`, got: {self.nulls_last!r}" + raise NotImplementedError(msg) + placement: Literal["at_start", "at_end"] = ( + "at_end" if self.nulls_last[0] else "at_start" + ) + if len(self.descending) == 1: + descending: Iterable[bool] = repeat(self.descending[0], len(by)) + else: + descending = self.descending + sorting: list[tuple[str, Literal["ascending", "descending"]]] = [ + (key, "descending" if desc else "ascending") + for key, desc in zip(by, descending) + ] + return pc.SortOptions(sort_keys=sorting, null_placement=placement) + class RankOptions(Immutable): """https://github.com/narwhals-dev/narwhals/pull/2555.""" From c26c862262d37b45f56ebb88fe2fd2abb6f2bd00 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 7 Jul 2025 21:45:40 +0100 Subject: [PATCH 295/368] feat(pyarrow): Impl `Expr.sort_by` `evaluate` is almost ready to remove --- narwhals/_plan/arrow/evaluate.py | 52 ++++++++------------------------ narwhals/_plan/arrow/expr.py | 11 ++++++- tests/plan/to_compliant_test.py | 7 +++++ 3 files changed, 29 insertions(+), 41 deletions(-) diff --git a/narwhals/_plan/arrow/evaluate.py b/narwhals/_plan/arrow/evaluate.py index 63d6828b2c..056404ad7d 100644 --- a/narwhals/_plan/arrow/evaluate.py +++ b/narwhals/_plan/arrow/evaluate.py @@ -1,30 +1,34 @@ -"""TODO: Move all the impls `ArrowExpr`/`ArrowScalar`, then delete.""" +"""TODO: Define remaining nodes in `Dispatch` protocol. + +- `Ternary` +- `BinaryExpr` +- `FunctionExpr` +- `RollingExpr` +- `WindowExpr` +- `OrderedWindowExpr` +- `AnonymousExpr` +""" from __future__ import annotations import typing as t from functools import singledispatch -from itertools import repeat from narwhals._plan import expr -from narwhals._plan.arrow.series import ArrowSeries if t.TYPE_CHECKING: - from typing_extensions import TypeAlias, TypeIs + from typing_extensions import TypeIs from narwhals._arrow.typing import ( # type: ignore[attr-defined] ChunkedArrayAny, - Order, ScalarAny, ) from narwhals._plan.arrow.dataframe import ArrowDataFrame + from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.common import ExprIR, NamedIR from narwhals._plan.protocols import EagerBroadcast -UnaryFn: TypeAlias = "t.Callable[[ChunkedArrayAny], ScalarAny]" - - def is_scalar(obj: t.Any) -> TypeIs[ScalarAny]: import pyarrow as pa # ignore-banned-import @@ -38,43 +42,11 @@ def evaluate(node: NamedIR[ExprIR], frame: ArrowDataFrame) -> EagerBroadcast[Arr return frame.__narwhals_namespace__()._expr.from_native(result, node.name) -# NOTE: Should mean we produce 1x CompliantSeries for the entire expression -# Multi-output have already been separated -# No intermediate CompliantSeries need to be created, just assign a name to the final one @singledispatch def _evaluate_inner(node: ExprIR, frame: ArrowDataFrame) -> ChunkedArrayAny: raise NotImplementedError(type(node)) -@_evaluate_inner.register(expr.Column) -def col(node: expr.Column, frame: ArrowDataFrame) -> ChunkedArrayAny: - return frame.native.column(node.name) - - -@_evaluate_inner.register(expr.SortBy) -def sort_by(node: expr.SortBy, frame: ArrowDataFrame) -> ChunkedArrayAny: - opts = node.options - if len(opts.nulls_last) != 1: - msg = f"pyarrow doesn't support multiple values for `nulls_last`, got: {opts.nulls_last!r}" - raise NotImplementedError(msg) - placement = "at_end" if opts.nulls_last[0] else "at_start" - from_native = ArrowSeries.from_native - by = ( - from_native(_evaluate_inner(e, frame), str(idx)) for idx, e in enumerate(node.by) - ) - df = frame.from_series(from_native(_evaluate_inner(node.expr, frame), ""), *by) - names = df.columns[1:] - if len(opts.descending) == 1: - descending: t.Iterable[bool] = repeat(opts.descending[0], len(names)) - else: - descending = opts.descending - sorting: list[tuple[str, Order]] = [ - (key, "descending" if desc else "ascending") - for key, desc in zip(names, descending) - ] - return df.native.sort_by(sorting, null_placement=placement).column(0) - - @_evaluate_inner.register(expr.Ternary) def ternary(node: expr.Ternary, frame: ArrowDataFrame) -> ChunkedArrayAny: raise NotImplementedError(type(node)) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index f8ef7b3455..b400abe1d6 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -152,7 +152,16 @@ def sort(self, node: expr.Sort, frame: ArrowDataFrame, name: str) -> ArrowExpr: return self._with_native(native.take(sorted_indices), name) def sort_by(self, node: expr.SortBy, frame: ArrowDataFrame, name: str) -> ArrowExpr: - raise NotImplementedError + series = self._dispatch_expr(node.expr, frame, name) + by = ( + self._dispatch_expr(e, frame, f"_{idx}") + for idx, e in enumerate(node.by) + ) + df = frame.from_series(series, *by) + names = df.columns[1:] + indices = pc.sort_indices(df.native, options=node.options.to_arrow(names)) + result: ChunkedArrayAny = df.native.column(0).take(indices) + return self._with_native(result, name) def filter(self, node: expr.Filter, frame: ArrowDataFrame, name: str) -> ArrowExpr: return self._with_native( diff --git a/tests/plan/to_compliant_test.py b/tests/plan/to_compliant_test.py index 457ee4232a..21f978938b 100644 --- a/tests/plan/to_compliant_test.py +++ b/tests/plan/to_compliant_test.py @@ -83,6 +83,13 @@ def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None: marks=XFAIL_REWRITE_SPECIAL_ALIASES, ), ([ndcs.string().first(), nwd.col("b")], {"a": ["A", "A", "A"], "b": [1, 2, 3]}), + ( + nwd.col("c", "d") + .sort_by("a", "b", descending=[True, False]) + .cast(nw.Float32()) + .name.to_uppercase(), + {"C": [2.0, 9.0, 4.0], "D": [7.0, 8.0, 8.0]}, + ), ], ids=_ids_ir, ) From c72074918c50dc02b295fe732335312c6ca69b5c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 7 Jul 2025 21:51:17 +0100 Subject: [PATCH 296/368] fix: unused-ignore https://github.com/narwhals-dev/narwhals/actions/runs/16127500855/job/45507810339 --- narwhals/_plan/arrow/evaluate.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/narwhals/_plan/arrow/evaluate.py b/narwhals/_plan/arrow/evaluate.py index 056404ad7d..47abed768b 100644 --- a/narwhals/_plan/arrow/evaluate.py +++ b/narwhals/_plan/arrow/evaluate.py @@ -19,10 +19,7 @@ if t.TYPE_CHECKING: from typing_extensions import TypeIs - from narwhals._arrow.typing import ( # type: ignore[attr-defined] - ChunkedArrayAny, - ScalarAny, - ) + from narwhals._arrow.typing import ChunkedArrayAny, ScalarAny from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.common import ExprIR, NamedIR From 8dcc57b90f393fd76f93399bc559ccc36f701c12 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 8 Jul 2025 20:09:49 +0100 Subject: [PATCH 297/368] refactor: Move `lit`, `col`, `len` to namespace --- narwhals/_plan/arrow/expr.py | 54 +++++++----------------- narwhals/_plan/arrow/namespace.py | 53 +++++++++++++++++++++++- narwhals/_plan/protocols.py | 69 +++++++++++++++++++++---------- 3 files changed, 114 insertions(+), 62 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index b400abe1d6..71bc054803 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -11,8 +11,7 @@ ) from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.common import ExprIR, into_dtype -from narwhals._plan.literal import is_literal_scalar -from narwhals._plan.protocols import Dispatch, EagerExpr, EagerScalar +from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch from narwhals._utils import Implementation, Version, _StoresNative, not_implemented from narwhals.exceptions import InvalidOperationError, ShapeError @@ -45,9 +44,8 @@ Var, ) from narwhals._plan.arrow.dataframe import ArrowDataFrame - from narwhals._plan.dummy import DummySeries - from narwhals._plan.expr import Len - from narwhals.typing import IntoDType, NonNestedLiteral, PythonLiteral + from narwhals._plan.arrow.namespace import ArrowNamespace + from narwhals.typing import IntoDType, PythonLiteral NativeScalar: TypeAlias = "pa.Scalar[Any]" @@ -55,13 +53,18 @@ class ArrowExpr( - Dispatch["ArrowDataFrame", "ArrowExpr | ArrowScalar"], + ExprDispatch["ArrowDataFrame", "ArrowExpr | ArrowScalar", "ArrowNamespace"], _StoresNative["ChunkedArrayAny"], EagerExpr["ArrowDataFrame", ArrowSeries], ): _evaluated: ArrowSeries _version: Version + def __narwhals_namespace__(self) -> ArrowNamespace: + from narwhals._plan.arrow.namespace import ArrowNamespace + + return ArrowNamespace(self._version) + @property def name(self) -> str: return self._evaluated.name @@ -122,23 +125,6 @@ def broadcast(self, length: int, /) -> ArrowSeries: def __len__(self) -> int: return len(self._evaluated) - def col(self, node: expr.Column, frame: ArrowDataFrame, name: str) -> Self: - return self.from_native(frame.native.column(node.name), name) - - def lit( - self, - node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[ChunkedArrayAny]], - frame: ArrowDataFrame, - name: str, - ) -> ArrowScalar | Self: - if is_literal_scalar(node): - return ArrowScalar.from_ir(node, frame, name) - nw_ser = node.unwrap() - return self.from_native(nw_ser.to_native(), name or node.name, nw_ser.version) - - def len(self, node: Len, frame: ArrowDataFrame, name: str) -> ArrowScalar: - return ArrowScalar.from_ir(node, frame, name) - def cast( # type: ignore[override] self, node: expr.Cast, frame: ArrowDataFrame, name: str ) -> ArrowScalar | Self: @@ -272,7 +258,7 @@ def chunked_array( class ArrowScalar( - Dispatch["ArrowDataFrame", "ArrowScalar"], + ExprDispatch["ArrowDataFrame", "ArrowScalar", "ArrowNamespace"], _StoresNative[NativeScalar], EagerScalar["ArrowDataFrame", ArrowSeries], ): @@ -280,6 +266,11 @@ class ArrowScalar( _evaluated: NativeScalar _version: Version + def __narwhals_namespace__(self) -> ArrowNamespace: + from narwhals._plan.arrow.namespace import ArrowNamespace + + return ArrowNamespace(self._version) + @property def name(self) -> str: return self._name @@ -351,21 +342,6 @@ def broadcast(self, length: int) -> ArrowSeries: chunked = chunked_array(pa_repeat(scalar, length)) return ArrowSeries.from_native(chunked, self.name, version=self.version) - # NOTE: These need to move into a single def in `Namespace` - # - col - # - len - # - lit - def col(self, node: expr.Column, frame: ArrowDataFrame, name: str) -> Self: - return ArrowExpr.from_ir(node, frame, name) # type: ignore[return-value] - - def lit( - self, node: expr.Literal[NonNestedLiteral], frame: ArrowDataFrame, name: str - ) -> Self: - return self.from_python(node.unwrap(), name, dtype=node.dtype) - - def len(self, node: Len, frame: ArrowDataFrame, name: str) -> Self: - return self.from_python(len(frame), name or node.name, version=frame.version) - def cast(self, node: expr.Cast, frame: ArrowDataFrame, name: str) -> ArrowScalar: data_type = narwhals_to_native_dtype(node.dtype, frame.version) native = self._dispatch(node.expr, frame, name).native diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index 90c05d124f..a9ff8ef596 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -1,15 +1,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, overload +from narwhals._plan.literal import is_literal_scalar from narwhals._plan.protocols import EagerNamespace from narwhals._utils import Version if TYPE_CHECKING: + from narwhals._arrow.typing import ChunkedArrayAny + from narwhals._plan import expr from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.common import ExprIR, NamedIR + from narwhals._plan.dummy import DummySeries + from narwhals.typing import NonNestedLiteral class ArrowNamespace( @@ -44,3 +49,49 @@ def _dataframe(self) -> type[ArrowDataFrame]: def dispatch_expr(self, named_ir: NamedIR[ExprIR], frame: ArrowDataFrame) -> Any: return self._expr.from_named_ir(named_ir, frame) + + def col(self, node: expr.Column, frame: ArrowDataFrame, name: str) -> ArrowExpr: + return self._expr.from_native( + frame.native.column(node.name), name, version=frame.version + ) + + @overload + def lit( + self, node: expr.Literal[NonNestedLiteral], frame: ArrowDataFrame, name: str + ) -> ArrowScalar: ... + + @overload + def lit( + self, + node: expr.Literal[DummySeries[ChunkedArrayAny]], + frame: ArrowDataFrame, + name: str, + ) -> ArrowExpr: ... + + @overload + def lit( + self, + node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[ChunkedArrayAny]], + frame: ArrowDataFrame, + name: str, + ) -> ArrowExpr | ArrowScalar: ... + + def lit( + self, + node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[ChunkedArrayAny]], + frame: ArrowDataFrame, + name: str, + ) -> ArrowExpr | ArrowScalar: + if is_literal_scalar(node): + return self._scalar.from_python( + node.unwrap(), name, dtype=node.dtype, version=frame.version + ) + nw_ser = node.unwrap() + return self._expr.from_native( + nw_ser.to_native(), name or node.name, nw_ser.version + ) + + def len(self, node: expr.Len, frame: ArrowDataFrame, name: str) -> ArrowScalar: + return self._scalar.from_python( + len(frame), name or node.name, version=frame.version + ) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 0da4913ee9..4c95c7a40c 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence, Sized -from typing import TYPE_CHECKING, Any, ClassVar, Protocol +from typing import TYPE_CHECKING, Any, ClassVar, Protocol, overload from narwhals._plan import aggregation as agg, expr from narwhals._plan.common import ExprIR, NamedIR, flatten_hash_safe @@ -11,6 +11,7 @@ if TYPE_CHECKING: from typing_extensions import Self, TypeAlias + from narwhals._plan.dummy import DummySeries from narwhals.typing import IntoDType, NonNestedLiteral, PythonLiteral T = TypeVar("T") @@ -25,6 +26,7 @@ NativeT_co = TypeVar("NativeT_co", covariant=True, default=Any) ExprAny: TypeAlias = "CompliantExpr[Any, Any]" ScalarAny: TypeAlias = "CompliantScalar[Any, Any]" +NamespaceAny: TypeAlias = "CompliantNamespace[Any, Any, Any, Any]" ExprT_co = TypeVar("ExprT_co", bound=ExprAny, covariant=True) ScalarT = TypeVar("ScalarT", bound="CompliantScalar[Any, Any]") ScalarT_co = TypeVar("ScalarT_co", bound="CompliantScalar[Any, Any]", covariant=True) @@ -35,6 +37,8 @@ "EagerScalarT_co", bound="EagerScalar[Any, Any]", covariant=True ) +NamespaceT_co = TypeVar("NamespaceT_co", bound="NamespaceAny", covariant=True) + class SupportsBroadcast(Protocol[SeriesT, LengthT]): """Minimal broadcasting for `Expr` results.""" @@ -130,14 +134,6 @@ def from_native( def _with_native(self, native: Any, name: str = "", /) -> Self: return self.from_native(native, name or self.name, self.version) - def col(self, node: expr.Column, frame: FrameT_contra, name: str) -> Self: ... - def lit( - self, node: expr.Literal[Any], frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co] | Self: ... - def len( - self, node: expr.Len, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - # series & scalar def cast(self, node: expr.Cast, frame: FrameT_contra, name: str) -> Self: ... # series only (section 3) @@ -189,11 +185,17 @@ def min( ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... -class Dispatch(Protocol[FrameT_contra, R_co]): +class ExprDispatch(Protocol[FrameT_contra, R_co, NamespaceT_co]): _DISPATCH: ClassVar[Mapping[type[ExprIR], Callable[[Any, ExprIR, Any, str], Any]]] = { - expr.Column: lambda self, node, frame, name: self.col(node, frame, name), - expr.Literal: lambda self, node, frame, name: self.lit(node, frame, name), - expr.Len: lambda self, node, frame, name: self.len(node, frame, name), + expr.Column: lambda self, node, frame, name: self.__narwhals_namespace__().col( + node, frame, name + ), + expr.Literal: lambda self, node, frame, name: self.__narwhals_namespace__().lit( + node, frame, name + ), + expr.Len: lambda self, node, frame, name: self.__narwhals_namespace__().len( + node, frame, name + ), expr.Cast: lambda self, node, frame, name: self.cast(node, frame, name), expr.Sort: lambda self, node, frame, name: self.sort(node, frame, name), expr.SortBy: lambda self, node, frame, name: self.sort_by(node, frame, name), @@ -231,6 +233,9 @@ def from_ir(cls, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: def from_named_ir(cls, named_ir: NamedIR[ExprIR], frame: FrameT_contra) -> R_co: return cls.from_ir(named_ir.expr, frame, named_ir.name) + # NOTE: Needs to stay `covariant` and never be used as a parameter + def __narwhals_namespace__(self) -> NamespaceT_co: ... + class CompliantScalar( CompliantExpr[FrameT_contra, SeriesT_co], Protocol[FrameT_contra, SeriesT_co] @@ -266,10 +271,6 @@ def _with_evaluated(self, evaluated: Any, name: str) -> Self: obj._version = self.version return obj - def lit( - self, node: expr.Literal[NonNestedLiteral], frame: FrameT_contra, name: str - ) -> Self: ... - def max(self, node: agg.Max, frame: FrameT_contra, name: str) -> Self: """Returns self.""" return self._with_evaluated(self._evaluated, name) @@ -367,7 +368,7 @@ class LazyScalar( ): ... -class CompliantNamespace(Protocol[FrameT_co, SeriesT_co, ExprT_co, ScalarT_co]): +class CompliantNamespace(Protocol[FrameT, SeriesT_co, ExprT_co, ScalarT_co]): """Need to hold `Expr` and `Scalar` types outside of their defs. Likely, re-wrapping the output types will work like: @@ -385,7 +386,7 @@ class CompliantNamespace(Protocol[FrameT_co, SeriesT_co, ExprT_co, ScalarT_co]): _version: Version @property - def _dataframe(self) -> type[FrameT_co]: ... + def _dataframe(self) -> type[FrameT]: ... @property def _series(self) -> type[SeriesT_co]: ... @property @@ -397,8 +398,32 @@ def _scalar(self) -> type[ScalarT_co]: ... def version(self) -> Version: return self._version + def col(self, node: expr.Column, frame: FrameT, name: str) -> ExprT_co: ... + def lit( + self, node: expr.Literal[Any], frame: FrameT, name: str + ) -> ExprT_co | ScalarT_co: ... + def len(self, node: expr.Len, frame: FrameT, name: str) -> ScalarT_co: ... + class EagerNamespace( - CompliantNamespace[FrameT_co, SeriesT_co, EagerExprT_co, EagerScalarT_co], - Protocol[FrameT_co, SeriesT_co, EagerExprT_co, EagerScalarT_co], -): ... + CompliantNamespace[FrameT, SeriesT_co, EagerExprT_co, EagerScalarT_co], + Protocol[FrameT, SeriesT_co, EagerExprT_co, EagerScalarT_co], +): + @overload + def lit( + self, node: expr.Literal[NonNestedLiteral], frame: FrameT, name: str + ) -> EagerScalarT_co: ... + @overload + def lit( + self, node: expr.Literal[DummySeries[Any]], frame: FrameT, name: str + ) -> EagerExprT_co: ... + @overload + def lit( + self, + node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[Any]], + frame: FrameT, + name: str, + ) -> EagerExprT_co | EagerScalarT_co: ... + def lit( + self, node: expr.Literal[Any], frame: FrameT, name: str + ) -> EagerExprT_co | EagerScalarT_co: ... From 4d54ff46c8ed95f30f9c2ec3fcd04c63e6f75395 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 8 Jul 2025 20:44:10 +0100 Subject: [PATCH 298/368] refactor: Move `len` impl to `EagerNamespace` Dependent on `frame` being `Sized`, so only safe to implement for eager cases --- narwhals/_plan/arrow/namespace.py | 5 ----- narwhals/_plan/protocols.py | 18 ++++++++++++------ 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index a9ff8ef596..e7acd2dda0 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -90,8 +90,3 @@ def lit( return self._expr.from_native( nw_ser.to_native(), name or node.name, nw_ser.version ) - - def len(self, node: expr.Len, frame: ArrowDataFrame, name: str) -> ArrowScalar: - return self._scalar.from_python( - len(frame), name or node.name, version=frame.version - ) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 4c95c7a40c..001b7e50bc 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -6,21 +6,22 @@ from narwhals._plan import aggregation as agg, expr from narwhals._plan.common import ExprIR, NamedIR, flatten_hash_safe from narwhals._typing_compat import TypeVar -from narwhals._utils import Version, _StoresVersion +from narwhals._utils import Version if TYPE_CHECKING: from typing_extensions import Self, TypeAlias - from narwhals._plan.dummy import DummySeries + from narwhals._plan.dummy import DummyCompliantFrame, DummySeries from narwhals.typing import IntoDType, NonNestedLiteral, PythonLiteral T = TypeVar("T") R_co = TypeVar("R_co", covariant=True) SeriesT = TypeVar("SeriesT") SeriesT_co = TypeVar("SeriesT_co", covariant=True) -FrameT = TypeVar("FrameT") -FrameT_co = TypeVar("FrameT_co", covariant=True) -FrameT_contra = TypeVar("FrameT_contra", bound="_StoresVersion", contravariant=True) +FrameAny: TypeAlias = "DummyCompliantFrame[Any, Any, Any]" +FrameT = TypeVar("FrameT", bound=FrameAny) +FrameT_co = TypeVar("FrameT_co", bound=FrameAny, covariant=True) +FrameT_contra = TypeVar("FrameT_contra", bound=FrameAny, contravariant=True) OneOrIterable: TypeAlias = "T | Iterable[T]" LengthT = TypeVar("LengthT") NativeT_co = TypeVar("NativeT_co", covariant=True, default=Any) @@ -226,7 +227,7 @@ def _dispatch(self, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: @classmethod def from_ir(cls, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: obj = cls.__new__(cls) - obj._version = frame._version + obj._version = frame.version return obj._dispatch(node, frame, name) @classmethod @@ -427,3 +428,8 @@ def lit( def lit( self, node: expr.Literal[Any], frame: FrameT, name: str ) -> EagerExprT_co | EagerScalarT_co: ... + + def len(self, node: expr.Len, frame: FrameT, name: str) -> EagerScalarT_co: + return self._scalar.from_python( + len(frame), name or node.name, dtype=None, version=frame.version + ) From 25a248b43926b96ac5f7a3ad4c1067a6f128e601 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 8 Jul 2025 20:58:40 +0100 Subject: [PATCH 299/368] revert: remove unused --- narwhals/_plan/arrow/dataframe.py | 4 ---- narwhals/_plan/arrow/namespace.py | 6 +----- narwhals/_plan/demo.py | 7 ------- narwhals/_plan/dummy.py | 4 ---- tests/plan/expr_parsing_test.py | 4 +++- 5 files changed, 4 insertions(+), 21 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index eab3ac7f37..ba1021e810 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -40,10 +40,6 @@ def __narwhals_namespace__(self) -> ArrowNamespace: return ArrowNamespace(self._version) - @property - def _series(self) -> type[ArrowSeries]: - return ArrowSeries - @property def columns(self) -> list[str]: return self.native.column_names diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index e7acd2dda0..eef9f987da 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, overload from narwhals._plan.literal import is_literal_scalar from narwhals._plan.protocols import EagerNamespace @@ -12,7 +12,6 @@ from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar from narwhals._plan.arrow.series import ArrowSeries - from narwhals._plan.common import ExprIR, NamedIR from narwhals._plan.dummy import DummySeries from narwhals.typing import NonNestedLiteral @@ -47,9 +46,6 @@ def _dataframe(self) -> type[ArrowDataFrame]: return ArrowDataFrame - def dispatch_expr(self, named_ir: NamedIR[ExprIR], frame: ArrowDataFrame) -> Any: - return self._expr.from_named_ir(named_ir, frame) - def col(self, node: expr.Column, frame: ArrowDataFrame, name: str) -> ArrowExpr: return self._expr.from_native( frame.native.column(node.name), name, version=frame.version diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index abed55bf2a..eaec1a096a 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -10,7 +10,6 @@ functions as F, # noqa: N812 ) from narwhals._plan.common import ( - ExprIR, into_dtype, is_non_nested_literal, is_series, @@ -228,9 +227,3 @@ def ensure_orderable_rules(*exprs: DummyExpr) -> tuple[DummyExpr, ...]: if not _is_order_enforcing_previous(previous): raise _order_dependent_error(node) return exprs - - -def select_context( - *exprs: IntoExpr | t.Iterable[IntoExpr], **named_exprs: IntoExpr -) -> tuple[ExprIR, ...]: - return parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 0c9b303b13..3b96bfe7a2 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -917,10 +917,6 @@ def version(self) -> Version: def columns(self) -> list[str]: raise NotImplementedError - @property - def _series(self) -> type[CompliantSeriesT]: - raise NotImplementedError - def to_narwhals(self) -> DummyFrame[NativeFrameT, NativeSeriesT]: raise NotImplementedError diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 85badfaeb7..955e09575b 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -18,6 +18,7 @@ from narwhals._plan.common import ExprIR, Function from narwhals._plan.dummy import DummyExpr, DummySeries from narwhals._plan.expr import BinaryExpr, FunctionExpr, RangeExpr +from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir from narwhals._plan.literal import SeriesLiteral from narwhals.exceptions import ( InvalidIntoExprError, @@ -58,7 +59,8 @@ def test_parsing( exprs: Seq[IntoExpr | Iterable[IntoExpr]], named_exprs: dict[str, IntoExpr] ) -> None: assert all( - isinstance(node, ExprIR) for node in nwd.select_context(*exprs, **named_exprs) + isinstance(node, ExprIR) + for node in parse_into_seq_of_expr_ir(*exprs, **named_exprs) ) From d68239050506eb5fba5a866eee68318bbcf93084 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 8 Jul 2025 22:16:02 +0100 Subject: [PATCH 300/368] feat: Add the remaining top-level nodes --- narwhals/_plan/arrow/evaluate.py | 76 -------------------------------- narwhals/_plan/protocols.py | 44 +++++++++++++++++- tests/plan/to_compliant_test.py | 17 +++++++ 3 files changed, 59 insertions(+), 78 deletions(-) delete mode 100644 narwhals/_plan/arrow/evaluate.py diff --git a/narwhals/_plan/arrow/evaluate.py b/narwhals/_plan/arrow/evaluate.py deleted file mode 100644 index 47abed768b..0000000000 --- a/narwhals/_plan/arrow/evaluate.py +++ /dev/null @@ -1,76 +0,0 @@ -"""TODO: Define remaining nodes in `Dispatch` protocol. - -- `Ternary` -- `BinaryExpr` -- `FunctionExpr` -- `RollingExpr` -- `WindowExpr` -- `OrderedWindowExpr` -- `AnonymousExpr` -""" - -from __future__ import annotations - -import typing as t -from functools import singledispatch - -from narwhals._plan import expr - -if t.TYPE_CHECKING: - from typing_extensions import TypeIs - - from narwhals._arrow.typing import ChunkedArrayAny, ScalarAny - from narwhals._plan.arrow.dataframe import ArrowDataFrame - from narwhals._plan.arrow.series import ArrowSeries - from narwhals._plan.common import ExprIR, NamedIR - from narwhals._plan.protocols import EagerBroadcast - - -def is_scalar(obj: t.Any) -> TypeIs[ScalarAny]: - import pyarrow as pa # ignore-banned-import - - return isinstance(obj, pa.Scalar) - - -def evaluate(node: NamedIR[ExprIR], frame: ArrowDataFrame) -> EagerBroadcast[ArrowSeries]: - result = _evaluate_inner(node.expr, frame) - if is_scalar(result): - return frame.__narwhals_namespace__()._scalar.from_native(result, node.name) - return frame.__narwhals_namespace__()._expr.from_native(result, node.name) - - -@singledispatch -def _evaluate_inner(node: ExprIR, frame: ArrowDataFrame) -> ChunkedArrayAny: - raise NotImplementedError(type(node)) - - -@_evaluate_inner.register(expr.Ternary) -def ternary(node: expr.Ternary, frame: ArrowDataFrame) -> ChunkedArrayAny: - raise NotImplementedError(type(node)) - - -@_evaluate_inner.register(expr.BinaryExpr) -def binary_expr(node: expr.BinaryExpr, frame: ArrowDataFrame) -> ChunkedArrayAny: - raise NotImplementedError(type(node)) - - -@_evaluate_inner.register(expr.FunctionExpr) -def function_expr( - node: expr.FunctionExpr[t.Any], frame: ArrowDataFrame -) -> ChunkedArrayAny: - raise NotImplementedError(type(node)) - - -@_evaluate_inner.register(expr.RollingExpr) -def rolling_expr(node: expr.RollingExpr[t.Any], frame: ArrowDataFrame) -> ChunkedArrayAny: - raise NotImplementedError(type(node)) - - -@_evaluate_inner.register(expr.WindowExpr) -def window_expr(node: expr.WindowExpr, frame: ArrowDataFrame) -> ChunkedArrayAny: - raise NotImplementedError(type(node)) - - -@_evaluate_inner.register(expr.AnonymousExpr) -def anonymous_expr(node: expr.AnonymousExpr, frame: ArrowDataFrame) -> ChunkedArrayAny: - raise NotImplementedError(type(node)) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 001b7e50bc..153f5e4d42 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -137,6 +137,25 @@ def _with_native(self, native: Any, name: str = "", /) -> Self: # series & scalar def cast(self, node: expr.Cast, frame: FrameT_contra, name: str) -> Self: ... + def binary_expr( + self, node: expr.BinaryExpr, frame: FrameT_contra, name: str + ) -> Self: ... + def ternary_expr( + self, node: expr.Ternary, frame: FrameT_contra, name: str + ) -> Self: ... + def over(self, node: expr.WindowExpr, frame: FrameT_contra, name: str) -> Self: ... + def over_ordered( + self, node: expr.OrderedWindowExpr, frame: FrameT_contra, name: str + ) -> Self: ... + def map_batches( + self, node: expr.AnonymousExpr, frame: FrameT_contra, name: str + ) -> Self: ... + def rolling_expr( + self, node: expr.RollingExpr, frame: FrameT_contra, name: str + ) -> Self: ... + def function_expr( + self, node: expr.FunctionExpr, frame: FrameT_contra, name: str + ) -> Self: ... # series only (section 3) def sort(self, node: expr.Sort, frame: FrameT_contra, name: str) -> Self: ... def sort_by(self, node: expr.SortBy, frame: FrameT_contra, name: str) -> Self: ... @@ -215,12 +234,33 @@ class ExprDispatch(Protocol[FrameT_contra, R_co, NamespaceT_co]): agg.Mean: lambda self, node, frame, name: self.mean(node, frame, name), agg.Median: lambda self, node, frame, name: self.median(node, frame, name), agg.Min: lambda self, node, frame, name: self.min(node, frame, name), + expr.BinaryExpr: lambda self, node, frame, name: self.binary_expr( + node, frame, name + ), + expr.RollingExpr: lambda self, node, frame, name: self.rolling_expr( + node, frame, name + ), + expr.AnonymousExpr: lambda self, node, frame, name: self.map_batches( + node, frame, name + ), + expr.FunctionExpr: lambda self, node, frame, name: self.function_expr( + node, frame, name + ), + expr.OrderedWindowExpr: lambda self, node, frame, name: self.over_ordered( + node, frame, name + ), + expr.WindowExpr: lambda self, node, frame, name: self.over(node, frame, name), + expr.Ternary: lambda self, node, frame, name: self.ternary_expr( + node, frame, name + ), } _version: Version def _dispatch(self, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: - if method := self._DISPATCH.get(node.__class__): - return method(self, node, frame, name) # type: ignore[no-any-return] + if (method := self._DISPATCH.get(node.__class__)) and ( + result := method(self, node, frame, name) + ): + return result # type: ignore[no-any-return] msg = f"Support for {node.__class__.__name__!r} is not yet implemented, got:\n{node!r}" raise NotImplementedError(msg) diff --git a/tests/plan/to_compliant_test.py b/tests/plan/to_compliant_test.py index 21f978938b..ae508b51f0 100644 --- a/tests/plan/to_compliant_test.py +++ b/tests/plan/to_compliant_test.py @@ -107,3 +107,20 @@ def test_select( df = DummyFrame.from_native(frame) result = df.select(expr).to_dict(as_series=False) assert result == expected + + +if TYPE_CHECKING: + + def test_protocol_expr() -> None: + """Static test for all members implemented. + + There's a lot left to implement, but only gets detected if we invoke `__init__`, which + doesn't happen elsewhere at the moment. + """ + pytest.importorskip("pyarrow") + from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar + + expr = ArrowExpr() # type: ignore[abstract] + scalar = ArrowScalar() # type: ignore[abstract] + assert expr + assert scalar From 14033aed3cd772b1d70f1d5ec055ec67d095ce08 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 9 Jul 2025 19:00:35 +0100 Subject: [PATCH 301/368] refactor: Use `Protocol`s, move to `protocols` + some tidying up along the way --- narwhals/_plan/arrow/dataframe.py | 3 +- narwhals/_plan/arrow/series.py | 2 +- narwhals/_plan/common.py | 19 ++-- narwhals/_plan/dummy.py | 142 +------------------------- narwhals/_plan/protocols.py | 159 +++++++++++++++++++++++++++--- 5 files changed, 162 insertions(+), 163 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index ba1021e810..7926b0b5d6 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -9,7 +9,8 @@ from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.common import ExprIR -from narwhals._plan.dummy import DummyCompliantFrame, DummyFrame +from narwhals._plan.dummy import DummyFrame +from narwhals._plan.protocols import DummyCompliantFrame from narwhals._utils import Version if t.TYPE_CHECKING: diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index 1692c459ce..0b4abf0787 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any from narwhals._arrow.utils import native_to_narwhals_dtype -from narwhals._plan.dummy import DummyCompliantSeries +from narwhals._plan.protocols import DummyCompliantSeries if TYPE_CHECKING: from narwhals._arrow.typing import ChunkedArrayAny # noqa: F401 diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 3fd364a42f..675b4b8f2a 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -18,6 +18,7 @@ Ns, Seq, ) +from narwhals._utils import _hasattr_static from narwhals.dtypes import DType from narwhals.utils import Version @@ -28,12 +29,7 @@ from typing_extensions import Never, Self, TypeIs, dataclass_transform from narwhals._plan import expr - from narwhals._plan.dummy import ( - DummyCompliantSeries, - DummyExpr, - DummySelector, - DummySeries, - ) + from narwhals._plan.dummy import DummyExpr, DummySelector, DummySeries from narwhals._plan.expr import ( Agg, BinaryExpr, @@ -44,6 +40,7 @@ ) from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.options import FunctionOptions + from narwhals._plan.protocols import DummyCompliantSeries from narwhals.typing import NonNestedDType, NonNestedLiteral else: @@ -456,12 +453,18 @@ def is_series( return isinstance(obj, DummySeries) +def is_compliant_series( + obj: DummyCompliantSeries[NativeSeriesT] | Any, +) -> TypeIs[DummyCompliantSeries[NativeSeriesT]]: + return _hasattr_static(obj, "__narwhals_series__") + + def is_iterable_reject( obj: Any, ) -> TypeIs[str | bytes | DummySeries | DummyCompliantSeries]: - from narwhals._plan.dummy import DummyCompliantSeries, DummySeries + from narwhals._plan.dummy import DummySeries - return isinstance(obj, (str, bytes, DummySeries, DummyCompliantSeries)) + return isinstance(obj, (str, bytes, DummySeries)) or is_compliant_series(obj) def is_regex_projection(name: str) -> bool: diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 3b96bfe7a2..3f2ffb50e4 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -28,15 +28,14 @@ from narwhals._plan.selectors import by_name from narwhals._plan.typing import NativeFrameT, NativeSeriesT from narwhals._plan.window import Over -from narwhals._typing_compat import TypeVar -from narwhals._utils import Version, _hasattr_static +from narwhals._utils import Version from narwhals.dependencies import is_pyarrow_chunked_array, is_pyarrow_table from narwhals.dtypes import DType from narwhals.exceptions import ComputeError, InvalidOperationError from narwhals.schema import Schema if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Mapping, Sequence + from collections.abc import Iterable, Sequence import pyarrow as pa from typing_extensions import Never, Self, TypeAlias @@ -46,6 +45,7 @@ from narwhals._plan.lists import ExprListNamespace from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.name import ExprNameNamespace + from narwhals._plan.protocols import DummyCompliantFrame, DummyCompliantSeries from narwhals._plan.schema import FrozenSchema from narwhals._plan.strings import ExprStringNamespace from narwhals._plan.struct import ExprStructNamespace @@ -64,7 +64,6 @@ ) -CompliantSeriesT = TypeVar("CompliantSeriesT", bound="DummyCompliantSeries[t.Any]") CompliantFrame: TypeAlias = "DummyCompliantFrame[t.Any, NativeFrameT, NativeSeriesT]" @@ -899,91 +898,6 @@ def sort( return self._from_compliant(self._compliant.sort(named_irs, opts, schema_frozen)) -class DummyCompliantFrame(Generic[CompliantSeriesT, NativeFrameT, NativeSeriesT]): - _native: NativeFrameT - _version: Version - - def __narwhals_namespace__(self) -> t.Any: ... - - @property - def native(self) -> NativeFrameT: - return self._native - - @property - def version(self) -> Version: - return self._version - - @property - def columns(self) -> list[str]: - raise NotImplementedError - - def to_narwhals(self) -> DummyFrame[NativeFrameT, NativeSeriesT]: - raise NotImplementedError - - @classmethod - def from_native(cls, native: NativeFrameT, /, version: Version) -> Self: - obj = cls.__new__(cls) - obj._native = native - obj._version = version - return obj - - def _with_native(self, native: NativeFrameT) -> Self: - return self.from_native(native, self.version) - - @classmethod - def from_series( - cls, - series: Iterable[CompliantSeriesT] | CompliantSeriesT, - *more_series: CompliantSeriesT, - ) -> Self: - """Return a new DataFrame, horizontally concatenating multiple Series.""" - raise NotImplementedError - - @classmethod - def from_dict( - cls, - data: Mapping[str, t.Any], - /, - *, - schema: Mapping[str, DType] | Schema | None = None, - ) -> Self: - raise NotImplementedError - - @t.overload - def to_dict(self, *, as_series: t.Literal[True]) -> dict[str, CompliantSeriesT]: ... - @t.overload - def to_dict(self, *, as_series: t.Literal[False]) -> dict[str, list[t.Any]]: ... - @t.overload - def to_dict( - self, *, as_series: bool - ) -> dict[str, CompliantSeriesT] | dict[str, list[t.Any]]: ... - - def to_dict( - self, *, as_series: bool - ) -> dict[str, CompliantSeriesT] | dict[str, list[t.Any]]: - raise NotImplementedError - - def __len__(self) -> int: - raise NotImplementedError - - @property - def schema(self) -> Mapping[str, DType]: - raise NotImplementedError - - def _evaluate_irs( - self, nodes: Iterable[NamedIR[ExprIR]], / - ) -> Iterator[CompliantSeriesT]: - raise NotImplementedError - - def select(self, irs: Seq[NamedIR], projected: FrozenSchema) -> Self: - return self.from_series(self._evaluate_irs(irs)) - - def sort( - self, by: Seq[NamedIR], options: SortMultipleOptions, projected: FrozenSchema - ) -> Self: - raise NotImplementedError - - class DummySeries(Generic[NativeSeriesT]): _compliant: DummyCompliantSeries[NativeSeriesT] _version: t.ClassVar[Version] = Version.MAIN @@ -1032,53 +946,3 @@ def __iter__(self) -> t.Iterator[t.Any]: class DummySeriesV1(DummySeries[NativeSeriesT]): _version: t.ClassVar[Version] = Version.V1 - - -class DummyCompliantSeries(Generic[NativeSeriesT]): - _native: NativeSeriesT - _name: str - _version: Version - - @property - def native(self) -> NativeSeriesT: - return self._native - - @property - def version(self) -> Version: - return self._version - - @property - def dtype(self) -> DType: - raise NotImplementedError - - @property - def name(self) -> str: - return self._name - - def to_narwhals(self) -> DummySeries[NativeSeriesT]: - return DummySeries[NativeSeriesT]._from_compliant(self) - - @classmethod - def from_native( - cls, native: NativeSeriesT, name: str = "", /, *, version: Version = Version.MAIN - ) -> Self: - name = name or ( - getattr(native, "name", name) if _hasattr_static(native, "name") else name - ) - obj = cls.__new__(cls) - obj._native = native - obj._name = name - obj._version = version - return obj - - def _with_native(self, native: NativeSeriesT) -> Self: - return self.from_native(native, self.name, version=self.version) - - def alias(self, name: str) -> Self: - return self.from_native(self.native, name, version=self.version) - - def __len__(self) -> int: - return len(self.native) - - def to_list(self) -> list[t.Any]: - raise NotImplementedError diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 153f5e4d42..bee0380a16 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -1,45 +1,50 @@ from __future__ import annotations from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence, Sized -from typing import TYPE_CHECKING, Any, ClassVar, Protocol, overload +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Protocol, overload from narwhals._plan import aggregation as agg, expr from narwhals._plan.common import ExprIR, NamedIR, flatten_hash_safe +from narwhals._plan.typing import NativeFrameT, NativeSeriesT, Seq from narwhals._typing_compat import TypeVar -from narwhals._utils import Version +from narwhals._utils import Version, _hasattr_static if TYPE_CHECKING: from typing_extensions import Self, TypeAlias - from narwhals._plan.dummy import DummyCompliantFrame, DummySeries + from narwhals._plan.dummy import DummyFrame, DummySeries + from narwhals._plan.options import SortMultipleOptions + from narwhals._plan.schema import FrozenSchema + from narwhals.dtypes import DType + from narwhals.schema import Schema from narwhals.typing import IntoDType, NonNestedLiteral, PythonLiteral T = TypeVar("T") R_co = TypeVar("R_co", covariant=True) -SeriesT = TypeVar("SeriesT") -SeriesT_co = TypeVar("SeriesT_co", covariant=True) -FrameAny: TypeAlias = "DummyCompliantFrame[Any, Any, Any]" -FrameT = TypeVar("FrameT", bound=FrameAny) -FrameT_co = TypeVar("FrameT_co", bound=FrameAny, covariant=True) -FrameT_contra = TypeVar("FrameT_contra", bound=FrameAny, contravariant=True) OneOrIterable: TypeAlias = "T | Iterable[T]" LengthT = TypeVar("LengthT") NativeT_co = TypeVar("NativeT_co", covariant=True, default=Any) + ExprAny: TypeAlias = "CompliantExpr[Any, Any]" ScalarAny: TypeAlias = "CompliantScalar[Any, Any]" +SeriesAny: TypeAlias = "DummyCompliantSeries[Any]" +FrameAny: TypeAlias = "DummyCompliantFrame[Any, Any, Any]" NamespaceAny: TypeAlias = "CompliantNamespace[Any, Any, Any, Any]" + ExprT_co = TypeVar("ExprT_co", bound=ExprAny, covariant=True) -ScalarT = TypeVar("ScalarT", bound="CompliantScalar[Any, Any]") -ScalarT_co = TypeVar("ScalarT_co", bound="CompliantScalar[Any, Any]", covariant=True) -IntoSeriesT_co = TypeVar("IntoSeriesT_co", bound="ExprAny | ScalarAny", covariant=True) +ScalarT = TypeVar("ScalarT", bound=ScalarAny) +ScalarT_co = TypeVar("ScalarT_co", bound=ScalarAny, covariant=True) +SeriesT = TypeVar("SeriesT", bound=SeriesAny) +SeriesT_co = TypeVar("SeriesT_co", bound=SeriesAny, covariant=True) +FrameT = TypeVar("FrameT", bound=FrameAny) +FrameT_contra = TypeVar("FrameT_contra", bound=FrameAny, contravariant=True) +NamespaceT_co = TypeVar("NamespaceT_co", bound="NamespaceAny", covariant=True) EagerExprT_co = TypeVar("EagerExprT_co", bound="EagerExpr[Any, Any]", covariant=True) EagerScalarT_co = TypeVar( "EagerScalarT_co", bound="EagerScalar[Any, Any]", covariant=True ) -NamespaceT_co = TypeVar("NamespaceT_co", bound="NamespaceAny", covariant=True) - class SupportsBroadcast(Protocol[SeriesT, LengthT]): """Minimal broadcasting for `Expr` results.""" @@ -473,3 +478,129 @@ def len(self, node: expr.Len, frame: FrameT, name: str) -> EagerScalarT_co: return self._scalar.from_python( len(frame), name or node.name, dtype=None, version=frame.version ) + + +class DummyCompliantFrame(Protocol[SeriesT, NativeFrameT, NativeSeriesT]): + _native: NativeFrameT + _version: Version + + def __narwhals_namespace__(self) -> Any: ... + + @property + def native(self) -> NativeFrameT: + return self._native + + @property + def version(self) -> Version: + return self._version + + @property + def columns(self) -> list[str]: ... + + def to_narwhals(self) -> DummyFrame[NativeFrameT, NativeSeriesT]: ... + + @classmethod + def from_native(cls, native: NativeFrameT, /, version: Version) -> Self: + obj = cls.__new__(cls) + obj._native = native + obj._version = version + return obj + + def _with_native(self, native: NativeFrameT) -> Self: + return self.from_native(native, self.version) + + @classmethod + def from_series( + cls, series: Iterable[SeriesT] | SeriesT, *more_series: SeriesT + ) -> Self: + """Return a new DataFrame, horizontally concatenating multiple Series.""" + ... + + @classmethod + def from_dict( + cls, + data: Mapping[str, Any], + /, + *, + schema: Mapping[str, DType] | Schema | None = None, + ) -> Self: ... + + @overload + def to_dict(self, *, as_series: Literal[True]) -> dict[str, SeriesT]: ... + @overload + def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... + @overload + def to_dict( + self, *, as_series: bool + ) -> dict[str, SeriesT] | dict[str, list[Any]]: ... + + def to_dict( + self, *, as_series: bool + ) -> dict[str, SeriesT] | dict[str, list[Any]]: ... + + def __len__(self) -> int: ... + + @property + def schema(self) -> Mapping[str, DType]: ... + + def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[SeriesT]: ... + + def select(self, irs: Seq[NamedIR], projected: FrozenSchema) -> Self: + return self.from_series(self._evaluate_irs(irs)) + + def sort( + self, by: Seq[NamedIR], options: SortMultipleOptions, projected: FrozenSchema + ) -> Self: ... + + +class DummyCompliantSeries(Protocol[NativeSeriesT]): + _native: NativeSeriesT + _name: str + _version: Version + + def __narwhals_series__(self) -> Self: + return self + + @property + def native(self) -> NativeSeriesT: + return self._native + + @property + def version(self) -> Version: + return self._version + + @property + def dtype(self) -> DType: ... + + @property + def name(self) -> str: + return self._name + + def to_narwhals(self) -> DummySeries[NativeSeriesT]: + from narwhals._plan.dummy import DummySeries + + return DummySeries[NativeSeriesT]._from_compliant(self) + + @classmethod + def from_native( + cls, native: NativeSeriesT, name: str = "", /, *, version: Version = Version.MAIN + ) -> Self: + name = name or ( + getattr(native, "name", name) if _hasattr_static(native, "name") else name + ) + obj = cls.__new__(cls) + obj._native = native + obj._name = name + obj._version = version + return obj + + def _with_native(self, native: NativeSeriesT) -> Self: + return self.from_native(native, self.name, version=self.version) + + def alias(self, name: str) -> Self: + return self.from_native(self.native, name, version=self.version) + + def __len__(self) -> int: + return len(self.native) + + def to_list(self) -> list[Any]: ... From d72b2005e435854d44ef81b2e30d87c2fa3aab8e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 9 Jul 2025 20:10:19 +0100 Subject: [PATCH 302/368] feat(pyarrow): Impl `BinaryExpr` --- narwhals/_plan/arrow/expr.py | 56 ++++++++++++++++++++++++++++++--- narwhals/_plan/protocols.py | 2 +- tests/plan/to_compliant_test.py | 9 +++--- 3 files changed, 56 insertions(+), 11 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 71bc054803..f8a759097f 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -1,14 +1,18 @@ from __future__ import annotations +from collections.abc import Callable, Mapping from typing import TYPE_CHECKING, Any, cast, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import ( + cast_for_truediv, chunked_array as _chunked_array, + floordiv_compat, narwhals_to_native_dtype, ) +from narwhals._plan import operators as ops from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.common import ExprIR, into_dtype from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch @@ -45,13 +49,43 @@ ) from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.arrow.namespace import ArrowNamespace + from narwhals._plan.expr import BinaryExpr from narwhals.typing import IntoDType, PythonLiteral NativeScalar: TypeAlias = "pa.Scalar[Any]" +BinOp: TypeAlias = Callable[..., "ChunkedArrayAny | NativeScalar"] BACKEND_VERSION = Implementation.PYARROW._backend_version() +def truediv_compat(lhs: Any, rhs: Any) -> Any: + return pc.divide(*cast_for_truediv(lhs, rhs)) + + +def modulus(lhs: Any, rhs: Any) -> Any: + floor_div = floordiv_compat(lhs, rhs) + return pc.subtract(lhs, pc.multiply(floor_div, rhs)) + + +DISPATCH_BINARY: Mapping[type[ops.Operator], BinOp] = { + ops.Eq: pc.equal, + ops.NotEq: pc.not_equal, + ops.Lt: pc.less, + ops.LtEq: pc.less_equal, + ops.Gt: pc.greater, + ops.GtEq: pc.greater_equal, + ops.Add: pc.add, + ops.Sub: pc.subtract, + ops.Multiply: pc.multiply, + ops.TrueDivide: truediv_compat, + ops.FloorDivide: floordiv_compat, + ops.Modulus: modulus, + ops.And: pc.and_kleene, + ops.Or: pc.or_kleene, + ops.ExclusiveOr: pc.xor, +} + + class ArrowExpr( ExprDispatch["ArrowDataFrame", "ArrowExpr | ArrowScalar", "ArrowNamespace"], _StoresNative["ChunkedArrayAny"], @@ -83,15 +117,15 @@ def from_native( return cls.from_series(ArrowSeries.from_native(native, name, version=version)) @overload - def _with_native(self, result: ChunkedArrayAny, name: str = ..., /) -> Self: ... + def _with_native(self, result: ChunkedArrayAny, name: str, /) -> Self: ... @overload - def _with_native(self, result: NativeScalar, name: str = ..., /) -> ArrowScalar: ... + def _with_native(self, result: NativeScalar, name: str, /) -> ArrowScalar: ... @overload def _with_native( - self, result: ChunkedArrayAny | NativeScalar, name: str = ..., / + self, result: ChunkedArrayAny | NativeScalar, name: str, / ) -> ArrowScalar | Self: ... def _with_native( - self, result: ChunkedArrayAny | NativeScalar, name: str = "", / + self, result: ChunkedArrayAny | NativeScalar, name: str, / ) -> ArrowScalar | Self: if isinstance(result, pa.Scalar): return ArrowScalar.from_native(result, name, version=self.version) @@ -153,7 +187,8 @@ def filter(self, node: expr.Filter, frame: ArrowDataFrame, name: str) -> ArrowEx return self._with_native( self._dispatch_expr(node.expr, frame, name).native.filter( self._dispatch_expr(node.by, frame, name).native - ) + ), + name, ) def first(self, node: First, frame: ArrowDataFrame, name: str) -> ArrowScalar: @@ -230,6 +265,17 @@ def min(self, node: Min, frame: ArrowDataFrame, name: str) -> ArrowScalar: result: NativeScalar = pc.min(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) + def binary_expr( # type: ignore[override] + self, node: BinaryExpr, frame: ArrowDataFrame, name: str + ) -> ArrowScalar | Self: + lhs, rhs = ( + self._dispatch(node.left, frame, name), + self._dispatch(node.right, frame, name), + ) + fn = DISPATCH_BINARY[node.op.__class__] + result = fn(lhs.native, rhs.native) + return self._with_native(result, name) + def lit(value: Any, dtype: pa.DataType | None = None) -> NativeScalar: # NOTE: Needed for `pyarrow<13` diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index bee0380a16..1ef4000ee8 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -137,7 +137,7 @@ def from_native( cls, native: Any, name: str = "", /, version: Version = Version.MAIN ) -> Self: ... - def _with_native(self, native: Any, name: str = "", /) -> Self: + def _with_native(self, native: Any, name: str, /) -> Self: return self.from_native(native, name or self.name, self.version) # series & scalar diff --git a/tests/plan/to_compliant_test.py b/tests/plan/to_compliant_test.py index ae508b51f0..de0702766e 100644 --- a/tests/plan/to_compliant_test.py +++ b/tests/plan/to_compliant_test.py @@ -48,9 +48,6 @@ def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None: assert isinstance(compliant_expr, namespace._expr) -XFAIL_REQUIRES_BINARY_EXPR = pytest.mark.xfail( - reason="Requires `BinaryExpr` implementation.", raises=NotImplementedError -) XFAIL_REWRITE_SPECIAL_ALIASES = pytest.mark.xfail( reason="Bug in `meta` namespace impl", raises=ComputeError ) @@ -72,8 +69,10 @@ def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None: {"e": [3, 3, 3], "d": [8, 8, 8], "c": [9, 2, 4]}, ), (nwd.col("b").sort(descending=True).alias("b_desc"), {"b_desc": [3, 2, 1]}), - pytest.param( - nwd.col("c").filter(a="B"), {"c": [2]}, marks=XFAIL_REQUIRES_BINARY_EXPR + (nwd.col("c").filter(a="B"), {"c": [2]}), + ( + [nwd.nth(0, 1).filter(nwd.col("c") >= 4), nwd.col("d").last() - 4], + {"a": ["A", "A"], "b": [1, 3], "d": [4, 4]}, ), (nwd.col("b").cast(nw.Float64()), {"b": [1.0, 2.0, 3.0]}), (nwd.lit(1).cast(nw.Float64()).alias("literal_cast"), {"literal_cast": [1.0]}), From d5ccd31540ae561b70a2b5c1e278811404de3b71 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 9 Jul 2025 22:32:07 +0100 Subject: [PATCH 303/368] =?UTF-8?q?=F0=9F=A7=B9=F0=9F=A7=B9=F0=9F=A7=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- narwhals/_plan/aggregation.py | 3 --- narwhals/_plan/boolean.py | 5 ----- narwhals/_plan/common.py | 2 -- narwhals/_plan/expr.py | 20 -------------------- narwhals/_plan/functions.py | 19 ------------------- narwhals/_plan/literal.py | 2 -- narwhals/_plan/name.py | 4 ---- narwhals/_plan/options.py | 10 ---------- narwhals/_plan/ranges.py | 1 - narwhals/_plan/selectors.py | 3 --- narwhals/_plan/strings.py | 12 ------------ narwhals/_plan/struct.py | 1 - narwhals/_plan/temporal.py | 5 ----- narwhals/_plan/when_then.py | 4 ---- 14 files changed, 91 deletions(-) diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index 0e5665fcb0..33d5ed8849 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -16,7 +16,6 @@ class Agg(ExprIR): __slots__ = ("expr",) - expr: ExprIR @property @@ -84,7 +83,6 @@ class Quantile(Agg): class Std(Agg): __slots__ = (*Agg.__slots__, "ddof") - ddof: int @@ -93,7 +91,6 @@ class Sum(Agg): ... class Var(Agg): __slots__ = (*Agg.__slots__, "ddof") - ddof: int diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/boolean.py index 1779d6e885..cbad3e0025 100644 --- a/narwhals/_plan/boolean.py +++ b/narwhals/_plan/boolean.py @@ -72,17 +72,13 @@ def function_options(self) -> FunctionOptions: ) -# NOTE: `lower_bound`, `upper_bound` aren't spec'd in the function enum. class IsBetween(BooleanFunction): """`lower_bound`, `upper_bound` aren't spec'd in the function enum. - Assuming the `FunctionExpr.input` becomes `s` in the impl - https://github.com/pola-rs/polars/blob/62257860a43ec44a638e8492ed2cf98a49c05f2e/crates/polars-plan/src/dsl/function_expr/boolean.rs#L225-L237 """ __slots__ = ("closed",) - closed: ClosedInterval @property @@ -110,7 +106,6 @@ def function_options(self) -> FunctionOptions: class IsIn(BooleanFunction, t.Generic[OtherT]): __slots__ = ("other",) - other: OtherT @property diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 675b4b8f2a..9ca7e03b07 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -363,7 +363,6 @@ def is_elementwise_top_level(self) -> bool: class IRNamespace(Immutable): __slots__ = ("_ir",) - _ir: ExprIR @classmethod @@ -373,7 +372,6 @@ def from_expr(cls, expr: DummyExpr, /) -> Self: class ExprNamespace(Immutable, Generic[IRNamespaceT]): __slots__ = ("_expr",) - _expr: DummyExpr @property diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index e2f4db2000..d3595e5a85 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -86,7 +86,6 @@ def col(name: str, /) -> Column: class Alias(ExprIR): __slots__ = ("expr", "name") - expr: ExprIR name: str @@ -114,7 +113,6 @@ def with_expr(self, expr: ExprIR, /) -> Self: class Column(ExprIR): __slots__ = ("name",) - name: str def __repr__(self) -> str: @@ -139,7 +137,6 @@ def map_ir(self, function: MapIR, /) -> ExprIR: class Columns(_ColumnSelection): __slots__ = ("names",) - names: Seq[str] def __repr__(self) -> str: @@ -151,7 +148,6 @@ def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: class Nth(_ColumnSelection): __slots__ = ("index",) - index: int def __repr__(self) -> str: @@ -167,7 +163,6 @@ class IndexColumns(_ColumnSelection): """ __slots__ = ("indices",) - indices: Seq[int] def __repr__(self) -> str: @@ -186,7 +181,6 @@ def __repr__(self) -> str: class Exclude(_ColumnSelection): __slots__ = ("expr", "names") - expr: ExprIR """Default is `all()`.""" names: Seq[str] @@ -226,7 +220,6 @@ class Literal(ExprIR, t.Generic[LiteralT]): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L81.""" __slots__ = ("value",) - value: LiteralValue[LiteralT] @property @@ -259,7 +252,6 @@ def map_ir(self, function: MapIR, /) -> ExprIR: class _BinaryOp(ExprIR, t.Generic[LeftT, OperatorT, RightT]): __slots__ = ("left", "op", "right") - left: LeftT op: OperatorT right: RightT @@ -310,7 +302,6 @@ def map_ir(self, function: MapIR, /) -> ExprIR: class Cast(ExprIR): __slots__ = ("expr", "dtype") # noqa: RUF023 - expr: ExprIR dtype: DType @@ -341,7 +332,6 @@ def with_expr(self, expr: ExprIR, /) -> Self: class Sort(ExprIR): __slots__ = ("expr", "options") - expr: ExprIR options: SortOptions @@ -375,7 +365,6 @@ class SortBy(ExprIR): """https://github.com/narwhals-dev/narwhals/issues/2534.""" __slots__ = ("expr", "by", "options") # noqa: RUF023 - expr: ExprIR by: Seq[ExprIR] options: SortMultipleOptions @@ -426,7 +415,6 @@ class FunctionExpr(ExprIR, t.Generic[FunctionT]): """ __slots__ = ("function", "input", "options") - input: Seq[ExprIR] function: FunctionT """Operation applied to each element of `input`. @@ -545,7 +533,6 @@ def __repr__(self) -> str: class Filter(ExprIR): __slots__ = ("expr", "by") # noqa: RUF023 - expr: ExprIR by: ExprIR @@ -590,15 +577,12 @@ class WindowExpr(ExprIR): """ __slots__ = ("expr", "partition_by", "options") # noqa: RUF023 - expr: ExprIR """Renamed from `function`. For lazy backends, this should be the only place we allow `rolling_*`, `cum_*`. """ - partition_by: Seq[ExprIR] - options: Window """Currently **always** represents over. @@ -645,7 +629,6 @@ def with_partition_by(self, partition_by: t.Iterable[ExprIR], /) -> Self: # TODO @dangotbanned: Reduce repetition from `WindowExpr` class OrderedWindowExpr(WindowExpr): __slots__ = ("expr", "partition_by", "order_by", "sort_options", "options") # noqa: RUF023 - expr: ExprIR partition_by: Seq[ExprIR] order_by: Seq[ExprIR] @@ -750,7 +733,6 @@ class RootSelector(SelectorIR): """A single selector expression.""" __slots__ = ("selector",) - selector: Selector """by_dtype, matches, numeric, boolean, string, categorical, datetime, all.""" @@ -786,7 +768,6 @@ def map_ir(self, function: MapIR, /) -> ExprIR: class InvertSelector(SelectorIR, t.Generic[SelectorT]): __slots__ = ("selector",) - selector: SelectorT """`(Root|Binary)Selector`.""" @@ -804,7 +785,6 @@ class Ternary(ExprIR): """When-Then-Otherwise.""" __slots__ = ("predicate", "truthy", "falsy") # noqa: RUF023 - predicate: ExprIR truthy: ExprIR falsy: ExprIR diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 1eeb8e6f50..48a300ec7f 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -33,7 +33,6 @@ class Hist(Function): """Only supported for `Series` so far.""" __slots__ = ("include_breakpoint",) - include_breakpoint: bool @property @@ -45,10 +44,7 @@ def __repr__(self) -> str: class HistBins(Hist): - """Subclasses for each variant.""" - __slots__ = ("bins", *Hist.__slots__) - bins: Seq[float] def __init__(self, *, bins: Seq[float], include_breakpoint: bool = True) -> None: @@ -61,7 +57,6 @@ def __init__(self, *, bins: Seq[float], include_breakpoint: bool = True) -> None class HistBinCount(Hist): __slots__ = ("bin_count", *Hist.__slots__) - bin_count: int """Polars (v1.20) sets `bin_count=10` if neither `bins` or `bin_count` are provided.""" @@ -81,7 +76,6 @@ def __repr__(self) -> str: class Log(Function): __slots__ = ("base",) - base: float @property @@ -121,7 +115,6 @@ def __repr__(self) -> str: class Kurtosis(Function): __slots__ = ("bias", "fisher") - fisher: bool bias: bool @@ -149,7 +142,6 @@ class FillNullWithStrategy(Function): """ __slots__ = ("limit", "strategy") - strategy: FillNullStrategy limit: int | None @@ -169,9 +161,7 @@ def __repr__(self) -> str: class Shift(Function): __slots__ = ("n",) - n: int - """https://github.com/narwhals-dev/narwhals/pull/2555""" @property def function_options(self) -> FunctionOptions: @@ -210,7 +200,6 @@ def __repr__(self) -> str: class Rank(Function): __slots__ = ("options",) - options: RankOptions @property @@ -232,9 +221,7 @@ def __repr__(self) -> str: class CumAgg(Function): __slots__ = ("reverse",) - reverse: bool - """https://github.com/narwhals-dev/narwhals/pull/2555""" @property def function_options(self) -> FunctionOptions: @@ -256,7 +243,6 @@ def __repr__(self) -> str: class RollingWindow(Function): __slots__ = ("options",) - options: RollingOptionsFixedWindow @property @@ -330,7 +316,6 @@ def __repr__(self) -> str: class Round(Function): __slots__ = ("decimals",) - decimals: int @property @@ -387,7 +372,6 @@ def __repr__(self) -> str: class EwmMean(Function): __slots__ = ("options",) - options: EWMOptions @property @@ -400,7 +384,6 @@ def __repr__(self) -> str: class ReplaceStrict(Function): __slots__ = ("new", "old", "return_dtype") - old: Seq[Any] new: Seq[Any] return_dtype: IntoDType | None @@ -415,7 +398,6 @@ def __repr__(self) -> str: class GatherEvery(Function): __slots__ = ("n", "offset") - n: int offset: int @@ -429,7 +411,6 @@ def __repr__(self) -> str: class MapBatches(Function): __slots__ = ("function", "is_elementwise", "return_dtype", "returns_scalar") - function: Udf return_dtype: IntoDType | None is_elementwise: bool diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index 9a0cb13e6e..a2579b72fb 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -39,7 +39,6 @@ def unwrap(self) -> LiteralT: class ScalarLiteral(LiteralValue[NonNestedLiteralT]): __slots__ = ("dtype", "value") - value: NonNestedLiteralT dtype: DType @@ -63,7 +62,6 @@ class SeriesLiteral(LiteralValue["DummySeries[NativeSeriesT]"]): """ __slots__ = ("value",) - value: DummySeries[NativeSeriesT] @property diff --git a/narwhals/_plan/name.py b/narwhals/_plan/name.py index a393b4b63b..3566574b42 100644 --- a/narwhals/_plan/name.py +++ b/narwhals/_plan/name.py @@ -18,7 +18,6 @@ class KeepName(ExprIR): """Keep the original root name.""" __slots__ = ("expr",) - expr: ExprIR @property @@ -45,7 +44,6 @@ def with_expr(self, expr: ExprIR, /) -> Self: class RenameAlias(ExprIR): __slots__ = ("expr", "function") - expr: ExprIR function: AliasName @@ -75,7 +73,6 @@ def with_expr(self, expr: ExprIR, /) -> Self: class Prefix(Immutable): __slots__ = ("prefix",) - prefix: str def __call__(self, name: str, /) -> str: @@ -84,7 +81,6 @@ def __call__(self, name: str, /) -> str: class Suffix(Immutable): __slots__ = ("suffix",) - suffix: str def __call__(self, name: str, /) -> str: diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 7e5229f4a1..63dbd59250 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -77,7 +77,6 @@ class FunctionOptions(Immutable): """ __slots__ = ("flags",) - flags: FunctionFlags def __str__(self) -> str: @@ -140,7 +139,6 @@ def aggregation() -> FunctionOptions: class SortOptions(Immutable): __slots__ = ("descending", "nulls_last") - descending: bool nulls_last: bool @@ -163,7 +161,6 @@ def to_arrow(self) -> pc.ArraySortOptions: class SortMultipleOptions(Immutable): __slots__ = ("descending", "nulls_last") - descending: Seq[bool] nulls_last: Seq[bool] @@ -202,10 +199,7 @@ def to_arrow(self, by: Sequence[str]) -> pc.SortOptions: class RankOptions(Immutable): - """https://github.com/narwhals-dev/narwhals/pull/2555.""" - __slots__ = ("descending", "method") - method: RankMethod descending: bool @@ -225,7 +219,6 @@ class EWMOptions(Immutable): "min_samples", "span", ) - com: float | None span: float | None half_life: float | None @@ -237,7 +230,6 @@ class EWMOptions(Immutable): class RollingVarParams(Immutable): __slots__ = ("ddof",) - ddof: int @@ -245,10 +237,8 @@ class RollingOptionsFixedWindow(Immutable): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-core/src/chunked_array/ops/rolling_window.rs#L10-L23.""" __slots__ = ("center", "fn_params", "min_samples", "window_size") - window_size: int min_samples: int """Renamed from `min_periods`, reuses `window_size` if null.""" - center: bool fn_params: RollingVarParams | None diff --git a/narwhals/_plan/ranges.py b/narwhals/_plan/ranges.py index a00d6b0e53..a757d835fd 100644 --- a/narwhals/_plan/ranges.py +++ b/narwhals/_plan/ranges.py @@ -42,7 +42,6 @@ class IntRange(RangeFunction): """ __slots__ = ("step", "dtype") # noqa: RUF023 - step: int dtype: IntegerType diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/selectors.py index 8124cb0c9e..2d82bbeccd 100644 --- a/narwhals/_plan/selectors.py +++ b/narwhals/_plan/selectors.py @@ -47,7 +47,6 @@ def matches_column(self, name: str, dtype: DType) -> bool: class ByDType(Selector): __slots__ = ("dtypes",) - dtypes: frozenset[DType | type[DType]] @staticmethod @@ -91,7 +90,6 @@ class Datetime(Selector): """ __slots__ = ("time_units", "time_zones") - time_units: frozenset[TimeUnit] time_zones: frozenset[str | None] @@ -120,7 +118,6 @@ def matches_column(self, name: str, dtype: DType) -> bool: class Matches(Selector): __slots__ = ("pattern",) - pattern: re.Pattern[str] @staticmethod diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/strings.py index bcf2bb4aee..c13cdf4486 100644 --- a/narwhals/_plan/strings.py +++ b/narwhals/_plan/strings.py @@ -22,7 +22,6 @@ class ConcatHorizontal(StringFunction): """`nw.functions.concat_str`.""" __slots__ = ("ignore_nulls", "separator") - separator: str ignore_nulls: bool @@ -36,7 +35,6 @@ def __repr__(self) -> str: class Contains(StringFunction): __slots__ = ("literal", "pattern") - pattern: str literal: bool @@ -46,7 +44,6 @@ def __repr__(self) -> str: class EndsWith(StringFunction): __slots__ = ("suffix",) - suffix: str def __repr__(self) -> str: @@ -60,7 +57,6 @@ def __repr__(self) -> str: class Replace(StringFunction): __slots__ = ("literal", "n", "pattern", "value") - pattern: str value: str literal: bool @@ -77,7 +73,6 @@ class ReplaceAll(StringFunction): """ __slots__ = ("literal", "pattern", "value") - pattern: str value: str literal: bool @@ -90,12 +85,9 @@ class Slice(StringFunction): """We're using for `Head`, `Tail` as well. https://github.com/dangotbanned/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/function_expr/strings.rs#L87-L89 - - I don't think it's likely we'll support `Expr` as inputs for this any time soon. """ __slots__ = ("length", "offset") - offset: int length: int | None @@ -105,7 +97,6 @@ def __repr__(self) -> str: class Split(StringFunction): __slots__ = ("by",) - by: str def __repr__(self) -> str: @@ -114,7 +105,6 @@ def __repr__(self) -> str: class StartsWith(StringFunction): __slots__ = ("prefix",) - prefix: str def __repr__(self) -> str: @@ -123,7 +113,6 @@ def __repr__(self) -> str: class StripChars(StringFunction): __slots__ = ("characters",) - characters: str | None def __repr__(self) -> str: @@ -139,7 +128,6 @@ class ToDatetime(StringFunction): """ __slots__ = ("format",) - format: str | None def __repr__(self) -> str: diff --git a/narwhals/_plan/struct.py b/narwhals/_plan/struct.py index 6ad3770cf4..d9dd42a730 100644 --- a/narwhals/_plan/struct.py +++ b/narwhals/_plan/struct.py @@ -16,7 +16,6 @@ class FieldByName(StructFunction): """https://github.com/pola-rs/polars/blob/62257860a43ec44a638e8492ed2cf98a49c05f2e/crates/polars-plan/src/dsl/function_expr/struct_.rs#L11.""" __slots__ = ("name",) - name: str @property diff --git a/narwhals/_plan/temporal.py b/narwhals/_plan/temporal.py index 6069d95a6a..301bfc95a4 100644 --- a/narwhals/_plan/temporal.py +++ b/narwhals/_plan/temporal.py @@ -110,25 +110,21 @@ class TotalNanoseconds(TemporalFunction): ... class ToString(TemporalFunction): __slots__ = ("format",) - format: str class ReplaceTimeZone(TemporalFunction): __slots__ = ("time_zone",) - time_zone: str | None class ConvertTimeZone(TemporalFunction): __slots__ = ("time_zone",) - time_zone: str class Timestamp(TemporalFunction): __slots__ = ("time_unit",) - time_unit: PolarsTimeUnit @staticmethod @@ -146,7 +142,6 @@ def from_time_unit(time_unit: TimeUnit, /) -> Timestamp: class Truncate(TemporalFunction): __slots__ = ("multiple", "unit") - multiple: int unit: IntervalUnit diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index ab9ae66933..ef7ab9f6b3 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -19,7 +19,6 @@ class When(Immutable): __slots__ = ("condition",) - condition: ExprIR def then(self, expr: IntoExpr, /) -> Then: @@ -36,7 +35,6 @@ def _from_ir(ir: ExprIR, /) -> When: class Then(Immutable, DummyExpr): __slots__ = ("condition", "statement") - condition: ExprIR statement: ExprIR @@ -70,7 +68,6 @@ def __eq__(self, value: object) -> DummyExpr | bool: # type: ignore[override] class ChainedWhen(Immutable): __slots__ = ("conditions", "statements") - conditions: Seq[ExprIR] statements: Seq[ExprIR] @@ -85,7 +82,6 @@ class ChainedThen(Immutable, DummyExpr): """https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/arity.rs#L89-L130.""" __slots__ = ("conditions", "statements") - conditions: Seq[ExprIR] statements: Seq[ExprIR] From a40b0448cdadca05805ede0776a554c95c0d45d2 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 10 Jul 2025 14:14:43 +0100 Subject: [PATCH 304/368] fix: Align all binary ops with `polars` - Added missing right ops - Fixed incorrect right ops - Only use `str_as_lit` in comparisons and `__add__` - Remove unnecessary `alias` (`lit` handles that already) - Add lots of tests --- narwhals/_plan/dummy.py | 107 ++++++++++++++++++++++---------- tests/plan/expr_parsing_test.py | 71 ++++++++++++++++++++- 2 files changed, 142 insertions(+), 36 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 3f2ffb50e4..e30f02d156 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -538,63 +538,102 @@ def __add__(self, other: IntoExpr) -> Self: rhs = parse.parse_into_expr_ir(other, str_as_lit=True) return self._from_ir(op.to_binary_expr(self._ir, rhs)) + def __radd__(self, other: IntoExpr) -> Self: + op = ops.Add() + lhs = parse.parse_into_expr_ir(other, str_as_lit=True) + return self._from_ir(op.to_binary_expr(lhs, self._ir)) + def __sub__(self, other: IntoExpr) -> Self: op = ops.Sub() - rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + rhs = parse.parse_into_expr_ir(other) return self._from_ir(op.to_binary_expr(self._ir, rhs)) + def __rsub__(self, other: IntoExpr) -> Self: + op = ops.Sub() + lhs = parse.parse_into_expr_ir(other) + return self._from_ir(op.to_binary_expr(lhs, self._ir)) + def __mul__(self, other: IntoExpr) -> Self: op = ops.Multiply() - rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + rhs = parse.parse_into_expr_ir(other) return self._from_ir(op.to_binary_expr(self._ir, rhs)) + def __rmul__(self, other: IntoExpr) -> Self: + op = ops.Multiply() + lhs = parse.parse_into_expr_ir(other) + return self._from_ir(op.to_binary_expr(lhs, self._ir)) + def __truediv__(self, other: IntoExpr) -> Self: op = ops.TrueDivide() - rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + rhs = parse.parse_into_expr_ir(other) return self._from_ir(op.to_binary_expr(self._ir, rhs)) + def __rtruediv__(self, other: IntoExpr) -> Self: + op = ops.TrueDivide() + lhs = parse.parse_into_expr_ir(other) + return self._from_ir(op.to_binary_expr(lhs, self._ir)) + def __floordiv__(self, other: IntoExpr) -> Self: op = ops.FloorDivide() - rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + rhs = parse.parse_into_expr_ir(other) return self._from_ir(op.to_binary_expr(self._ir, rhs)) + def __rfloordiv__(self, other: IntoExpr) -> Self: + op = ops.FloorDivide() + lhs = parse.parse_into_expr_ir(other) + return self._from_ir(op.to_binary_expr(lhs, self._ir)) + def __mod__(self, other: IntoExpr) -> Self: op = ops.Modulus() - rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + rhs = parse.parse_into_expr_ir(other) return self._from_ir(op.to_binary_expr(self._ir, rhs)) - def __and__(self, other: IntoExpr) -> Self: + def __rmod__(self, other: IntoExpr) -> Self: + op = ops.Modulus() + lhs = parse.parse_into_expr_ir(other) + return self._from_ir(op.to_binary_expr(lhs, self._ir)) + + def __and__(self, other: IntoExprColumn | int | bool) -> Self: op = ops.And() - rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + rhs = parse.parse_into_expr_ir(other) return self._from_ir(op.to_binary_expr(self._ir, rhs)) - def __rand__(self, other: IntoExpr) -> Self: - return (self & other).alias("literal") + def __rand__(self, other: IntoExprColumn | int | bool) -> Self: + op = ops.And() + lhs = parse.parse_into_expr_ir(other) + return self._from_ir(op.to_binary_expr(lhs, self._ir)) - def __or__(self, other: IntoExpr) -> Self: + def __or__(self, other: IntoExprColumn | int | bool) -> Self: op = ops.Or() - rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + rhs = parse.parse_into_expr_ir(other) return self._from_ir(op.to_binary_expr(self._ir, rhs)) - def __ror__(self, other: IntoExpr) -> Self: - return (self | other).alias("literal") + def __ror__(self, other: IntoExprColumn | int | bool) -> Self: + op = ops.Or() + lhs = parse.parse_into_expr_ir(other) + return self._from_ir(op.to_binary_expr(lhs, self._ir)) - def __xor__(self, other: IntoExpr) -> Self: + def __xor__(self, other: IntoExprColumn | int | bool) -> Self: op = ops.ExclusiveOr() - rhs = parse.parse_into_expr_ir(other, str_as_lit=True) + rhs = parse.parse_into_expr_ir(other) return self._from_ir(op.to_binary_expr(self._ir, rhs)) - def __rxor__(self, other: IntoExpr) -> Self: - return (self ^ other).alias("literal") + def __rxor__(self, other: IntoExprColumn | int | bool) -> Self: + op = ops.ExclusiveOr() + lhs = parse.parse_into_expr_ir(other) + return self._from_ir(op.to_binary_expr(lhs, self._ir)) + + def __pow__(self, exponent: IntoExprColumn | float) -> Self: + exp = parse.parse_into_expr_ir(exponent) + return self._from_ir(F.Pow().to_function_expr(self._ir, exp)) + + def __rpow__(self, base: IntoExprColumn | float) -> Self: + base_ = parse.parse_into_expr_ir(base) + return self._from_ir(F.Pow().to_function_expr(base_, self._ir)) def __invert__(self) -> Self: return self._from_ir(boolean.Not().to_function_expr(self._ir)) - def __pow__(self, other: IntoExpr) -> Self: - exponent = parse.parse_into_expr_ir(other, str_as_lit=True) - base = self._ir - return self._from_ir(F.Pow().to_function_expr(base, exponent)) - @property def meta(self) -> IRMetaNamespace: from narwhals._plan.meta import IRMetaNamespace @@ -679,8 +718,8 @@ def _to_expr(self) -> DummyExpr: @t.overload # type: ignore[override] def __or__(self, other: Self) -> Self: ... @t.overload - def __or__(self, other: IntoExpr) -> DummyExpr: ... - def __or__(self, other: IntoExpr) -> Self | DummyExpr: + def __or__(self, other: IntoExprColumn | int | bool) -> DummyExpr: ... + def __or__(self, other: IntoExprColumn | int | bool) -> Self | DummyExpr: if isinstance(other, type(self)): op = ops.Or() return self._from_ir(op.to_binary_selector(self._ir, other._ir)) @@ -689,8 +728,8 @@ def __or__(self, other: IntoExpr) -> Self | DummyExpr: @t.overload # type: ignore[override] def __and__(self, other: Self) -> Self: ... @t.overload - def __and__(self, other: IntoExpr) -> DummyExpr: ... - def __and__(self, other: IntoExpr) -> Self | DummyExpr: + def __and__(self, other: IntoExprColumn | int | bool) -> DummyExpr: ... + def __and__(self, other: IntoExprColumn | int | bool) -> Self | DummyExpr: if is_column(other) and (name := other.meta.output_name()): other = by_name(name) if isinstance(other, type(self)): @@ -711,8 +750,8 @@ def __sub__(self, other: IntoExpr) -> Self | DummyExpr: @t.overload # type: ignore[override] def __xor__(self, other: Self) -> Self: ... @t.overload - def __xor__(self, other: IntoExpr) -> DummyExpr: ... - def __xor__(self, other: IntoExpr) -> Self | DummyExpr: + def __xor__(self, other: IntoExprColumn | int | bool) -> DummyExpr: ... + def __xor__(self, other: IntoExprColumn | int | bool) -> Self | DummyExpr: if isinstance(other, type(self)): op = ops.ExclusiveOr() return self._from_ir(op.to_binary_selector(self._ir, other._ir)) @@ -738,8 +777,8 @@ def __rsub__(self, other: t.Any) -> Never: @t.overload # type: ignore[override] def __rand__(self, other: Self) -> Self: ... @t.overload - def __rand__(self, other: IntoExpr) -> DummyExpr: ... - def __rand__(self, other: IntoExpr) -> Self | DummyExpr: + def __rand__(self, other: IntoExprColumn | int | bool) -> DummyExpr: ... + def __rand__(self, other: IntoExprColumn | int | bool) -> Self | DummyExpr: if is_column(other) and (name := other.meta.output_name()): return by_name(name) & self return self._to_expr().__rand__(other) @@ -747,8 +786,8 @@ def __rand__(self, other: IntoExpr) -> Self | DummyExpr: @t.overload # type: ignore[override] def __ror__(self, other: Self) -> Self: ... @t.overload - def __ror__(self, other: IntoExpr) -> DummyExpr: ... - def __ror__(self, other: IntoExpr) -> Self | DummyExpr: + def __ror__(self, other: IntoExprColumn | int | bool) -> DummyExpr: ... + def __ror__(self, other: IntoExprColumn | int | bool) -> Self | DummyExpr: if is_column(other) and (name := other.meta.output_name()): return by_name(name) | self return self._to_expr().__ror__(other) @@ -756,8 +795,8 @@ def __ror__(self, other: IntoExpr) -> Self | DummyExpr: @t.overload # type: ignore[override] def __rxor__(self, other: Self) -> Self: ... @t.overload - def __rxor__(self, other: IntoExpr) -> DummyExpr: ... - def __rxor__(self, other: IntoExpr) -> Self | DummyExpr: + def __rxor__(self, other: IntoExprColumn | int | bool) -> DummyExpr: ... + def __rxor__(self, other: IntoExprColumn | int | bool) -> Self | DummyExpr: if is_column(other) and (name := other.meta.output_name()): return by_name(name) ^ self return self._to_expr().__rxor__(other) diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 955e09575b..2043959350 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -1,8 +1,9 @@ from __future__ import annotations +import operator import re from collections import deque -from collections.abc import Iterable, Sequence +from collections.abc import Iterable, Mapping, Sequence from contextlib import nullcontext from typing import TYPE_CHECKING, Any, Callable @@ -14,6 +15,7 @@ boolean, expr, functions as F, # noqa: N812 + operators as ops, ) from narwhals._plan.common import ExprIR, Function from narwhals._plan.dummy import DummyExpr, DummySeries @@ -34,7 +36,7 @@ from typing_extensions import TypeAlias - from narwhals._plan.typing import IntoExpr, IntoExprColumn, Seq + from narwhals._plan.typing import IntoExpr, IntoExprColumn, OperatorFn, Seq IntoIterable: TypeAlias = Callable[[Sequence[Any]], Iterable[Any]] @@ -386,3 +388,68 @@ def test_lit_series_roundtrip() -> None: assert isinstance(unwrapped, DummySeries) assert isinstance(unwrapped.to_native(), pa.ChunkedArray) assert unwrapped.to_list() == data + + +@pytest.mark.parametrize( + ("arg_1", "arg_2", "function", "op"), + [ + (nwd.col("a"), 1, operator.eq, ops.Eq), + (nwd.col("a"), "b", operator.eq, ops.Eq), + (nwd.col("a"), 1, operator.ne, ops.NotEq), + (nwd.col("a"), "b", operator.ne, ops.NotEq), + (nwd.col("a"), "b", operator.ge, ops.GtEq), + (nwd.col("a"), "b", operator.gt, ops.Gt), + (nwd.col("a"), "b", operator.le, ops.LtEq), + (nwd.col("a"), "b", operator.lt, ops.Lt), + ((nwd.col("a") != 1), False, operator.and_, ops.And), + ((nwd.col("a") != 1), False, operator.or_, ops.Or), + ((nwd.col("a")), True, operator.xor, ops.ExclusiveOr), + (nwd.col("a"), 6, operator.add, ops.Add), + (nwd.col("a"), 2.1, operator.mul, ops.Multiply), + (nwd.col("a"), nwd.col("b"), operator.sub, ops.Sub), + (nwd.col("a"), 2, operator.pow, F.Pow), + (nwd.col("a"), 2, operator.mod, ops.Modulus), + (nwd.col("a"), 2, operator.floordiv, ops.FloorDivide), + (nwd.col("a"), 4, operator.truediv, ops.TrueDivide), + ], +) +def test_operators_left_right( + arg_1: IntoExpr, + arg_2: IntoExpr, + function: OperatorFn, + op: type[ops.Operator | Function], +) -> None: + inverse: Mapping[type[ops.Operator], type[ops.Operator]] = { + ops.Gt: ops.Lt, + ops.Lt: ops.Gt, + ops.GtEq: ops.LtEq, + ops.LtEq: ops.GtEq, + } + result_1 = function(arg_1, arg_2) + result_2 = function(arg_2, arg_1) + assert isinstance(result_1, DummyExpr) + assert isinstance(result_2, DummyExpr) + ir_1 = result_1._ir + ir_2 = result_2._ir + if op in {ops.Eq, ops.NotEq}: + assert ir_1 == ir_2 + else: + assert ir_1 != ir_2 + if issubclass(op, ops.Operator): + assert isinstance(ir_1, BinaryExpr) + assert isinstance(ir_1.op, op) + assert isinstance(ir_2, BinaryExpr) + op_inverse = inverse.get(op, op) + assert isinstance(ir_2.op, op_inverse) + if op in {ops.Eq, ops.NotEq, *inverse}: + assert ir_1.left == ir_2.left + assert ir_1.right == ir_2.right + else: + assert ir_1.left == ir_2.right + assert ir_1.right == ir_2.left + else: + assert isinstance(ir_1, FunctionExpr) + assert isinstance(ir_1.function, op) + assert isinstance(ir_2, FunctionExpr) + assert isinstance(ir_2.function, op) + assert tuple(reversed(ir_2.input)) == ir_1.input From 1880cefd7cb3acb975939514b5fd79d31fe29928 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 10 Jul 2025 14:36:03 +0100 Subject: [PATCH 305/368] fix: Fill in some `DType` holes --- narwhals/_plan/dummy.py | 13 ++++++++----- narwhals/_plan/functions.py | 7 ++++--- tests/plan/expr_expansion_test.py | 8 ++++---- tests/plan/expr_parsing_test.py | 7 +++++-- tests/plan/expr_rewrites_test.py | 4 ++-- tests/plan/to_compliant_test.py | 4 ++-- 6 files changed, 25 insertions(+), 18 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index e30f02d156..541c463c60 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -15,7 +15,7 @@ functions as F, # noqa: N812 operators as ops, ) -from narwhals._plan.common import NamedIR, is_column, is_expr, is_series +from narwhals._plan.common import NamedIR, into_dtype, is_column, is_expr, is_series from narwhals._plan.contexts import ExprContext from narwhals._plan.options import ( EWMOptions, @@ -30,7 +30,6 @@ from narwhals._plan.window import Over from narwhals._utils import Version from narwhals.dependencies import is_pyarrow_chunked_array, is_pyarrow_table -from narwhals.dtypes import DType from narwhals.exceptions import ComputeError, InvalidOperationError from narwhals.schema import Schema @@ -51,6 +50,7 @@ from narwhals._plan.struct import ExprStructNamespace from narwhals._plan.temporal import ExprDateTimeNamespace from narwhals._plan.typing import ExprT, IntoExpr, IntoExprColumn, Ns, Seq, Udf + from narwhals.dtypes import DType from narwhals.typing import ( ClosedInterval, FillNullStrategy, @@ -114,9 +114,8 @@ def version(self) -> Version: def alias(self, name: str) -> Self: return self._from_ir(expr.Alias(expr=self._ir, name=name)) - def cast(self, dtype: DType | type[DType]) -> Self: - dtype = dtype if isinstance(dtype, DType) else self.version.dtypes.Unknown() - return self._from_ir(self._ir.cast(dtype)) + def cast(self, dtype: IntoDType) -> Self: + return self._from_ir(self._ir.cast(into_dtype(dtype))) def exclude(self, *names: str | t.Iterable[str]) -> Self: return self._from_ir(expr.Exclude.from_names(self._ir, *names)) @@ -429,6 +428,8 @@ def replace_strict( else: before = tuple(old) after = tuple(new) + if return_dtype is not None: + return_dtype = into_dtype(return_dtype) function = F.ReplaceStrict(old=before, new=after, return_dtype=return_dtype) return self._from_ir(function.to_function_expr(self._ir)) @@ -443,6 +444,8 @@ def map_batches( is_elementwise: bool = False, returns_scalar: bool = False, ) -> Self: + if return_dtype is not None: + return_dtype = into_dtype(return_dtype) return self._from_ir( F.MapBatches( function=function, diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 48a300ec7f..e3fed4953a 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -17,7 +17,8 @@ from narwhals._plan.expr import AnonymousExpr, RollingExpr from narwhals._plan.options import EWMOptions, RankOptions, RollingOptionsFixedWindow from narwhals._plan.typing import Seq, Udf - from narwhals.typing import FillNullStrategy, IntoDType + from narwhals.dtypes import DType + from narwhals.typing import FillNullStrategy class Abs(Function): @@ -386,7 +387,7 @@ class ReplaceStrict(Function): __slots__ = ("new", "old", "return_dtype") old: Seq[Any] new: Seq[Any] - return_dtype: IntoDType | None + return_dtype: DType | None @property def function_options(self) -> FunctionOptions: @@ -412,7 +413,7 @@ def __repr__(self) -> str: class MapBatches(Function): __slots__ = ("function", "is_elementwise", "return_dtype", "returns_scalar") function: Udf - return_dtype: IntoDType | None + return_dtype: DType | None is_elementwise: bool returns_scalar: bool diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 91ebe99752..93ef769dd7 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -298,7 +298,7 @@ def test_replace_selector( ), ( (ndcs.numeric() - ndcs.by_dtype(nw.Float32(), nw.Float64())) - .cast(nw.Int64()) + .cast(nw.Int64) .mean() .name.suffix("_mean"), [ @@ -337,18 +337,18 @@ def test_replace_selector( ), ( nwd.col("f", "g") - .cast(nw.String()) + .cast(nw.String) .str.starts_with("1") .all() .name.suffix("_all_starts_with_1"), [ nwd.col("f") - .cast(nw.String()) + .cast(nw.String) .str.starts_with("1") .all() .alias("f_all_starts_with_1"), nwd.col("g") - .cast(nw.String()) + .cast(nw.String) .str.starts_with("1") .all() .alias("g_all_starts_with_1"), diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 2043959350..e4182d40ae 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -247,9 +247,12 @@ def test_binary_expr_length_changing_agg() -> None: assert _is_expr_ir_binary_expr(a.drop_nulls().min() - b.mode()) assert _is_expr_ir_binary_expr(a.gather_every(2, 1) / b.drop_nulls().max()) assert _is_expr_ir_binary_expr( - b.gather_every(1, 0) / a.map_batches(lambda x: x, returns_scalar=True) + b.gather_every(1, 0) + / a.map_batches(lambda x: x, returns_scalar=True, return_dtype=nw.Float64) + ) + assert _is_expr_ir_binary_expr( + b.unique() * a.map_batches(lambda x: x, return_dtype=nw.Unknown).first() ) - assert _is_expr_ir_binary_expr(b.unique() * a.map_batches(lambda x: x).first()) def test_invalid_binary_expr_shape() -> None: diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index 6d1d9576b0..9f67b800e2 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -90,7 +90,7 @@ def named_ir(name: str, expr: DummyExpr | ExprIR, /) -> NamedIR[ExprIR]: def test_rewrite_elementwise_over_complex(schema_2: dict[str, DType]) -> None: expected = ( named_ir("a", nwd.col("a")), - named_ir("b", nwd.col("b").cast(nw.String())), + named_ir("b", nwd.col("b").cast(nw.String)), named_ir("x2", nwd.col("c").max().over("a").fill_null(50)), named_ir("d**", ~nwd.col("d").is_duplicated().over("b")), named_ir("f_some", nwd.col("f").str.contains("some")), @@ -102,7 +102,7 @@ def test_rewrite_elementwise_over_complex(schema_2: dict[str, DType]) -> None: ) before = ( nwd.col("a"), - nwd.col("b").cast(nw.String()), + nwd.col("b").cast(nw.String), ( _to_window_expr(nwd.col("c").max().alias("x").fill_null(50), "a") .to_narwhals() diff --git a/tests/plan/to_compliant_test.py b/tests/plan/to_compliant_test.py index de0702766e..5026c52ab2 100644 --- a/tests/plan/to_compliant_test.py +++ b/tests/plan/to_compliant_test.py @@ -60,7 +60,7 @@ def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None: (nwd.col("a", "b"), {"a": ["A", "B", "A"], "b": [1, 2, 3]}), (nwd.lit(1), {"literal": [1]}), (nwd.lit(2.0), {"literal": [2.0]}), - (nwd.lit(None, nw.String()), {"literal": [None]}), + (nwd.lit(None, nw.String), {"literal": [None]}), (nwd.col("a", "b").first(), {"a": ["A"], "b": [1]}), (nwd.col("d").max(), {"d": [8]}), ([nwd.len(), nwd.nth(3).last()], {"len": [3], "d": [8]}), @@ -75,7 +75,7 @@ def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None: {"a": ["A", "A"], "b": [1, 3], "d": [4, 4]}, ), (nwd.col("b").cast(nw.Float64()), {"b": [1.0, 2.0, 3.0]}), - (nwd.lit(1).cast(nw.Float64()).alias("literal_cast"), {"literal_cast": [1.0]}), + (nwd.lit(1).cast(nw.Float64).alias("literal_cast"), {"literal_cast": [1.0]}), pytest.param( nwd.lit(1).cast(nw.Float64()).name.suffix("_cast"), {"literal_cast": [1.0]}, From 4b5b9abfdd68659581e7bf23ad59ab5dd538512c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 10 Jul 2025 16:34:55 +0100 Subject: [PATCH 306/368] test: use `assert_equal_data` --- tests/plan/to_compliant_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/plan/to_compliant_test.py b/tests/plan/to_compliant_test.py index 5026c52ab2..4101e80e76 100644 --- a/tests/plan/to_compliant_test.py +++ b/tests/plan/to_compliant_test.py @@ -10,6 +10,7 @@ from narwhals.exceptions import ComputeError from narwhals.utils import Version from tests.namespace_test import backends +from tests.utils import assert_equal_data if TYPE_CHECKING: from collections.abc import Sequence @@ -105,7 +106,7 @@ def test_select( frame = pa.table(data_small) df = DummyFrame.from_native(frame) result = df.select(expr).to_dict(as_series=False) - assert result == expected + assert_equal_data(result, expected) if TYPE_CHECKING: From 2c0c0676ede0c47764b67f137f05128f91e9c49d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 10 Jul 2025 17:20:32 +0100 Subject: [PATCH 307/368] refactor: tighten up protocols --- narwhals/_plan/arrow/expr.py | 62 +++++++++++++++++------------------- narwhals/_plan/protocols.py | 54 +++++++++++-------------------- 2 files changed, 48 insertions(+), 68 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index f8a759097f..9af8979c73 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Callable, Mapping -from typing import TYPE_CHECKING, Any, cast, overload +from typing import TYPE_CHECKING, Any, Protocol, cast, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import @@ -16,6 +16,7 @@ from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.common import ExprIR, into_dtype from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch +from narwhals._typing_compat import TypeVar from narwhals._utils import Implementation, Version, _StoresNative, not_implemented from narwhals.exceptions import InvalidOperationError, ShapeError @@ -56,6 +57,8 @@ BinOp: TypeAlias = Callable[..., "ChunkedArrayAny | NativeScalar"] BACKEND_VERSION = Implementation.PYARROW._backend_version() +_StoresNativeAny: TypeAlias = _StoresNative[Any] +_StoresNativeT_co = TypeVar("_StoresNativeT_co", bound=_StoresNativeAny, covariant=True) def truediv_compat(lhs: Any, rhs: Any) -> Any: @@ -86,18 +89,31 @@ def modulus(lhs: Any, rhs: Any) -> Any: } -class ArrowExpr( - ExprDispatch["ArrowDataFrame", "ArrowExpr | ArrowScalar", "ArrowNamespace"], - _StoresNative["ChunkedArrayAny"], - EagerExpr["ArrowDataFrame", ArrowSeries], +class _ArrowDispatch( + ExprDispatch["ArrowDataFrame", _StoresNativeT_co, "ArrowNamespace"], Protocol ): - _evaluated: ArrowSeries - _version: Version + """Common to `Expr`, `Scalar` + their dependencies.""" def __narwhals_namespace__(self) -> ArrowNamespace: from narwhals._plan.arrow.namespace import ArrowNamespace - return ArrowNamespace(self._version) + return ArrowNamespace(self.version) + + def _with_native(self, native: Any, name: str, /) -> _StoresNativeT_co: ... + def cast( + self, node: expr.Cast, frame: ArrowDataFrame, name: str + ) -> _StoresNativeT_co: + data_type = narwhals_to_native_dtype(node.dtype, frame.version) + native = self._dispatch(node.expr, frame, name).native + return self._with_native(pc.cast(native, data_type), name) + + +class ArrowExpr( # type: ignore[misc] + _ArrowDispatch["ArrowExpr | ArrowScalar"], + _StoresNative["ChunkedArrayAny"], + EagerExpr["ArrowDataFrame", ArrowSeries], +): + _evaluated: ArrowSeries @property def name(self) -> str: @@ -129,7 +145,7 @@ def _with_native( ) -> ArrowScalar | Self: if isinstance(result, pa.Scalar): return ArrowScalar.from_native(result, name, version=self.version) - return super()._with_native(result, name) + return self.from_native(result, name or self.name, self.version) def _dispatch_expr( self, node: ExprIR, frame: ArrowDataFrame, name: str @@ -159,13 +175,6 @@ def broadcast(self, length: int, /) -> ArrowSeries: def __len__(self) -> int: return len(self._evaluated) - def cast( # type: ignore[override] - self, node: expr.Cast, frame: ArrowDataFrame, name: str - ) -> ArrowScalar | Self: - data_type = narwhals_to_native_dtype(node.dtype, frame.version) - native = self._dispatch(node.expr, frame, name).native - return self._with_native(pc.cast(native, data_type), name) - def sort(self, node: expr.Sort, frame: ArrowDataFrame, name: str) -> ArrowExpr: native = self._dispatch_expr(node.expr, frame, name).native sorted_indices = pc.array_sort_indices(native, options=node.options.to_arrow()) @@ -304,22 +313,11 @@ def chunked_array( class ArrowScalar( - ExprDispatch["ArrowDataFrame", "ArrowScalar", "ArrowNamespace"], + _ArrowDispatch["ArrowScalar"], _StoresNative[NativeScalar], EagerScalar["ArrowDataFrame", ArrowSeries], ): - _name: str _evaluated: NativeScalar - _version: Version - - def __narwhals_namespace__(self) -> ArrowNamespace: - from narwhals._plan.arrow.namespace import ArrowNamespace - - return ArrowNamespace(self._version) - - @property - def name(self) -> str: - return self._name @classmethod def from_native( @@ -370,6 +368,9 @@ def _dispatch_expr( msg = f"Expected unreachable, but hit at: {node!r}" raise InvalidOperationError(msg) + def _with_native(self, native: Any, name: str, /) -> Self: + return self.from_native(native, name or self.name, self.version) + @property def native(self) -> NativeScalar: return self._evaluated @@ -388,11 +389,6 @@ def broadcast(self, length: int) -> ArrowSeries: chunked = chunked_array(pa_repeat(scalar, length)) return ArrowSeries.from_native(chunked, self.name, version=self.version) - def cast(self, node: expr.Cast, frame: ArrowDataFrame, name: str) -> ArrowScalar: - data_type = narwhals_to_native_dtype(node.dtype, frame.version) - native = self._dispatch(node.expr, frame, name).native - return self._with_native(pc.cast(native, data_type), name) - def arg_min(self, node: ArgMin, frame: ArrowDataFrame, name: str) -> ArrowScalar: return self._with_native(pa.scalar(0), name) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 1ef4000ee8..d492a63ccb 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -46,6 +46,16 @@ ) +# NOTE: Unlike the version in `nw._utils`, here `.version` it is public +class StoresVersion(Protocol): + _version: Version + + @property + def version(self) -> Version: + """Narwhals API version (V1 or MAIN).""" + return self._version + + class SupportsBroadcast(Protocol[SeriesT, LengthT]): """Minimal broadcasting for `Expr` results.""" @@ -108,9 +118,10 @@ def _length_required( return max_length if required else None -class CompliantExpr(Protocol[FrameT_contra, SeriesT_co]): - """Getting a bit tricky, just storing notes. +class CompliantExpr(StoresVersion, Protocol[FrameT_contra, SeriesT_co]): + """Everything common to `Expr`/`Series` and `Scalar` literal values. + Early notes: - Separating series/scalar makes a lot of sense - Handling the recursive case *without* intermediate (non-pyarrow) objects seems unachievable - Everywhere would need to first check if it a scalar, which isn't ergonomic @@ -123,12 +134,6 @@ class CompliantExpr(Protocol[FrameT_contra, SeriesT_co]): _evaluated: Any """Compliant or native value.""" - _version: Version - - @property - def version(self) -> Version: - return self._version - @property def name(self) -> str: ... @@ -210,7 +215,7 @@ def min( ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... -class ExprDispatch(Protocol[FrameT_contra, R_co, NamespaceT_co]): +class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]): _DISPATCH: ClassVar[Mapping[type[ExprIR], Callable[[Any, ExprIR, Any, str], Any]]] = { expr.Column: lambda self, node, frame, name: self.__narwhals_namespace__().col( node, frame, name @@ -259,7 +264,6 @@ class ExprDispatch(Protocol[FrameT_contra, R_co, NamespaceT_co]): node, frame, name ), } - _version: Version def _dispatch(self, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: if (method := self._DISPATCH.get(node.__class__)) and ( @@ -287,16 +291,11 @@ class CompliantScalar( CompliantExpr[FrameT_contra, SeriesT_co], Protocol[FrameT_contra, SeriesT_co] ): _name: str - _version: Version @property def name(self) -> str: return self._name - @property - def version(self) -> Version: - return self._version - @classmethod def from_python( cls, @@ -414,7 +413,9 @@ class LazyScalar( ): ... -class CompliantNamespace(Protocol[FrameT, SeriesT_co, ExprT_co, ScalarT_co]): +class CompliantNamespace( + StoresVersion, Protocol[FrameT, SeriesT_co, ExprT_co, ScalarT_co] +): """Need to hold `Expr` and `Scalar` types outside of their defs. Likely, re-wrapping the output types will work like: @@ -429,8 +430,6 @@ class CompliantNamespace(Protocol[FrameT, SeriesT_co, ExprT_co, ScalarT_co]): assert_never(out) """ - _version: Version - @property def _dataframe(self) -> type[FrameT]: ... @property @@ -439,11 +438,6 @@ def _series(self) -> type[SeriesT_co]: ... def _expr(self) -> type[ExprT_co]: ... @property def _scalar(self) -> type[ScalarT_co]: ... - - @property - def version(self) -> Version: - return self._version - def col(self, node: expr.Column, frame: FrameT, name: str) -> ExprT_co: ... def lit( self, node: expr.Literal[Any], frame: FrameT, name: str @@ -480,9 +474,8 @@ def len(self, node: expr.Len, frame: FrameT, name: str) -> EagerScalarT_co: ) -class DummyCompliantFrame(Protocol[SeriesT, NativeFrameT, NativeSeriesT]): +class DummyCompliantFrame(StoresVersion, Protocol[SeriesT, NativeFrameT, NativeSeriesT]): _native: NativeFrameT - _version: Version def __narwhals_namespace__(self) -> Any: ... @@ -490,10 +483,6 @@ def __narwhals_namespace__(self) -> Any: ... def native(self) -> NativeFrameT: return self._native - @property - def version(self) -> Version: - return self._version - @property def columns(self) -> list[str]: ... @@ -553,10 +542,9 @@ def sort( ) -> Self: ... -class DummyCompliantSeries(Protocol[NativeSeriesT]): +class DummyCompliantSeries(StoresVersion, Protocol[NativeSeriesT]): _native: NativeSeriesT _name: str - _version: Version def __narwhals_series__(self) -> Self: return self @@ -565,10 +553,6 @@ def __narwhals_series__(self) -> Self: def native(self) -> NativeSeriesT: return self._native - @property - def version(self) -> Version: - return self._version - @property def dtype(self) -> DType: ... From cb2e247be974301d518b8a7f4822730e96155754 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 11 Jul 2025 14:15:34 +0100 Subject: [PATCH 308/368] refactor: rename `Agg` -> `AggExpr` --- narwhals/_plan/aggregation.py | 40 +++++++++++++++++------------------ narwhals/_plan/common.py | 11 +++++----- narwhals/_plan/demo.py | 4 ++-- narwhals/_plan/exceptions.py | 4 ++-- narwhals/_plan/expr.py | 6 +++--- 5 files changed, 32 insertions(+), 33 deletions(-) diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index 33d5ed8849..17bdaa0c8d 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -14,7 +14,7 @@ from narwhals.typing import RollingInterpolationMethod -class Agg(ExprIR): +class AggExpr(ExprIR): __slots__ = ("expr",) expr: ExprIR @@ -24,7 +24,7 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: tp = type(self) - if tp in {Agg, OrderableAgg}: + if tp in {AggExpr, OrderableAggExpr}: return tp.__name__ m = {ArgMin: "arg_min", ArgMax: "arg_max", NUnique: "n_unique"} name = m.get(tp, tp.__name__.lower()) @@ -56,56 +56,56 @@ def __init__(self, *, expr: ExprIR, **kwds: Any) -> None: super().__init__(expr=expr, **kwds) # pyright: ignore[reportCallIssue] -class Count(Agg): ... +class Count(AggExpr): ... -class Max(Agg): ... +class Max(AggExpr): ... -class Mean(Agg): ... +class Mean(AggExpr): ... -class Median(Agg): ... +class Median(AggExpr): ... -class Min(Agg): ... +class Min(AggExpr): ... -class NUnique(Agg): ... +class NUnique(AggExpr): ... -class Quantile(Agg): - __slots__ = (*Agg.__slots__, "interpolation", "quantile") +class Quantile(AggExpr): + __slots__ = (*AggExpr.__slots__, "interpolation", "quantile") quantile: float interpolation: RollingInterpolationMethod -class Std(Agg): - __slots__ = (*Agg.__slots__, "ddof") +class Std(AggExpr): + __slots__ = (*AggExpr.__slots__, "ddof") ddof: int -class Sum(Agg): ... +class Sum(AggExpr): ... -class Var(Agg): - __slots__ = (*Agg.__slots__, "ddof") +class Var(AggExpr): + __slots__ = (*AggExpr.__slots__, "ddof") ddof: int -class OrderableAgg(Agg): ... +class OrderableAggExpr(AggExpr): ... -class First(OrderableAgg): +class First(OrderableAggExpr): """https://github.com/narwhals-dev/narwhals/issues/2526.""" -class Last(OrderableAgg): +class Last(OrderableAggExpr): """https://github.com/narwhals-dev/narwhals/issues/2526.""" -class ArgMin(OrderableAgg): ... +class ArgMin(OrderableAggExpr): ... -class ArgMax(OrderableAgg): ... +class ArgMax(OrderableAggExpr): ... diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 9ca7e03b07..557242f5d9 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -31,7 +31,7 @@ from narwhals._plan import expr from narwhals._plan.dummy import DummyExpr, DummySelector, DummySeries from narwhals._plan.expr import ( - Agg, + AggExpr, BinaryExpr, Cast, Column, @@ -487,14 +487,13 @@ def is_binary_expr(obj: Any) -> TypeIs[BinaryExpr]: return isinstance(obj, BinaryExpr) -# TODO @dangotbanned: Rename `Agg` -> `AggExpr` -def is_agg_expr(obj: Any) -> TypeIs[Agg]: - from narwhals._plan.expr import Agg +def is_agg_expr(obj: Any) -> TypeIs[AggExpr]: + from narwhals._plan.expr import AggExpr - return isinstance(obj, Agg) + return isinstance(obj, AggExpr) -def is_aggregation(obj: Any) -> TypeIs[Agg | FunctionExpr[Any]]: +def is_aggregation(obj: Any) -> TypeIs[AggExpr | FunctionExpr[Any]]: """Superset of `ExprIR.is_scalar`, excludes literals & len.""" return is_agg_expr(obj) or (is_function_expr(obj) and obj.is_scalar) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index eaec1a096a..e0f3f106da 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -204,7 +204,7 @@ def _is_order_enforcing_previous(obj: t.Any) -> TypeIs[SortBy]: return isinstance(obj, allowed) -def _order_dependent_error(node: agg.OrderableAgg) -> OrderDependentExprError: +def _order_dependent_error(node: agg.OrderableAggExpr) -> OrderDependentExprError: previous = node.expr method = repr(node).removeprefix(f"{previous!r}.") msg = ( @@ -222,7 +222,7 @@ def _order_dependent_error(node: agg.OrderableAgg) -> OrderDependentExprError: def ensure_orderable_rules(*exprs: DummyExpr) -> tuple[DummyExpr, ...]: for expr in exprs: node = expr._ir - if isinstance(node, agg.OrderableAgg): + if isinstance(node, agg.OrderableAggExpr): previous = node.expr if not _is_order_enforcing_previous(previous): raise _order_dependent_error(node) diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index 0028e3f7a5..127471720d 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -24,7 +24,7 @@ import pandas as pd import polars as pl - from narwhals._plan.aggregation import Agg + from narwhals._plan.aggregation import AggExpr from narwhals._plan.common import ExprIR, Function from narwhals._plan.expr import FunctionExpr, WindowExpr from narwhals._plan.operators import Operator @@ -37,7 +37,7 @@ # TODO @dangotbanned: Use arguments in error message -def agg_scalar_error(agg: Agg, scalar: ExprIR, /) -> InvalidOperationError: # noqa: ARG001 +def agg_scalar_error(agg: AggExpr, scalar: ExprIR, /) -> InvalidOperationError: # noqa: ARG001 msg = "Can't apply aggregations to scalar-like expressions." return InvalidOperationError(msg) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index d3595e5a85..97657cea8e 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -6,7 +6,7 @@ # - Literal import typing as t -from narwhals._plan.aggregation import Agg, OrderableAgg +from narwhals._plan.aggregation import AggExpr, OrderableAggExpr from narwhals._plan.common import ( ExprIR, SelectorIR, @@ -49,7 +49,7 @@ from narwhals.dtypes import DType __all__ = [ - "Agg", + "AggExpr", "Alias", "All", "AnonymousExpr", @@ -66,7 +66,7 @@ "Len", "Literal", "Nth", - "OrderableAgg", + "OrderableAggExpr", "RenameAlias", "RollingExpr", "RootSelector", From 4947ec4cf8d9c564b98e902c6bd04bfc7959c06c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 11 Jul 2025 19:58:41 +0100 Subject: [PATCH 309/368] fix: Check for scalars in `int_range` `polars` applies rules like this when evaluating, but this is realistically a very cheap check to simplify things later --- narwhals/_plan/expr.py | 3 +++ tests/plan/expr_parsing_test.py | 18 ++++++++++++++++++ tests/plan/meta_test.py | 4 ++-- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 97657cea8e..cba6ed65b4 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -523,6 +523,9 @@ def __init__( if len(input) < 2: msg = f"Expected at least 2 inputs for `{function!r}()`, but got `{len(input)}`.\n`{input}`" raise InvalidOperationError(msg) + if not all(e.is_scalar for e in input): + msg = f"All inputs for `{function!r}()` must be scalar or aggregations, but got \n`{input}`" + raise InvalidOperationError(msg) super(ExprIR, self).__init__( **dict(input=input, function=function, options=options, **kwds) ) diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index e4182d40ae..79b1ff0be2 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -165,6 +165,24 @@ def test_invalid_agg_non_elementwise() -> None: def test_agg_non_elementwise_range_special() -> None: e = nwd.int_range(0, 100) assert isinstance(e._ir, RangeExpr) + e = nwd.int_range(nwd.len(), dtype=nw.UInt32).alias("index") + ir = e._ir + assert isinstance(ir, expr.Alias) + assert isinstance(ir.expr, RangeExpr) + assert isinstance(ir.expr.input[0], expr.Literal) + assert isinstance(ir.expr.input[1], expr.Len) + + +def test_invalid_int_range() -> None: + pattern = re.compile(r"scalar.+agg", re.IGNORECASE) + with pytest.raises(InvalidOperationError, match=pattern): + nwd.int_range(nwd.col("a")) + with pytest.raises(InvalidOperationError, match=pattern): + nwd.int_range(nwd.nth(1), 10) + with pytest.raises(InvalidOperationError, match=pattern): + nwd.int_range(0, nwd.col("a").abs()) + with pytest.raises(InvalidOperationError, match=pattern): + nwd.int_range(nwd.col("a") + 1) # NOTE: Non-`polars`` rule diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index 67f039a656..6d1f20fc68 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -174,8 +174,8 @@ def test_meta_root_names( nwd.int_range(0, 10), pl.int_range(0, 10), "literal", id="IntRange-Literal" ), pytest.param( - nwd.int_range(nwd.col("b"), 10), - pl.int_range(pl.col("b"), 10), + nwd.int_range(nwd.col("b").first(), 10), + pl.int_range(pl.col("b").first(), 10), "b", id="IntRange-Column", ), From 5c33fce10bb6bae79138af3b6249332af4791304 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 11 Jul 2025 20:00:58 +0100 Subject: [PATCH 310/368] feat: Identify n-ary functions Besides these and the horizontal (variadic) functions, everything else is unary --- narwhals/_plan/boolean.py | 13 ++++++++----- narwhals/_plan/functions.py | 14 +++++++++++++- narwhals/_plan/ranges.py | 8 +++++++- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/boolean.py index cbad3e0025..5e24c7c329 100644 --- a/narwhals/_plan/boolean.py +++ b/narwhals/_plan/boolean.py @@ -9,9 +9,11 @@ from narwhals._typing_compat import TypeVar if t.TYPE_CHECKING: + from typing_extensions import Self + from narwhals._plan.common import ExprIR from narwhals._plan.dummy import DummySeries - from narwhals._plan.expr import Literal # noqa: F401 + from narwhals._plan.expr import FunctionExpr, Literal # noqa: F401 from narwhals._plan.typing import NativeSeriesT, Seq # noqa: F401 from narwhals.typing import ClosedInterval @@ -73,10 +75,7 @@ def function_options(self) -> FunctionOptions: class IsBetween(BooleanFunction): - """`lower_bound`, `upper_bound` aren't spec'd in the function enum. - - https://github.com/pola-rs/polars/blob/62257860a43ec44a638e8492ed2cf98a49c05f2e/crates/polars-plan/src/dsl/function_expr/boolean.rs#L225-L237 - """ + """N-ary (expr, lower_bound, upper_bound).""" __slots__ = ("closed",) closed: ClosedInterval @@ -85,6 +84,10 @@ class IsBetween(BooleanFunction): def function_options(self) -> FunctionOptions: return FunctionOptions.elementwise() + def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR, ExprIR]: + expr, lower_bound, upper_bound = node.input + return expr, lower_bound, upper_bound + class IsDuplicated(BooleanFunction): @property diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index e3fed4953a..5021aaa2eb 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -14,7 +14,7 @@ from typing_extensions import Self from narwhals._plan.common import ExprIR - from narwhals._plan.expr import AnonymousExpr, RollingExpr + from narwhals._plan.expr import AnonymousExpr, FunctionExpr, RollingExpr from narwhals._plan.options import EWMOptions, RankOptions, RollingOptionsFixedWindow from narwhals._plan.typing import Seq, Udf from narwhals.dtypes import DType @@ -97,6 +97,8 @@ def __repr__(self) -> str: class Pow(Function): + """N-ary (base, exponent).""" + @property def function_options(self) -> FunctionOptions: return FunctionOptions.elementwise() @@ -104,6 +106,10 @@ def function_options(self) -> FunctionOptions: def __repr__(self) -> str: return "pow" + def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: + base, exponent = node.input + return base, exponent + class Sqrt(Function): @property @@ -128,6 +134,8 @@ def __repr__(self) -> str: class FillNull(Function): + """N-ary (expr, value).""" + @property def function_options(self) -> FunctionOptions: return FunctionOptions.elementwise() @@ -135,6 +143,10 @@ def function_options(self) -> FunctionOptions: def __repr__(self) -> str: return "fill_null" + def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: + expr, value = node.input + return expr, value + class FillNullWithStrategy(Function): """We don't support this variant in a lot of backends, so worth keeping it split out. diff --git a/narwhals/_plan/ranges.py b/narwhals/_plan/ranges.py index a757d835fd..cfa54df748 100644 --- a/narwhals/_plan/ranges.py +++ b/narwhals/_plan/ranges.py @@ -27,7 +27,9 @@ def to_function_expr(self, *inputs: ExprIR) -> RangeExpr[Self]: class IntRange(RangeFunction): - """Not implemented yet, but might push forward [#2722]. + """N-ary (start, end). + + Not implemented yet, but might push forward [#2722]. See [`rust` entrypoint], which is roughly: @@ -48,3 +50,7 @@ class IntRange(RangeFunction): @property def function_options(self) -> FunctionOptions: return FunctionOptions.row_separable() + + def unwrap_input(self, node: RangeExpr[Self], /) -> tuple[ExprIR, ExprIR]: + start, end = node.input + return start, end From 3c5663a73dbbb367e74074c2e46554c6a4276eec Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 11 Jul 2025 20:03:42 +0100 Subject: [PATCH 311/368] feat(DRAFT): Stub out namespace functions --- narwhals/_plan/arrow/namespace.py | 48 ++++++++++++++++++++++++++++++- narwhals/_plan/protocols.py | 32 ++++++++++++++++++++- 2 files changed, 78 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index eef9f987da..c7572ed37b 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -7,12 +7,18 @@ from narwhals._utils import Version if TYPE_CHECKING: + from typing_extensions import Self + from narwhals._arrow.typing import ChunkedArrayAny - from narwhals._plan import expr + from narwhals._plan import expr, functions as F # noqa: N812 from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar from narwhals._plan.arrow.series import ArrowSeries + from narwhals._plan.boolean import AllHorizontal, AnyHorizontal from narwhals._plan.dummy import DummySeries + from narwhals._plan.expr import FunctionExpr, RangeExpr + from narwhals._plan.ranges import IntRange + from narwhals._plan.strings import ConcatHorizontal from narwhals.typing import NonNestedLiteral @@ -86,3 +92,43 @@ def lit( return self._expr.from_native( nw_ser.to_native(), name or node.name, nw_ser.version ) + + def any_horizontal( + self, node: FunctionExpr[AnyHorizontal], frame: ArrowDataFrame, name: str + ) -> Self: + raise NotImplementedError + + def all_horizontal( + self, node: FunctionExpr[AllHorizontal], frame: ArrowDataFrame, name: str + ) -> Self: + raise NotImplementedError + + def sum_horizontal( + self, node: FunctionExpr[F.SumHorizontal], frame: ArrowDataFrame, name: str + ) -> Self: + raise NotImplementedError + + def min_horizontal( + self, node: FunctionExpr[F.MinHorizontal], frame: ArrowDataFrame, name: str + ) -> Self: + raise NotImplementedError + + def max_horizontal( + self, node: FunctionExpr[F.MaxHorizontal], frame: ArrowDataFrame, name: str + ) -> Self: + raise NotImplementedError + + def mean_horizontal( + self, node: FunctionExpr[F.MeanHorizontal], frame: ArrowDataFrame, name: str + ) -> Self: + raise NotImplementedError + + def concat_str( + self, node: FunctionExpr[ConcatHorizontal], frame: ArrowDataFrame, name: str + ) -> Self: + raise NotImplementedError + + def int_range( + self, node: RangeExpr[IntRange], frame: ArrowDataFrame, name: str + ) -> Self: + raise NotImplementedError diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index d492a63ccb..aa11149761 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -3,7 +3,13 @@ from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence, Sized from typing import TYPE_CHECKING, Any, ClassVar, Literal, Protocol, overload -from narwhals._plan import aggregation as agg, expr +from narwhals._plan import ( # noqa: N812 + aggregation as agg, + boolean, + expr, + functions as F, + strings, +) from narwhals._plan.common import ExprIR, NamedIR, flatten_hash_safe from narwhals._plan.typing import NativeFrameT, NativeSeriesT, Seq from narwhals._typing_compat import TypeVar @@ -13,7 +19,9 @@ from typing_extensions import Self, TypeAlias from narwhals._plan.dummy import DummyFrame, DummySeries + from narwhals._plan.expr import FunctionExpr, RangeExpr from narwhals._plan.options import SortMultipleOptions + from narwhals._plan.ranges import IntRange from narwhals._plan.schema import FrozenSchema from narwhals.dtypes import DType from narwhals.schema import Schema @@ -443,6 +451,28 @@ def lit( self, node: expr.Literal[Any], frame: FrameT, name: str ) -> ExprT_co | ScalarT_co: ... def len(self, node: expr.Len, frame: FrameT, name: str) -> ScalarT_co: ... + def any_horizontal( + self, node: FunctionExpr[boolean.AnyHorizontal], frame: FrameT, name: str + ) -> Self: ... + def all_horizontal( + self, node: FunctionExpr[boolean.AllHorizontal], frame: FrameT, name: str + ) -> Self: ... + def sum_horizontal( + self, node: FunctionExpr[F.SumHorizontal], frame: FrameT, name: str + ) -> Self: ... + def min_horizontal( + self, node: FunctionExpr[F.MinHorizontal], frame: FrameT, name: str + ) -> Self: ... + def max_horizontal( + self, node: FunctionExpr[F.MaxHorizontal], frame: FrameT, name: str + ) -> Self: ... + def mean_horizontal( + self, node: FunctionExpr[F.MeanHorizontal], frame: FrameT, name: str + ) -> Self: ... + def concat_str( + self, node: FunctionExpr[strings.ConcatHorizontal], frame: FrameT, name: str + ) -> Self: ... + def int_range(self, node: RangeExpr[IntRange], frame: FrameT, name: str) -> Self: ... class EagerNamespace( From f48d8685f031060c992d086b51843476874e83ba Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 11 Jul 2025 21:52:45 +0100 Subject: [PATCH 312/368] fix(typing): Oops they return expr --- narwhals/_plan/arrow/namespace.py | 18 ++++++++---------- narwhals/_plan/protocols.py | 18 ++++++++++-------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index c7572ed37b..d2208ef238 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -7,8 +7,6 @@ from narwhals._utils import Version if TYPE_CHECKING: - from typing_extensions import Self - from narwhals._arrow.typing import ChunkedArrayAny from narwhals._plan import expr, functions as F # noqa: N812 from narwhals._plan.arrow.dataframe import ArrowDataFrame @@ -95,40 +93,40 @@ def lit( def any_horizontal( self, node: FunctionExpr[AnyHorizontal], frame: ArrowDataFrame, name: str - ) -> Self: + ) -> ArrowExpr: raise NotImplementedError def all_horizontal( self, node: FunctionExpr[AllHorizontal], frame: ArrowDataFrame, name: str - ) -> Self: + ) -> ArrowExpr: raise NotImplementedError def sum_horizontal( self, node: FunctionExpr[F.SumHorizontal], frame: ArrowDataFrame, name: str - ) -> Self: + ) -> ArrowExpr: raise NotImplementedError def min_horizontal( self, node: FunctionExpr[F.MinHorizontal], frame: ArrowDataFrame, name: str - ) -> Self: + ) -> ArrowExpr: raise NotImplementedError def max_horizontal( self, node: FunctionExpr[F.MaxHorizontal], frame: ArrowDataFrame, name: str - ) -> Self: + ) -> ArrowExpr: raise NotImplementedError def mean_horizontal( self, node: FunctionExpr[F.MeanHorizontal], frame: ArrowDataFrame, name: str - ) -> Self: + ) -> ArrowExpr: raise NotImplementedError def concat_str( self, node: FunctionExpr[ConcatHorizontal], frame: ArrowDataFrame, name: str - ) -> Self: + ) -> ArrowExpr: raise NotImplementedError def int_range( self, node: RangeExpr[IntRange], frame: ArrowDataFrame, name: str - ) -> Self: + ) -> ArrowExpr: raise NotImplementedError diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index aa11149761..067b42f13e 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -453,26 +453,28 @@ def lit( def len(self, node: expr.Len, frame: FrameT, name: str) -> ScalarT_co: ... def any_horizontal( self, node: FunctionExpr[boolean.AnyHorizontal], frame: FrameT, name: str - ) -> Self: ... + ) -> ExprT_co: ... def all_horizontal( self, node: FunctionExpr[boolean.AllHorizontal], frame: FrameT, name: str - ) -> Self: ... + ) -> ExprT_co: ... def sum_horizontal( self, node: FunctionExpr[F.SumHorizontal], frame: FrameT, name: str - ) -> Self: ... + ) -> ExprT_co: ... def min_horizontal( self, node: FunctionExpr[F.MinHorizontal], frame: FrameT, name: str - ) -> Self: ... + ) -> ExprT_co: ... def max_horizontal( self, node: FunctionExpr[F.MaxHorizontal], frame: FrameT, name: str - ) -> Self: ... + ) -> ExprT_co: ... def mean_horizontal( self, node: FunctionExpr[F.MeanHorizontal], frame: FrameT, name: str - ) -> Self: ... + ) -> ExprT_co: ... def concat_str( self, node: FunctionExpr[strings.ConcatHorizontal], frame: FrameT, name: str - ) -> Self: ... - def int_range(self, node: RangeExpr[IntRange], frame: FrameT, name: str) -> Self: ... + ) -> ExprT_co: ... + def int_range( + self, node: RangeExpr[IntRange], frame: FrameT, name: str + ) -> ExprT_co: ... class EagerNamespace( From 08e2bd262a4badc69650a5d7a7e353a09a0d09ee Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 11 Jul 2025 22:41:50 +0100 Subject: [PATCH 313/368] feat(pyarrow): Impl `int_range` --- narwhals/_plan/arrow/expr.py | 3 ++ narwhals/_plan/arrow/namespace.py | 42 ++++++++++++++++++++-- narwhals/_plan/protocols.py | 58 ++++++++++++++++++++++++++++--- tests/plan/to_compliant_test.py | 4 +++ 4 files changed, 100 insertions(+), 7 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 9af8979c73..5fff3f2cf0 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -378,6 +378,9 @@ def native(self) -> NativeScalar: def to_series(self) -> ArrowSeries: return self.broadcast(1) + def to_python(self) -> PythonLiteral: + return self.native.as_py() # type: ignore[no-any-return] + def broadcast(self, length: int) -> ArrowSeries: scalar = self.native if length == 1: diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index d2208ef238..778c48c109 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -2,9 +2,13 @@ from typing import TYPE_CHECKING, overload +import pyarrow as pa # ignore-banned-import + +from narwhals._arrow.utils import narwhals_to_native_dtype from narwhals._plan.literal import is_literal_scalar from narwhals._plan.protocols import EagerNamespace from narwhals._utils import Version +from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: from narwhals._arrow.typing import ChunkedArrayAny @@ -17,7 +21,7 @@ from narwhals._plan.expr import FunctionExpr, RangeExpr from narwhals._plan.ranges import IntRange from narwhals._plan.strings import ConcatHorizontal - from narwhals.typing import NonNestedLiteral + from narwhals.typing import NonNestedLiteral, PythonLiteral class ArrowNamespace( @@ -129,4 +133,38 @@ def concat_str( def int_range( self, node: RangeExpr[IntRange], frame: ArrowDataFrame, name: str ) -> ArrowExpr: - raise NotImplementedError + start_: PythonLiteral + end_: PythonLiteral + start, end = node.function.unwrap_input(node) + step = node.function.step + dtype = node.function.dtype + if is_literal_scalar(start) and is_literal_scalar(end): + start_, end_ = start.unwrap(), end.unwrap() + else: + scalar_start = self._expr.from_ir(start, frame, "start") + scalar_end = self._expr.from_ir(end, frame, "end") + if isinstance(scalar_start, self._scalar) and isinstance( + scalar_end, self._scalar + ): + start_, end_ = scalar_start.to_python(), scalar_end.to_python() + else: + msg = ( + f"All inputs for `int_range()` must be scalar or aggregations, but got \n" + f"{scalar_start.native!r}\n{scalar_end.native!r}" + ) + raise InvalidOperationError(msg) + if isinstance(start_, int) and isinstance(end_, int): + import numpy as np # ignore-banned-import + + from narwhals._plan.arrow.expr import chunked_array + + pa_dtype = narwhals_to_native_dtype(dtype, self.version) + native = chunked_array(pa.array(np.arange(start_, end_, step), pa_dtype)) + return self._expr.from_native(native, name, self.version) + + else: + msg = ( + f"All inputs for `int_range()` resolve to int, but got \n" + f"{start_!r}\n{start_!r}" + ) + raise InvalidOperationError(msg) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 067b42f13e..2ba3f04e2c 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -10,7 +10,7 @@ functions as F, strings, ) -from narwhals._plan.common import ExprIR, NamedIR, flatten_hash_safe +from narwhals._plan.common import ExprIR, Function, NamedIR, flatten_hash_safe from narwhals._plan.typing import NativeFrameT, NativeSeriesT, Seq from narwhals._typing_compat import TypeVar from narwhals._utils import Version, _hasattr_static @@ -171,9 +171,6 @@ def map_batches( def rolling_expr( self, node: expr.RollingExpr, frame: FrameT_contra, name: str ) -> Self: ... - def function_expr( - self, node: expr.FunctionExpr, frame: FrameT_contra, name: str - ) -> Self: ... # series only (section 3) def sort(self, node: expr.Sort, frame: FrameT_contra, name: str) -> Self: ... def sort_by(self, node: expr.SortBy, frame: FrameT_contra, name: str) -> Self: ... @@ -261,9 +258,15 @@ class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]): expr.AnonymousExpr: lambda self, node, frame, name: self.map_batches( node, frame, name ), - expr.FunctionExpr: lambda self, node, frame, name: self.function_expr( + expr.FunctionExpr: lambda self, node, frame, name: self._dispatch_function( node, frame, name ), + # NOTE: Keeping it simple for now + # When adding other `*_range` functions, this should instead map to `range_expr` + expr.RangeExpr: lambda self, + node, + frame, + name: self.__narwhals_namespace__().int_range(node, frame, name), expr.OrderedWindowExpr: lambda self, node, frame, name: self.over_ordered( node, frame, name ), @@ -272,6 +275,38 @@ class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]): node, frame, name ), } + _DISPATCH_FUNCTION: ClassVar[ + Mapping[type[Function], Callable[[Any, FunctionExpr, Any, str], Any]] + ] = { + boolean.AnyHorizontal: lambda self, + node, + frame, + name: self.__narwhals_namespace__().any_horizontal(node, frame, name), + boolean.AllHorizontal: lambda self, + node, + frame, + name: self.__narwhals_namespace__().all_horizontal(node, frame, name), + F.SumHorizontal: lambda self, + node, + frame, + name: self.__narwhals_namespace__().sum_horizontal(node, frame, name), + F.MinHorizontal: lambda self, + node, + frame, + name: self.__narwhals_namespace__().min_horizontal(node, frame, name), + F.MaxHorizontal: lambda self, + node, + frame, + name: self.__narwhals_namespace__().max_horizontal(node, frame, name), + F.MeanHorizontal: lambda self, + node, + frame, + name: self.__narwhals_namespace__().mean_horizontal(node, frame, name), + strings.ConcatHorizontal: lambda self, + node, + frame, + name: self.__narwhals_namespace__().concat_str(node, frame, name), + } def _dispatch(self, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: if (method := self._DISPATCH.get(node.__class__)) and ( @@ -281,6 +316,17 @@ def _dispatch(self, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: msg = f"Support for {node.__class__.__name__!r} is not yet implemented, got:\n{node!r}" raise NotImplementedError(msg) + def _dispatch_function( + self, node: FunctionExpr, frame: FrameT_contra, name: str + ) -> R_co: + fn = node.function + if (method := self._DISPATCH_FUNCTION.get(fn.__class__)) and ( + result := method(self, node, frame, name) + ): + return result # type: ignore[no-any-return] + msg = f"Support for {fn.__class__.__name__!r} is not yet implemented, got:\n{node!r}" + raise NotImplementedError(msg) + @classmethod def from_ir(cls, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: obj = cls.__new__(cls) @@ -413,6 +459,8 @@ class EagerScalar( def __len__(self) -> int: return 1 + def to_python(self) -> PythonLiteral: ... + class LazyScalar( CompliantScalar[FrameT_contra, SeriesT], diff --git a/tests/plan/to_compliant_test.py b/tests/plan/to_compliant_test.py index 4101e80e76..811ee5f8ee 100644 --- a/tests/plan/to_compliant_test.py +++ b/tests/plan/to_compliant_test.py @@ -90,6 +90,10 @@ def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None: .name.to_uppercase(), {"C": [2.0, 9.0, 4.0], "D": [7.0, 8.0, 8.0]}, ), + ([nwd.int_range(5)], {"literal": [0, 1, 2, 3, 4]}), + ([nwd.int_range(nwd.len())], {"literal": [0, 1, 2]}), + (nwd.int_range(nwd.len() * 5, 20).alias("lol"), {"lol": [15, 16, 17, 18, 19]}), + (nwd.int_range(nwd.col("b").min() + 4, nwd.col("d").last()), {"b": [5, 6, 7]}), ], ids=_ids_ir, ) From 5635be2fd5e7e9153f9438f8a1dc2349b14e807a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 11 Jul 2025 22:44:19 +0100 Subject: [PATCH 314/368] fix: typo in error message --- narwhals/_plan/arrow/namespace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index 778c48c109..16dc0f6c0f 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -165,6 +165,6 @@ def int_range( else: msg = ( f"All inputs for `int_range()` resolve to int, but got \n" - f"{start_!r}\n{start_!r}" + f"{start_!r}\n{end_!r}" ) raise InvalidOperationError(msg) From 98fe85a1069c67e58fc8629a7528a6bba00cebf8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 12 Jul 2025 13:47:38 +0100 Subject: [PATCH 315/368] feat(pyarrow): Impl `pow` Remaining n-ary: `fill_null`, `is_between` --- narwhals/_plan/arrow/expr.py | 12 ++++++++++-- narwhals/_plan/protocols.py | 2 ++ tests/plan/to_compliant_test.py | 5 +++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 5fff3f2cf0..4b312981a8 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -12,7 +12,7 @@ floordiv_compat, narwhals_to_native_dtype, ) -from narwhals._plan import operators as ops +from narwhals._plan import functions as F, operators as ops # noqa: N812 from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.common import ExprIR, into_dtype from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch @@ -50,7 +50,7 @@ ) from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.arrow.namespace import ArrowNamespace - from narwhals._plan.expr import BinaryExpr + from narwhals._plan.expr import BinaryExpr, FunctionExpr from narwhals.typing import IntoDType, PythonLiteral NativeScalar: TypeAlias = "pa.Scalar[Any]" @@ -107,6 +107,14 @@ def cast( native = self._dispatch(node.expr, frame, name).native return self._with_native(pc.cast(native, data_type), name) + def pow( + self, node: FunctionExpr[F.Pow], frame: ArrowDataFrame, name: str + ) -> _StoresNativeT_co: + base, exponent = node.function.unwrap_input(node) + base_ = self._dispatch(base, frame, "base").native + exponent_ = self._dispatch(exponent, frame, "exponent").native + return self._with_native(pc.power(base_, exponent_), name) + class ArrowExpr( # type: ignore[misc] _ArrowDispatch["ArrowExpr | ArrowScalar"], diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 2ba3f04e2c..e251a37913 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -155,6 +155,7 @@ def _with_native(self, native: Any, name: str, /) -> Self: # series & scalar def cast(self, node: expr.Cast, frame: FrameT_contra, name: str) -> Self: ... + def pow(self, node: FunctionExpr[F.Pow], frame: FrameT_contra, name: str) -> Self: ... def binary_expr( self, node: expr.BinaryExpr, frame: FrameT_contra, name: str ) -> Self: ... @@ -306,6 +307,7 @@ class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]): node, frame, name: self.__narwhals_namespace__().concat_str(node, frame, name), + F.Pow: lambda self, node, frame, name: self.pow(node, frame, name), } def _dispatch(self, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: diff --git a/tests/plan/to_compliant_test.py b/tests/plan/to_compliant_test.py index 811ee5f8ee..9600149d5a 100644 --- a/tests/plan/to_compliant_test.py +++ b/tests/plan/to_compliant_test.py @@ -94,6 +94,11 @@ def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None: ([nwd.int_range(nwd.len())], {"literal": [0, 1, 2]}), (nwd.int_range(nwd.len() * 5, 20).alias("lol"), {"lol": [15, 16, 17, 18, 19]}), (nwd.int_range(nwd.col("b").min() + 4, nwd.col("d").last()), {"b": [5, 6, 7]}), + (nwd.col("b") ** 2, {"b": [1, 4, 9]}), + ( + [2 ** nwd.col("b"), (nwd.lit(2.0) ** nwd.nth(1)).alias("lit")], + {"literal": [2, 4, 8], "lit": [2, 4, 8]}, + ), ], ids=_ids_ir, ) From 63b63c8b844adbf29d8fae42e25dfcd2c7fb3089 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 12 Jul 2025 15:56:12 +0100 Subject: [PATCH 316/368] feat(pyarrow): Impl `fill_null`, `is_between` Eventually will want to make those less repetitive, but that should be all of the fixed n-ary expressions ready --- narwhals/_plan/arrow/expr.py | 32 +++++++++++++++++++++- narwhals/_plan/protocols.py | 10 +++++++ tests/plan/to_compliant_test.py | 47 ++++++++++++++++++++++++++++++++- 3 files changed, 87 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 4b312981a8..f1a443757d 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -50,11 +50,15 @@ ) from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.arrow.namespace import ArrowNamespace + from narwhals._plan.boolean import IsBetween from narwhals._plan.expr import BinaryExpr, FunctionExpr - from narwhals.typing import IntoDType, PythonLiteral + from narwhals.typing import ClosedInterval, IntoDType, PythonLiteral NativeScalar: TypeAlias = "pa.Scalar[Any]" BinOp: TypeAlias = Callable[..., "ChunkedArrayAny | NativeScalar"] +LogicalOp: TypeAlias = Callable[ + ..., "pa.ChunkedArray[pa.BooleanScalar] | pa.BooleanScalar" +] BACKEND_VERSION = Implementation.PYARROW._backend_version() _StoresNativeAny: TypeAlias = _StoresNative[Any] @@ -88,6 +92,13 @@ def modulus(lhs: Any, rhs: Any) -> Any: ops.ExclusiveOr: pc.xor, } +IS_BETWEEN: Mapping[ClosedInterval, tuple[LogicalOp, LogicalOp]] = { + "left": (pc.greater_equal, pc.less), + "right": (pc.greater, pc.less_equal), + "none": (pc.greater, pc.less), + "both": (pc.greater_equal, pc.less_equal), +} + class _ArrowDispatch( ExprDispatch["ArrowDataFrame", _StoresNativeT_co, "ArrowNamespace"], Protocol @@ -115,6 +126,25 @@ def pow( exponent_ = self._dispatch(exponent, frame, "exponent").native return self._with_native(pc.power(base_, exponent_), name) + def fill_null( + self, node: FunctionExpr[F.FillNull], frame: ArrowDataFrame, name: str + ) -> _StoresNativeT_co: + expr, value = node.function.unwrap_input(node) + native = self._dispatch(expr, frame, name).native + value_ = self._dispatch(value, frame, "value").native + return self._with_native(pc.fill_null(native, value_), name) + + def is_between( + self, node: FunctionExpr[IsBetween], frame: ArrowDataFrame, name: str + ) -> _StoresNativeT_co: + expr, lower_bound, upper_bound = node.function.unwrap_input(node) + native = self._dispatch(expr, frame, name).native + lower = self._dispatch(lower_bound, frame, "lower").native + upper = self._dispatch(upper_bound, frame, "upper").native + fn_lhs, fn_rhs = IS_BETWEEN[node.function.closed] + result = pc.and_kleene(fn_lhs(native, lower), fn_rhs(native, upper)) + return self._with_native(result, name) + class ArrowExpr( # type: ignore[misc] _ArrowDispatch["ArrowExpr | ArrowScalar"], diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index e251a37913..9c336efafa 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -156,6 +156,12 @@ def _with_native(self, native: Any, name: str, /) -> Self: # series & scalar def cast(self, node: expr.Cast, frame: FrameT_contra, name: str) -> Self: ... def pow(self, node: FunctionExpr[F.Pow], frame: FrameT_contra, name: str) -> Self: ... + def fill_null( + self, node: FunctionExpr[F.FillNull], frame: FrameT_contra, name: str + ) -> Self: ... + def is_between( + self, node: FunctionExpr[boolean.IsBetween], frame: FrameT_contra, name: str + ) -> Self: ... def binary_expr( self, node: expr.BinaryExpr, frame: FrameT_contra, name: str ) -> Self: ... @@ -308,6 +314,10 @@ class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]): frame, name: self.__narwhals_namespace__().concat_str(node, frame, name), F.Pow: lambda self, node, frame, name: self.pow(node, frame, name), + F.FillNull: lambda self, node, frame, name: self.fill_null(node, frame, name), + boolean.IsBetween: lambda self, node, frame, name: self.is_between( + node, frame, name + ), } def _dispatch(self, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: diff --git a/tests/plan/to_compliant_test.py b/tests/plan/to_compliant_test.py index 9600149d5a..c8d9b12722 100644 --- a/tests/plan/to_compliant_test.py +++ b/tests/plan/to_compliant_test.py @@ -21,7 +21,13 @@ @pytest.fixture def data_small() -> dict[str, Any]: - return {"a": ["A", "B", "A"], "b": [1, 2, 3], "c": [9, 2, 4], "d": [8, 7, 8]} + return { + "a": ["A", "B", "A"], + "b": [1, 2, 3], + "c": [9, 2, 4], + "d": [8, 7, 8], + "e": [None, 9, 7], + } def _ids_ir(expr: DummyExpr | Any) -> str: @@ -99,6 +105,45 @@ def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None: [2 ** nwd.col("b"), (nwd.lit(2.0) ** nwd.nth(1)).alias("lit")], {"literal": [2, 4, 8], "lit": [2, 4, 8]}, ), + ( + [ + nwd.col("b").is_between(2, 3, "left").alias("left"), + nwd.col("b").is_between(2, 3, "right").alias("right"), + nwd.col("b").is_between(2, 3, "none").alias("none"), + nwd.col("b").is_between(2, 3, "both").alias("both"), + nwd.col("c").is_between( + nwd.col("c").mean() - 1, 7 - nwd.col("b"), "both" + ), + nwd.col("c") + .alias("c_right") + .is_between(nwd.col("c").mean() - 1, 7 - nwd.col("b"), "right"), + ], + { + "left": [False, True, False], + "right": [False, False, True], + "none": [False, False, False], + "both": [False, True, True], + "c": [False, False, True], + "c_right": [False, False, False], + }, + ), + ( + [ + nwd.col("e").fill_null(0).alias("e_0"), + nwd.col("e").fill_null(nwd.col("b")).alias("e_b"), + nwd.col("e").fill_null(nwd.col("b").last()).alias("e_b_last"), + nwd.col("e") + .sort(nulls_last=True) + .fill_null(nwd.col("d").last() - nwd.col("c")) + .alias("e_sort_wild"), + ], + { + "e_0": [0, 9, 7], + "e_b": [1, 9, 7], + "e_b_last": [3, 9, 7], + "e_sort_wild": [7, 9, 4], + }, + ), ], ids=_ids_ir, ) From cccc323ae9351d188036802495161d68e0409b9a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 12 Jul 2025 18:06:53 +0100 Subject: [PATCH 317/368] feat(pyarrow): Impl 6x boolean unary functions --- narwhals/_plan/arrow/expr.py | 53 +++++++++++++++++++++++++++++++-- narwhals/_plan/protocols.py | 26 ++++++++++++++++ tests/plan/to_compliant_test.py | 5 ++++ 3 files changed, 82 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index f1a443757d..2e61ec4aa2 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -12,7 +12,7 @@ floordiv_compat, narwhals_to_native_dtype, ) -from narwhals._plan import functions as F, operators as ops # noqa: N812 +from narwhals._plan import boolean, functions as F, operators as ops # noqa: N812 from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.common import ExprIR, into_dtype from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch @@ -50,7 +50,7 @@ ) from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.arrow.namespace import ArrowNamespace - from narwhals._plan.boolean import IsBetween + from narwhals._plan.boolean import IsBetween, IsFinite, IsNan, IsNull from narwhals._plan.expr import BinaryExpr, FunctionExpr from narwhals.typing import ClosedInterval, IntoDType, PythonLiteral @@ -74,6 +74,14 @@ def modulus(lhs: Any, rhs: Any) -> Any: return pc.subtract(lhs, pc.multiply(floor_div, rhs)) +def any_(native: Any) -> pa.BooleanScalar: + return pc.any(native, min_count=0) + + +def all_(native: Any) -> pa.BooleanScalar: + return pc.all(native, min_count=0) + + DISPATCH_BINARY: Mapping[type[ops.Operator], BinOp] = { ops.Eq: pc.equal, ops.NotEq: pc.not_equal, @@ -145,6 +153,47 @@ def is_between( result = pc.and_kleene(fn_lhs(native, lower), fn_rhs(native, upper)) return self._with_native(result, name) + def _unary_function( + self, fn: Callable[[Any], Any], / + ) -> Callable[[FunctionExpr[Any], ArrowDataFrame, str], _StoresNativeT_co]: + def func( + node: FunctionExpr[Any], frame: ArrowDataFrame, name: str + ) -> _StoresNativeT_co: + native = self._dispatch(node.input[0], frame, name).native + return self._with_native(fn(native), name) + + return func + + def not_( + self, node: FunctionExpr[boolean.Not], frame: ArrowDataFrame, name: str + ) -> _StoresNativeT_co: + return self._unary_function(pc.invert)(node, frame, name) + + def all( + self, node: FunctionExpr[boolean.All], frame: ArrowDataFrame, name: str + ) -> _StoresNativeT_co: + return self._unary_function(all_)(node, frame, name) + + def any( + self, node: FunctionExpr[boolean.Any], frame: ArrowDataFrame, name: str + ) -> _StoresNativeT_co: + return self._unary_function(any_)(node, frame, name) + + def is_finite( + self, node: FunctionExpr[IsFinite], frame: ArrowDataFrame, name: str + ) -> _StoresNativeT_co: + return self._unary_function(pc.is_finite)(node, frame, name) + + def is_nan( + self, node: FunctionExpr[IsNan], frame: ArrowDataFrame, name: str + ) -> _StoresNativeT_co: + return self._unary_function(pc.is_nan)(node, frame, name) + + def is_null( + self, node: FunctionExpr[IsNull], frame: ArrowDataFrame, name: str + ) -> _StoresNativeT_co: + return self._unary_function(pc.is_null)(node, frame, name) + class ArrowExpr( # type: ignore[misc] _ArrowDispatch["ArrowExpr | ArrowScalar"], diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 9c336efafa..30927ac445 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -156,12 +156,24 @@ def _with_native(self, native: Any, name: str, /) -> Self: # series & scalar def cast(self, node: expr.Cast, frame: FrameT_contra, name: str) -> Self: ... def pow(self, node: FunctionExpr[F.Pow], frame: FrameT_contra, name: str) -> Self: ... + def not_( + self, node: FunctionExpr[boolean.Not], frame: FrameT_contra, name: str + ) -> Self: ... def fill_null( self, node: FunctionExpr[F.FillNull], frame: FrameT_contra, name: str ) -> Self: ... def is_between( self, node: FunctionExpr[boolean.IsBetween], frame: FrameT_contra, name: str ) -> Self: ... + def is_finite( + self, node: FunctionExpr[boolean.IsFinite], frame: FrameT_contra, name: str + ) -> Self: ... + def is_nan( + self, node: FunctionExpr[boolean.IsNan], frame: FrameT_contra, name: str + ) -> Self: ... + def is_null( + self, node: FunctionExpr[boolean.IsNull], frame: FrameT_contra, name: str + ) -> Self: ... def binary_expr( self, node: expr.BinaryExpr, frame: FrameT_contra, name: str ) -> Self: ... @@ -225,6 +237,12 @@ def median( def min( self, node: agg.Min, frame: FrameT_contra, name: str ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def all( + self, node: FunctionExpr[boolean.All], frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def any( + self, node: FunctionExpr[boolean.Any], frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]): @@ -318,6 +336,14 @@ class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]): boolean.IsBetween: lambda self, node, frame, name: self.is_between( node, frame, name ), + boolean.IsFinite: lambda self, node, frame, name: self.is_finite( + node, frame, name + ), + boolean.IsNan: lambda self, node, frame, name: self.is_nan(node, frame, name), + boolean.IsNull: lambda self, node, frame, name: self.is_null(node, frame, name), + boolean.Not: lambda self, node, frame, name: self.not_(node, frame, name), + boolean.Any: lambda self, node, frame, name: self.any(node, frame, name), + boolean.All: lambda self, node, frame, name: self.all(node, frame, name), } def _dispatch(self, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: diff --git a/tests/plan/to_compliant_test.py b/tests/plan/to_compliant_test.py index c8d9b12722..cf88c3d456 100644 --- a/tests/plan/to_compliant_test.py +++ b/tests/plan/to_compliant_test.py @@ -144,6 +144,11 @@ def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None: "e_sort_wild": [7, 9, 4], }, ), + (nwd.col("e", "d").is_null().any(), {"e": [True], "d": [False]}), + ( + [(~nwd.col("e", "d").is_null()).all(), "b"], + {"e": [False, False, False], "d": [True, True, True], "b": [1, 2, 3]}, + ), ], ids=_ids_ir, ) From 638b585faa4b710da0b2ed28e14a84d648fe79d2 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 12 Jul 2025 20:44:20 +0100 Subject: [PATCH 318/368] chore: remove `to_compliant`/`_to_compliant` - Not planning to go down that route any further (92694ce439ea6882e8ea5994142d37bb500042f0) --- narwhals/_plan/common.py | 5 ----- narwhals/_plan/dummy.py | 5 +---- narwhals/_plan/expr.py | 22 +--------------------- narwhals/_plan/typing.py | 11 ----------- tests/plan/to_compliant_test.py | 22 ---------------------- 5 files changed, 2 insertions(+), 63 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 557242f5d9..495a4e18ae 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -9,13 +9,11 @@ DTypeT, ExprIRT, ExprIRT2, - ExprT, IRNamespaceT, MapIR, NamedOrExprIRT, NativeSeriesT, NonNestedDTypeT, - Ns, Seq, ) from narwhals._utils import _hasattr_static @@ -179,9 +177,6 @@ def to_narwhals(self, version: Version = Version.MAIN) -> DummyExpr: return dummy.DummyExpr._from_ir(self) return dummy.DummyExprV1._from_ir(self) - def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: - raise NotImplementedError - @property def is_scalar(self) -> bool: return False diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 541c463c60..4bad96a7ee 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -49,7 +49,7 @@ from narwhals._plan.strings import ExprStringNamespace from narwhals._plan.struct import ExprStructNamespace from narwhals._plan.temporal import ExprDateTimeNamespace - from narwhals._plan.typing import ExprT, IntoExpr, IntoExprColumn, Ns, Seq, Udf + from narwhals._plan.typing import IntoExpr, IntoExprColumn, Seq, Udf from narwhals.dtypes import DType from narwhals.typing import ( ClosedInterval, @@ -104,9 +104,6 @@ def _from_ir(cls, ir: ExprIR, /) -> Self: obj._ir = ir return obj - def _to_compliant(self, plx: Ns[ExprT], /) -> ExprT: - return self._ir.to_compliant(plx) - @property def version(self) -> Version: return self._version diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index cba6ed65b4..561aa3d96e 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -7,24 +7,16 @@ import typing as t from narwhals._plan.aggregation import AggExpr, OrderableAggExpr -from narwhals._plan.common import ( - ExprIR, - SelectorIR, - collect, - is_non_nested_literal, - is_regex_projection, -) +from narwhals._plan.common import ExprIR, SelectorIR, collect, is_regex_projection from narwhals._plan.exceptions import function_expr_invalid_operation_error from narwhals._plan.name import KeepName, RenameAlias from narwhals._plan.typing import ( - ExprT, FunctionT, LeftSelectorT, LeftT, LeftT2, LiteralT, MapIR, - Ns, OperatorT, RangeT, RightSelectorT, @@ -118,9 +110,6 @@ class Column(ExprIR): def __repr__(self) -> str: return f"col({self.name!r})" - def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: - return plx.col(self.name) - def with_name(self, name: str, /) -> Column: return self if name == self.name else col(name) @@ -142,9 +131,6 @@ class Columns(_ColumnSelection): def __repr__(self) -> str: return f"cols({list(self.names)!r})" - def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: - return plx.col(*self.names) - class Nth(_ColumnSelection): __slots__ = ("index",) @@ -237,12 +223,6 @@ def name(self) -> str: def __repr__(self) -> str: return f"lit({self.value!r})" - def to_compliant(self, plx: Ns[ExprT], /) -> ExprT: - value = self.unwrap() - if is_non_nested_literal(value): - return plx.lit(value, self.dtype) - raise NotImplementedError(type(self.value)) - def unwrap(self) -> LiteralT: return self.value.unwrap() diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index e7f5fe03bc..6dfdaa8d21 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -8,8 +8,6 @@ from typing_extensions import TypeAlias from narwhals import dtypes - from narwhals._compliant import CompliantNamespace as Namespace - from narwhals._compliant.typing import CompliantExprAny from narwhals._plan import operators as ops from narwhals._plan.common import ExprIR, Function, IRNamespace, NamedIR, SelectorIR from narwhals._plan.dummy import DummyExpr, DummySeries @@ -79,15 +77,6 @@ MapIR: TypeAlias = "t.Callable[[ExprIR], ExprIR]" """A function to apply to all nodes in this tree.""" -# NOTE: Shorter aliases of `_compliant.typing` -# - Aiming to try and preserve the types as much as possible -# - Recursion between `Expr` and `Frame` is an issue -Expr: TypeAlias = "CompliantExprAny" -ExprT = TypeVar("ExprT", bound="Expr") -Ns: TypeAlias = "Namespace[t.Any, ExprT]" -"""A `CompliantNamespace`, ignoring the `Frame` type.""" - - T = TypeVar("T") Seq: TypeAlias = "tuple[T,...]" diff --git a/tests/plan/to_compliant_test.py b/tests/plan/to_compliant_test.py index cf88c3d456..b92db659ac 100644 --- a/tests/plan/to_compliant_test.py +++ b/tests/plan/to_compliant_test.py @@ -8,14 +8,11 @@ from narwhals._plan import demo as nwd, selectors as ndcs from narwhals._plan.common import is_expr from narwhals.exceptions import ComputeError -from narwhals.utils import Version -from tests.namespace_test import backends from tests.utils import assert_equal_data if TYPE_CHECKING: from collections.abc import Sequence - from narwhals._namespace import BackendName from narwhals._plan.dummy import DummyExpr @@ -36,25 +33,6 @@ def _ids_ir(expr: DummyExpr | Any) -> str: return repr(expr) -@pytest.mark.parametrize( - ("expr"), - [ - nwd.col("a"), - nwd.col("a", "b"), - nwd.lit(1), - nwd.lit(2.0), - nwd.lit(None, nw.String()), - ], - ids=_ids_ir, -) -@backends -def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None: - pytest.importorskip(backend) - namespace = Version.MAIN.namespace.from_backend(backend).compliant - compliant_expr = expr._to_compliant(namespace) - assert isinstance(compliant_expr, namespace._expr) - - XFAIL_REWRITE_SPECIAL_ALIASES = pytest.mark.xfail( reason="Bug in `meta` namespace impl", raises=ComputeError ) From ccfa7da966c76310124515afc5cce68a5f3236b1 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 12 Jul 2025 20:46:54 +0100 Subject: [PATCH 319/368] rename `to_compliant_test` -> `compliant_test` Old name was based on a dead api --- tests/plan/{to_compliant_test.py => compliant_test.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/plan/{to_compliant_test.py => compliant_test.py} (100%) diff --git a/tests/plan/to_compliant_test.py b/tests/plan/compliant_test.py similarity index 100% rename from tests/plan/to_compliant_test.py rename to tests/plan/compliant_test.py From df52814bbc2647c2802510c5c565931a0923515a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 13 Jul 2025 19:53:25 +0100 Subject: [PATCH 320/368] test: Add failing `when` tests Will be demonstrating #668, which just needs a backend impl --- tests/plan/compliant_test.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index b92db659ac..482fbef5b4 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -37,6 +37,13 @@ def _ids_ir(expr: DummyExpr | Any) -> str: reason="Bug in `meta` namespace impl", raises=ComputeError ) +XFAIL_NOT_IMPL_WHEN = pytest.mark.xfail( + reason="Not implemented when-then-otherwise", raises=NotImplementedError +) +XFAIL_NOT_ALL_HORIZONTAL = pytest.mark.xfail( + reason="Not implemented all_horizontal", raises=NotImplementedError +) + @pytest.mark.parametrize( ("expr", "expected"), @@ -127,6 +134,34 @@ def _ids_ir(expr: DummyExpr | Any) -> str: [(~nwd.col("e", "d").is_null()).all(), "b"], {"e": [False, False, False], "d": [True, True, True], "b": [1, 2, 3]}, ), + pytest.param( + nwd.when(d=8).then("c"), {"c": [9, None, 4]}, marks=XFAIL_NOT_IMPL_WHEN + ), + pytest.param( + nwd.when(nwd.col("e").is_null()) + .then(nwd.col("b") + nwd.col("c")) + .otherwise(50), + {"b": [10, 50, 50]}, + marks=XFAIL_NOT_IMPL_WHEN, + ), + pytest.param( + nwd.when(nwd.col("a") == nwd.lit("C")) + .then(nwd.lit("c")) + .when(nwd.col("a") == nwd.lit("D")) + .then(nwd.lit("d")) + .when(nwd.col("a") == nwd.lit("B")) + .then(nwd.lit("b")) + .when(nwd.col("a") == nwd.lit("A")) + .then(nwd.lit("a")) + .alias("A"), + {"A": ["a", "b", "a"]}, + marks=XFAIL_NOT_IMPL_WHEN, + ), + pytest.param( + nwd.when(nwd.col("c") > 5, b=1).then(999), + {"literal": [999, None, None]}, + marks=[XFAIL_NOT_IMPL_WHEN, XFAIL_NOT_ALL_HORIZONTAL], + ), ], ids=_ids_ir, ) From 3140a481822c8bf36414bcf7b48930b5e4cd1297 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 13 Jul 2025 21:40:54 +0100 Subject: [PATCH 321/368] feat(pyarrow): Impl complex `when-then-otherwise` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit If this works on all `pyarrow` versions, it might be the most compelling part of the PR so far 🤞 --- narwhals/_plan/arrow/expr.py | 11 ++++++++++- tests/plan/compliant_test.py | 31 ++++++++++++++++++++++++------- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 2e61ec4aa2..6ce889f2f1 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -51,7 +51,7 @@ from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.boolean import IsBetween, IsFinite, IsNan, IsNull - from narwhals._plan.expr import BinaryExpr, FunctionExpr + from narwhals._plan.expr import BinaryExpr, FunctionExpr, Ternary from narwhals.typing import ClosedInterval, IntoDType, PythonLiteral NativeScalar: TypeAlias = "pa.Scalar[Any]" @@ -194,6 +194,15 @@ def is_null( ) -> _StoresNativeT_co: return self._unary_function(pc.is_null)(node, frame, name) + def ternary_expr( + self, node: Ternary, frame: ArrowDataFrame, name: str + ) -> _StoresNativeT_co: + when = self._dispatch(node.predicate, frame, name) + then = self._dispatch(node.truthy, frame, name) + otherwise = self._dispatch(node.falsy, frame, name) + result = pc.if_else(when.native, then.native, otherwise.native) + return self._with_native(result, name) + class ArrowExpr( # type: ignore[misc] _ArrowDispatch["ArrowExpr | ArrowScalar"], diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 482fbef5b4..48debbfc44 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -1,5 +1,6 @@ from __future__ import annotations +# ruff: noqa: FBT003 from typing import TYPE_CHECKING, Any import pytest @@ -37,9 +38,6 @@ def _ids_ir(expr: DummyExpr | Any) -> str: reason="Bug in `meta` namespace impl", raises=ComputeError ) -XFAIL_NOT_IMPL_WHEN = pytest.mark.xfail( - reason="Not implemented when-then-otherwise", raises=NotImplementedError -) XFAIL_NOT_ALL_HORIZONTAL = pytest.mark.xfail( reason="Not implemented all_horizontal", raises=NotImplementedError ) @@ -135,14 +133,14 @@ def _ids_ir(expr: DummyExpr | Any) -> str: {"e": [False, False, False], "d": [True, True, True], "b": [1, 2, 3]}, ), pytest.param( - nwd.when(d=8).then("c"), {"c": [9, None, 4]}, marks=XFAIL_NOT_IMPL_WHEN + nwd.when(d=8).then("c"), {"c": [9, None, 4]}, id="When-otherwise-none" ), pytest.param( nwd.when(nwd.col("e").is_null()) .then(nwd.col("b") + nwd.col("c")) .otherwise(50), {"b": [10, 50, 50]}, - marks=XFAIL_NOT_IMPL_WHEN, + id="When-otherwise-native-broadcast", ), pytest.param( nwd.when(nwd.col("a") == nwd.lit("C")) @@ -155,12 +153,31 @@ def _ids_ir(expr: DummyExpr | Any) -> str: .then(nwd.lit("a")) .alias("A"), {"A": ["a", "b", "a"]}, - marks=XFAIL_NOT_IMPL_WHEN, + id="When-then-x4", ), pytest.param( nwd.when(nwd.col("c") > 5, b=1).then(999), {"literal": [999, None, None]}, - marks=[XFAIL_NOT_IMPL_WHEN, XFAIL_NOT_ALL_HORIZONTAL], + marks=[XFAIL_NOT_ALL_HORIZONTAL], + id="When-multiple-predicates", + ), + pytest.param( + nwd.when(nwd.lit(True)).then("c"), + {"c": [9, 2, 4]}, + id="When-literal-then-column", + ), + pytest.param( + nwd.when(nwd.lit(True)).then(nwd.col("c").mean()), + {"c": [5.0]}, + id="When-literal-then-agg", + ), + pytest.param( + [ + nwd.when(nwd.lit(True)).then(nwd.col("e").last()), + nwd.col("b").sort(descending=True), + ], + {"e": [7, 7, 7], "b": [3, 2, 1]}, + id="When-literal-then-agg-broadcast", ), ], ids=_ids_ir, From b5550b40318971380c4d7806cadd9baa9cf18f7b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 14 Jul 2025 17:08:51 +0000 Subject: [PATCH 322/368] refactor: Remove `pyarrow<13` compat Not needed since #2825 --- narwhals/_plan/arrow/dataframe.py | 2 +- narwhals/_plan/arrow/expr.py | 22 +++++----------------- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 7926b0b5d6..1a2209db0f 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -54,7 +54,7 @@ def schema(self) -> dict[str, DType]: } def __len__(self) -> int: - return len(self.native) + return self.native.num_rows def to_narwhals(self) -> DummyFrame[pa.Table, ChunkedArrayAny]: return DummyFrame[pa.Table, "ChunkedArrayAny"]._from_compliant(self) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 6ce889f2f1..c42a276695 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Callable, Mapping -from typing import TYPE_CHECKING, Any, Protocol, cast, overload +from typing import TYPE_CHECKING, Any, Protocol, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import @@ -299,15 +299,13 @@ def filter(self, node: expr.Filter, frame: ArrowDataFrame, name: str) -> ArrowEx def first(self, node: First, frame: ArrowDataFrame, name: str) -> ArrowScalar: prev = self._dispatch_expr(node.expr, frame, name) native = prev.native - result = lit(native[0]) if len(prev) else lit(None, native.type) + result = native[0] if len(prev) else lit(None, native.type) return self._with_native(result, name) def last(self, node: Last, frame: ArrowDataFrame, name: str) -> ArrowScalar: prev = self._dispatch_expr(node.expr, frame, name) native = prev.native - result = ( - lit(native[height - 1]) if (height := len(prev)) else lit(None, native.type) - ) + result = native[height - 1] if (height := len(prev)) else lit(None, native.type) return self._with_native(result, name) def arg_min(self, node: ArgMin, frame: ArrowDataFrame, name: str) -> ArrowScalar: @@ -383,23 +381,13 @@ def binary_expr( # type: ignore[override] def lit(value: Any, dtype: pa.DataType | None = None) -> NativeScalar: - # NOTE: Needed for `pyarrow<13` - if isinstance(value, pa.Scalar): - return value # NOTE: PR that fixed this the overloads was closed # https://github.com/zen-xu/pyarrow-stubs/pull/208 return pa.scalar(value) if dtype is None else pa.scalar(value, dtype) -# NOTE: https://github.com/apache/arrow/issues/21761 -# fmt: off -if BACKEND_VERSION >= (13,): - def array(value: NativeScalar) -> ArrayAny: - return pa.array([value], value.type) -else: - def array(value: NativeScalar) -> ArrayAny: - return cast("ArrayAny", pa.array([value.as_py()], value.type)) -# fmt: on +def array(value: NativeScalar) -> ArrayAny: + return pa.array([value], value.type) def chunked_array( From 97453bb0f9bcb8c57382be7a3a0805a0a75ffb72 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 14 Jul 2025 17:14:44 +0000 Subject: [PATCH 323/368] refactor: Moving around --- narwhals/_plan/arrow/expr.py | 32 ++--- narwhals/_plan/protocols.py | 238 +++++++++++++++++------------------ 2 files changed, 135 insertions(+), 135 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index c42a276695..c6931f3dea 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -82,6 +82,22 @@ def all_(native: Any) -> pa.BooleanScalar: return pc.all(native, min_count=0) +def lit(value: Any, dtype: pa.DataType | None = None) -> NativeScalar: + # NOTE: PR that fixed these the overloads was closed + # https://github.com/zen-xu/pyarrow-stubs/pull/208 + return pa.scalar(value) if dtype is None else pa.scalar(value, dtype) + + +def array(value: NativeScalar) -> ArrayAny: + return pa.array([value], value.type) + + +def chunked_array( + arr: ArrayOrScalar | list[Iterable[Any]], dtype: pa.DataType | None = None, / +) -> ChunkedArrayAny: + return _chunked_array(array(arr) if isinstance(arr, pa.Scalar) else arr, dtype) + + DISPATCH_BINARY: Mapping[type[ops.Operator], BinOp] = { ops.Eq: pc.equal, ops.NotEq: pc.not_equal, @@ -380,22 +396,6 @@ def binary_expr( # type: ignore[override] return self._with_native(result, name) -def lit(value: Any, dtype: pa.DataType | None = None) -> NativeScalar: - # NOTE: PR that fixed this the overloads was closed - # https://github.com/zen-xu/pyarrow-stubs/pull/208 - return pa.scalar(value) if dtype is None else pa.scalar(value, dtype) - - -def array(value: NativeScalar) -> ArrayAny: - return pa.array([value], value.type) - - -def chunked_array( - arr: ArrayOrScalar | list[Iterable[Any]], dtype: pa.DataType | None = None, / -) -> ChunkedArrayAny: - return _chunked_array(array(arr) if isinstance(arr, pa.Scalar) else arr, dtype) - - class ArrowScalar( _ArrowDispatch["ArrowScalar"], _StoresNative[NativeScalar], diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 30927ac445..cc31fe103f 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -126,125 +126,6 @@ def _length_required( return max_length if required else None -class CompliantExpr(StoresVersion, Protocol[FrameT_contra, SeriesT_co]): - """Everything common to `Expr`/`Series` and `Scalar` literal values. - - Early notes: - - Separating series/scalar makes a lot of sense - - Handling the recursive case *without* intermediate (non-pyarrow) objects seems unachievable - - Everywhere would need to first check if it a scalar, which isn't ergonomic - - Broadcasting being separated is working - - A lot of `pyarrow.compute` (section 2) can work on either scalar or series (`FunctionExpr`) - - Aggregation can't, but that is already handled in `ExprIR` - - `polars` noops on aggregating a scalar, which we might be able to support this way - """ - - _evaluated: Any - """Compliant or native value.""" - - @property - def name(self) -> str: ... - - @classmethod - def from_native( - cls, native: Any, name: str = "", /, version: Version = Version.MAIN - ) -> Self: ... - - def _with_native(self, native: Any, name: str, /) -> Self: - return self.from_native(native, name or self.name, self.version) - - # series & scalar - def cast(self, node: expr.Cast, frame: FrameT_contra, name: str) -> Self: ... - def pow(self, node: FunctionExpr[F.Pow], frame: FrameT_contra, name: str) -> Self: ... - def not_( - self, node: FunctionExpr[boolean.Not], frame: FrameT_contra, name: str - ) -> Self: ... - def fill_null( - self, node: FunctionExpr[F.FillNull], frame: FrameT_contra, name: str - ) -> Self: ... - def is_between( - self, node: FunctionExpr[boolean.IsBetween], frame: FrameT_contra, name: str - ) -> Self: ... - def is_finite( - self, node: FunctionExpr[boolean.IsFinite], frame: FrameT_contra, name: str - ) -> Self: ... - def is_nan( - self, node: FunctionExpr[boolean.IsNan], frame: FrameT_contra, name: str - ) -> Self: ... - def is_null( - self, node: FunctionExpr[boolean.IsNull], frame: FrameT_contra, name: str - ) -> Self: ... - def binary_expr( - self, node: expr.BinaryExpr, frame: FrameT_contra, name: str - ) -> Self: ... - def ternary_expr( - self, node: expr.Ternary, frame: FrameT_contra, name: str - ) -> Self: ... - def over(self, node: expr.WindowExpr, frame: FrameT_contra, name: str) -> Self: ... - def over_ordered( - self, node: expr.OrderedWindowExpr, frame: FrameT_contra, name: str - ) -> Self: ... - def map_batches( - self, node: expr.AnonymousExpr, frame: FrameT_contra, name: str - ) -> Self: ... - def rolling_expr( - self, node: expr.RollingExpr, frame: FrameT_contra, name: str - ) -> Self: ... - # series only (section 3) - def sort(self, node: expr.Sort, frame: FrameT_contra, name: str) -> Self: ... - def sort_by(self, node: expr.SortBy, frame: FrameT_contra, name: str) -> Self: ... - def filter(self, node: expr.Filter, frame: FrameT_contra, name: str) -> Self: ... - # series -> scalar - def first( - self, node: agg.First, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def last( - self, node: agg.Last, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def arg_min( - self, node: agg.ArgMin, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def arg_max( - self, node: agg.ArgMax, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def sum( - self, node: agg.Sum, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def n_unique( - self, node: agg.NUnique, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def std( - self, node: agg.Std, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def var( - self, node: agg.Var, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def quantile( - self, node: agg.Quantile, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def count( - self, node: agg.Count, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def max( - self, node: agg.Max, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def mean( - self, node: agg.Mean, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def median( - self, node: agg.Median, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def min( - self, node: agg.Min, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def all( - self, node: FunctionExpr[boolean.All], frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def any( - self, node: FunctionExpr[boolean.Any], frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - - class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]): _DISPATCH: ClassVar[Mapping[type[ExprIR], Callable[[Any, ExprIR, Any, str], Any]]] = { expr.Column: lambda self, node, frame, name: self.__narwhals_namespace__().col( @@ -379,6 +260,125 @@ def from_named_ir(cls, named_ir: NamedIR[ExprIR], frame: FrameT_contra) -> R_co: def __narwhals_namespace__(self) -> NamespaceT_co: ... +class CompliantExpr(StoresVersion, Protocol[FrameT_contra, SeriesT_co]): + """Everything common to `Expr`/`Series` and `Scalar` literal values. + + Early notes: + - Separating series/scalar makes a lot of sense + - Handling the recursive case *without* intermediate (non-pyarrow) objects seems unachievable + - Everywhere would need to first check if it a scalar, which isn't ergonomic + - Broadcasting being separated is working + - A lot of `pyarrow.compute` (section 2) can work on either scalar or series (`FunctionExpr`) + - Aggregation can't, but that is already handled in `ExprIR` + - `polars` noops on aggregating a scalar, which we might be able to support this way + """ + + _evaluated: Any + """Compliant or native value.""" + + @property + def name(self) -> str: ... + + @classmethod + def from_native( + cls, native: Any, name: str = "", /, version: Version = Version.MAIN + ) -> Self: ... + + def _with_native(self, native: Any, name: str, /) -> Self: + return self.from_native(native, name or self.name, self.version) + + # series & scalar + def cast(self, node: expr.Cast, frame: FrameT_contra, name: str) -> Self: ... + def pow(self, node: FunctionExpr[F.Pow], frame: FrameT_contra, name: str) -> Self: ... + def not_( + self, node: FunctionExpr[boolean.Not], frame: FrameT_contra, name: str + ) -> Self: ... + def fill_null( + self, node: FunctionExpr[F.FillNull], frame: FrameT_contra, name: str + ) -> Self: ... + def is_between( + self, node: FunctionExpr[boolean.IsBetween], frame: FrameT_contra, name: str + ) -> Self: ... + def is_finite( + self, node: FunctionExpr[boolean.IsFinite], frame: FrameT_contra, name: str + ) -> Self: ... + def is_nan( + self, node: FunctionExpr[boolean.IsNan], frame: FrameT_contra, name: str + ) -> Self: ... + def is_null( + self, node: FunctionExpr[boolean.IsNull], frame: FrameT_contra, name: str + ) -> Self: ... + def binary_expr( + self, node: expr.BinaryExpr, frame: FrameT_contra, name: str + ) -> Self: ... + def ternary_expr( + self, node: expr.Ternary, frame: FrameT_contra, name: str + ) -> Self: ... + def over(self, node: expr.WindowExpr, frame: FrameT_contra, name: str) -> Self: ... + def over_ordered( + self, node: expr.OrderedWindowExpr, frame: FrameT_contra, name: str + ) -> Self: ... + def map_batches( + self, node: expr.AnonymousExpr, frame: FrameT_contra, name: str + ) -> Self: ... + def rolling_expr( + self, node: expr.RollingExpr, frame: FrameT_contra, name: str + ) -> Self: ... + # series only (section 3) + def sort(self, node: expr.Sort, frame: FrameT_contra, name: str) -> Self: ... + def sort_by(self, node: expr.SortBy, frame: FrameT_contra, name: str) -> Self: ... + def filter(self, node: expr.Filter, frame: FrameT_contra, name: str) -> Self: ... + # series -> scalar + def first( + self, node: agg.First, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def last( + self, node: agg.Last, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def arg_min( + self, node: agg.ArgMin, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def arg_max( + self, node: agg.ArgMax, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def sum( + self, node: agg.Sum, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def n_unique( + self, node: agg.NUnique, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def std( + self, node: agg.Std, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def var( + self, node: agg.Var, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def quantile( + self, node: agg.Quantile, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def count( + self, node: agg.Count, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def max( + self, node: agg.Max, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def mean( + self, node: agg.Mean, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def median( + self, node: agg.Median, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def min( + self, node: agg.Min, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def all( + self, node: FunctionExpr[boolean.All], frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def any( + self, node: FunctionExpr[boolean.Any], frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + + class CompliantScalar( CompliantExpr[FrameT_contra, SeriesT_co], Protocol[FrameT_contra, SeriesT_co] ): From fbb463ba8a73260c1127e08ce23f19cbe9cb1026 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 14 Jul 2025 19:56:27 +0000 Subject: [PATCH 324/368] feat(pyarrow): Impl `all_horizontal` --- narwhals/_plan/arrow/namespace.py | 22 +++++++++------ narwhals/_plan/protocols.py | 14 +++++----- tests/plan/compliant_test.py | 46 +++++++++++++++++++++++++++---- 3 files changed, 62 insertions(+), 20 deletions(-) diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index 16dc0f6c0f..98f4f59a26 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -1,8 +1,10 @@ from __future__ import annotations +from functools import reduce from typing import TYPE_CHECKING, overload import pyarrow as pa # ignore-banned-import +import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import narwhals_to_native_dtype from narwhals._plan.literal import is_literal_scalar @@ -97,37 +99,41 @@ def lit( def any_horizontal( self, node: FunctionExpr[AnyHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr: + ) -> ArrowExpr | ArrowScalar: raise NotImplementedError def all_horizontal( self, node: FunctionExpr[AllHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr: - raise NotImplementedError + ) -> ArrowExpr | ArrowScalar: + it = (self._expr.from_ir(e, frame, name).native for e in node.input) + result = reduce(pc.and_kleene, it) # type: ignore[arg-type] + if isinstance(result, pa.Scalar): + return self._scalar.from_native(result, name, self.version) + return self._expr.from_native(result, name, self.version) def sum_horizontal( self, node: FunctionExpr[F.SumHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr: + ) -> ArrowExpr | ArrowScalar: raise NotImplementedError def min_horizontal( self, node: FunctionExpr[F.MinHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr: + ) -> ArrowExpr | ArrowScalar: raise NotImplementedError def max_horizontal( self, node: FunctionExpr[F.MaxHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr: + ) -> ArrowExpr | ArrowScalar: raise NotImplementedError def mean_horizontal( self, node: FunctionExpr[F.MeanHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr: + ) -> ArrowExpr | ArrowScalar: raise NotImplementedError def concat_str( self, node: FunctionExpr[ConcatHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr: + ) -> ArrowExpr | ArrowScalar: raise NotImplementedError def int_range( diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index cc31fe103f..3d860c5bf3 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -539,25 +539,25 @@ def lit( def len(self, node: expr.Len, frame: FrameT, name: str) -> ScalarT_co: ... def any_horizontal( self, node: FunctionExpr[boolean.AnyHorizontal], frame: FrameT, name: str - ) -> ExprT_co: ... + ) -> ExprT_co | ScalarT_co: ... def all_horizontal( self, node: FunctionExpr[boolean.AllHorizontal], frame: FrameT, name: str - ) -> ExprT_co: ... + ) -> ExprT_co | ScalarT_co: ... def sum_horizontal( self, node: FunctionExpr[F.SumHorizontal], frame: FrameT, name: str - ) -> ExprT_co: ... + ) -> ExprT_co | ScalarT_co: ... def min_horizontal( self, node: FunctionExpr[F.MinHorizontal], frame: FrameT, name: str - ) -> ExprT_co: ... + ) -> ExprT_co | ScalarT_co: ... def max_horizontal( self, node: FunctionExpr[F.MaxHorizontal], frame: FrameT, name: str - ) -> ExprT_co: ... + ) -> ExprT_co | ScalarT_co: ... def mean_horizontal( self, node: FunctionExpr[F.MeanHorizontal], frame: FrameT, name: str - ) -> ExprT_co: ... + ) -> ExprT_co | ScalarT_co: ... def concat_str( self, node: FunctionExpr[strings.ConcatHorizontal], frame: FrameT, name: str - ) -> ExprT_co: ... + ) -> ExprT_co | ScalarT_co: ... def int_range( self, node: RangeExpr[IntRange], frame: FrameT, name: str ) -> ExprT_co: ... diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 48debbfc44..90fcbc270a 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -38,10 +38,6 @@ def _ids_ir(expr: DummyExpr | Any) -> str: reason="Bug in `meta` namespace impl", raises=ComputeError ) -XFAIL_NOT_ALL_HORIZONTAL = pytest.mark.xfail( - reason="Not implemented all_horizontal", raises=NotImplementedError -) - @pytest.mark.parametrize( ("expr", "expected"), @@ -158,9 +154,17 @@ def _ids_ir(expr: DummyExpr | Any) -> str: pytest.param( nwd.when(nwd.col("c") > 5, b=1).then(999), {"literal": [999, None, None]}, - marks=[XFAIL_NOT_ALL_HORIZONTAL], id="When-multiple-predicates", ), + pytest.param( + nwd.when(nwd.col("b") == nwd.col("c"), nwd.col("d").mean() > nwd.col("d")) + .then(123) + .when(nwd.lit(True), ~nwd.nth(-1).is_null()) + .then(456) + .otherwise(nwd.col("c")), + {"literal": [9, 123, 456]}, + id="When-multiple-predicates-mixed-broadcast", + ), pytest.param( nwd.when(nwd.lit(True)).then("c"), {"c": [9, 2, 4]}, @@ -179,6 +183,38 @@ def _ids_ir(expr: DummyExpr | Any) -> str: {"e": [7, 7, 7], "b": [3, 2, 1]}, id="When-literal-then-agg-broadcast", ), + ( + [ + nwd.all_horizontal( + nwd.col("b") < nwd.col("c"), + nwd.col("a") != nwd.lit("B"), + nwd.col("e").cast(nw.Boolean), + nwd.lit(True), + ), + nwd.nth(1).last().name.suffix("_last"), + ], + {"b": [None, False, True], "b_last": [3, 3, 3]}, + ), + ( + [ + nwd.all_horizontal(nwd.lit(True), nwd.lit(True)).alias("a"), + nwd.all_horizontal(nwd.lit(False), nwd.lit(True)).alias("b"), + nwd.all_horizontal(nwd.lit(False), nwd.lit(False)).alias("c"), + nwd.all_horizontal(nwd.lit(None, nw.Boolean), nwd.lit(True)).alias("d"), + nwd.all_horizontal(nwd.lit(None, nw.Boolean), nwd.lit(False)).alias("e"), + nwd.all_horizontal( + nwd.lit(None, nw.Boolean), nwd.lit(None, nw.Boolean) + ).alias("f"), + ], + { + "a": [True], + "b": [False], + "c": [False], + "d": [None], + "e": [False], + "f": [None], + }, + ), ], ids=_ids_ir, ) From 21bf3dbc4e1b9d09df70af5016008f1f1025d483 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 14 Jul 2025 21:48:27 +0000 Subject: [PATCH 325/368] feat(pyarrow): Impl `{any,sum,min,max}_horizontal` --- narwhals/_plan/arrow/namespace.py | 36 ++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index 98f4f59a26..568e310675 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import reduce -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, Any, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import @@ -13,6 +13,8 @@ from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: + from collections.abc import Callable + from narwhals._arrow.typing import ChunkedArrayAny from narwhals._plan import expr, functions as F # noqa: N812 from narwhals._plan.arrow.dataframe import ArrowDataFrame @@ -97,40 +99,54 @@ def lit( nw_ser.to_native(), name or node.name, nw_ser.version ) + # NOTE: Update with `ignore_nulls`/`fill_null` behavior once added to each `Function` + # https://github.com/narwhals-dev/narwhals/pull/2719 + def _horizontal_function( + self, fn: Callable[[Any, Any], Any], / + ) -> Callable[[FunctionExpr[Any], ArrowDataFrame, str], ArrowExpr | ArrowScalar]: + def func( + node: FunctionExpr[Any], frame: ArrowDataFrame, name: str + ) -> ArrowExpr | ArrowScalar: + it = (self._expr.from_ir(e, frame, name).native for e in node.input) + result = reduce(fn, it) + if isinstance(result, pa.Scalar): + return self._scalar.from_native(result, name, self.version) + return self._expr.from_native(result, name, self.version) + + return func + def any_horizontal( self, node: FunctionExpr[AnyHorizontal], frame: ArrowDataFrame, name: str ) -> ArrowExpr | ArrowScalar: - raise NotImplementedError + return self._horizontal_function(pc.or_kleene)(node, frame, name) def all_horizontal( self, node: FunctionExpr[AllHorizontal], frame: ArrowDataFrame, name: str ) -> ArrowExpr | ArrowScalar: - it = (self._expr.from_ir(e, frame, name).native for e in node.input) - result = reduce(pc.and_kleene, it) # type: ignore[arg-type] - if isinstance(result, pa.Scalar): - return self._scalar.from_native(result, name, self.version) - return self._expr.from_native(result, name, self.version) + return self._horizontal_function(pc.and_kleene)(node, frame, name) def sum_horizontal( self, node: FunctionExpr[F.SumHorizontal], frame: ArrowDataFrame, name: str ) -> ArrowExpr | ArrowScalar: - raise NotImplementedError + return self._horizontal_function(pc.add)(node, frame, name) def min_horizontal( self, node: FunctionExpr[F.MinHorizontal], frame: ArrowDataFrame, name: str ) -> ArrowExpr | ArrowScalar: - raise NotImplementedError + return self._horizontal_function(pc.min_element_wise)(node, frame, name) def max_horizontal( self, node: FunctionExpr[F.MaxHorizontal], frame: ArrowDataFrame, name: str ) -> ArrowExpr | ArrowScalar: - raise NotImplementedError + return self._horizontal_function(pc.max_element_wise)(node, frame, name) + # TODO @dangotbanned: Impl `mean_horizontal` def mean_horizontal( self, node: FunctionExpr[F.MeanHorizontal], frame: ArrowDataFrame, name: str ) -> ArrowExpr | ArrowScalar: raise NotImplementedError + # TODO @dangotbanned: Impl `concat_str` def concat_str( self, node: FunctionExpr[ConcatHorizontal], frame: ArrowDataFrame, name: str ) -> ArrowExpr | ArrowScalar: From ec06a5e183abd65f2034ae64169513c3a078411a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 15 Jul 2025 14:32:02 +0000 Subject: [PATCH 326/368] test: Shorten some test ids --- tests/plan/compliant_test.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 90fcbc270a..5fa46cab15 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -84,7 +84,7 @@ def _ids_ir(expr: DummyExpr | Any) -> str: [2 ** nwd.col("b"), (nwd.lit(2.0) ** nwd.nth(1)).alias("lit")], {"literal": [2, 4, 8], "lit": [2, 4, 8]}, ), - ( + pytest.param( [ nwd.col("b").is_between(2, 3, "left").alias("left"), nwd.col("b").is_between(2, 3, "right").alias("right"), @@ -105,8 +105,9 @@ def _ids_ir(expr: DummyExpr | Any) -> str: "c": [False, False, True], "c_right": [False, False, False], }, + id="is_between", ), - ( + pytest.param( [ nwd.col("e").fill_null(0).alias("e_0"), nwd.col("e").fill_null(nwd.col("b")).alias("e_b"), @@ -122,6 +123,7 @@ def _ids_ir(expr: DummyExpr | Any) -> str: "e_b_last": [3, 9, 7], "e_sort_wild": [7, 9, 4], }, + id="sort", ), (nwd.col("e", "d").is_null().any(), {"e": [True], "d": [False]}), ( @@ -159,7 +161,7 @@ def _ids_ir(expr: DummyExpr | Any) -> str: pytest.param( nwd.when(nwd.col("b") == nwd.col("c"), nwd.col("d").mean() > nwd.col("d")) .then(123) - .when(nwd.lit(True), ~nwd.nth(-1).is_null()) + .when(nwd.lit(True), ~nwd.nth(4).is_null()) .then(456) .otherwise(nwd.col("c")), {"literal": [9, 123, 456]}, @@ -183,7 +185,7 @@ def _ids_ir(expr: DummyExpr | Any) -> str: {"e": [7, 7, 7], "b": [3, 2, 1]}, id="When-literal-then-agg-broadcast", ), - ( + pytest.param( [ nwd.all_horizontal( nwd.col("b") < nwd.col("c"), @@ -194,8 +196,9 @@ def _ids_ir(expr: DummyExpr | Any) -> str: nwd.nth(1).last().name.suffix("_last"), ], {"b": [None, False, True], "b_last": [3, 3, 3]}, + id="all-horizontal-mixed-broadcast", ), - ( + pytest.param( [ nwd.all_horizontal(nwd.lit(True), nwd.lit(True)).alias("a"), nwd.all_horizontal(nwd.lit(False), nwd.lit(True)).alias("b"), @@ -214,6 +217,8 @@ def _ids_ir(expr: DummyExpr | Any) -> str: "e": [False], "f": [None], }, + id="all-horizontal-kleene", + ), ), ], ids=_ids_ir, From 031d547b3deecb4f36ba4cbf2216c627ec9eefd6 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 15 Jul 2025 14:32:57 +0000 Subject: [PATCH 327/368] test(pyarrow): Add `any_horizontal` tests --- tests/plan/compliant_test.py | 46 +++++++++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 5fa46cab15..8c352fb058 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -5,6 +5,9 @@ import pytest +pytest.importorskip("pyarrow") +import pyarrow as pa + import narwhals as nw from narwhals._plan import demo as nwd, selectors as ndcs from narwhals._plan.common import is_expr @@ -25,6 +28,12 @@ def data_small() -> dict[str, Any]: "c": [9, 2, 4], "d": [8, 7, 8], "e": [None, 9, 7], + "f": [True, False, None], + "g": [False, None, False], + "h": [None, None, True], + "i": [None, None, None], + "j": [12.1, 13.2, 4.0], + "k": [42, 10, 12], } @@ -37,6 +46,12 @@ def _ids_ir(expr: DummyExpr | Any) -> str: XFAIL_REWRITE_SPECIAL_ALIASES = pytest.mark.xfail( reason="Bug in `meta` namespace impl", raises=ComputeError ) +XFAIL_KLEENE_ALL_NULL = pytest.mark.xfail( + reason="`pyarrow` uses `pa.null()`, which also fails in current `narwhals`.\n" + "In `polars`, the same op is supported and it uses `pl.Null`.\n\n" + "Function 'or_kleene' has no kernel matching input types (bool, null)", + raises=pa.ArrowNotImplementedError, +) @pytest.mark.parametrize( @@ -219,6 +234,34 @@ def _ids_ir(expr: DummyExpr | Any) -> str: }, id="all-horizontal-kleene", ), + pytest.param( + [ + nwd.any_horizontal("f", "g"), + nwd.any_horizontal("g", "h"), + nwd.any_horizontal(nwd.lit(False), nwd.col("g").last()).alias( + "False-False" + ), + ], + { + "f": [True, None, None], + "g": [None, None, True], + "False-False": [False, False, False], + }, + id="any-horizontal-kleene", + ), + pytest.param( + [ + nwd.any_horizontal(nwd.lit(None, nw.Boolean), "i").alias("None-None"), + nwd.any_horizontal(nwd.lit(True), "i").alias("True-None"), + nwd.any_horizontal(nwd.lit(False), "i").alias("False-None"), + ], + { + "None-None": [None, None, None], + "True-None": [True, True, True], + "False-None": [None, None, None], + }, + id="any-horizontal-kleene-full-null", + marks=XFAIL_KLEENE_ALL_NULL, ), ], ids=_ids_ir, @@ -228,9 +271,6 @@ def test_select( expected: dict[str, Any], data_small: dict[str, Any], ) -> None: - pytest.importorskip("pyarrow") - import pyarrow as pa - from narwhals._plan.dummy import DummyFrame frame = pa.table(data_small) From 51ed0fbe9a1fcb48e5c9585fca428d1d315eab89 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 15 Jul 2025 16:56:34 +0000 Subject: [PATCH 328/368] test: Add equiv to `test_sumh_broadcasting` The rest are working, just need to fill out the method --- tests/plan/compliant_test.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 8c352fb058..775b641293 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -34,6 +34,8 @@ def data_small() -> dict[str, Any]: "i": [None, None, None], "j": [12.1, 13.2, 4.0], "k": [42, 10, 12], + "l": [4, 5, 6], + "m": [0, 1, 2], } @@ -263,6 +265,34 @@ def _ids_ir(expr: DummyExpr | Any) -> str: id="any-horizontal-kleene-full-null", marks=XFAIL_KLEENE_ALL_NULL, ), + pytest.param( + [ + nwd.col("b").alias("a"), + nwd.col("l").alias("b"), + nwd.col("m").alias("i"), + nwd.any_horizontal(nwd.sum("b", "l").cast(nw.Boolean)).alias("any"), + nwd.all_horizontal(nwd.sum("b", "l").cast(nw.Boolean)).alias("all"), + nwd.max_horizontal(nwd.sum("b"), nwd.sum("l")).alias("max"), + nwd.min_horizontal(nwd.sum("b"), nwd.sum("l")).alias("min"), + nwd.sum_horizontal(nwd.sum("b"), nwd.sum("l")).alias("sum"), + nwd.mean_horizontal(nwd.sum("b"), nwd.sum("l")).alias("mean"), + ], + { + "a": [1, 2, 3], + "b": [4, 5, 6], + "i": [0, 1, 2], + "any": [True, True, True], + "all": [True, True, True], + "max": [15, 15, 15], + "min": [6, 6, 6], + "sum": [21, 21, 21], + "mean": [10.5, 10.5, 10.5], + }, + id="sumh_broadcasting", + marks=pytest.mark.xfail( + reason="`mean_horizontal` not implemented", raises=NotImplementedError + ), + ), ], ids=_ids_ir, ) From b0655736747550284566c2098e63fd4790438bba Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 15 Jul 2025 18:46:53 +0000 Subject: [PATCH 329/368] feat(pyarrow): Impl `mean_horizontal` --- narwhals/_plan/arrow/namespace.py | 16 +++++++++++++--- tests/plan/compliant_test.py | 12 +++++++----- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index 568e310675..b6bfa5b1f2 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from collections.abc import Callable - from narwhals._arrow.typing import ChunkedArrayAny + from narwhals._arrow.typing import ChunkedArrayAny, Incomplete from narwhals._plan import expr, functions as F # noqa: N812 from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar @@ -140,11 +140,21 @@ def max_horizontal( ) -> ArrowExpr | ArrowScalar: return self._horizontal_function(pc.max_element_wise)(node, frame, name) - # TODO @dangotbanned: Impl `mean_horizontal` def mean_horizontal( self, node: FunctionExpr[F.MeanHorizontal], frame: ArrowDataFrame, name: str ) -> ArrowExpr | ArrowScalar: - raise NotImplementedError + from narwhals._plan.arrow.expr import lit, truediv_compat + + # NOTE: Overloads too broken + add: Incomplete = pc.add + sub = pc.subtract + inputs = [self._expr.from_ir(e, frame, name).native for e in node.input] + filled = (pc.fill_null(native, lit(0)) for native in inputs) + not_null = (sub(lit(1), pc.is_null(native).cast(pa.int64())) for native in inputs) + result = truediv_compat(reduce(add, filled), reduce(add, not_null)) + if isinstance(result, pa.Scalar): + return self._scalar.from_native(result, name, self.version) + return self._expr.from_native(result, name, self.version) # TODO @dangotbanned: Impl `concat_str` def concat_str( diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 775b641293..0efb812823 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -32,8 +32,8 @@ def data_small() -> dict[str, Any]: "g": [False, None, False], "h": [None, None, True], "i": [None, None, None], - "j": [12.1, 13.2, 4.0], - "k": [42, 10, 12], + "j": [12.1, None, 4.0], + "k": [42, 10, None], "l": [4, 5, 6], "m": [0, 1, 2], } @@ -289,9 +289,11 @@ def _ids_ir(expr: DummyExpr | Any) -> str: "mean": [10.5, 10.5, 10.5], }, id="sumh_broadcasting", - marks=pytest.mark.xfail( - reason="`mean_horizontal` not implemented", raises=NotImplementedError - ), + ), + pytest.param( + nwd.mean_horizontal("j", nwd.col("k"), "e"), + {"j": [27.05, 9.5, 5.5]}, + id="mean_horizontal-null", ), ], ids=_ids_ir, From 3732e5a6b56411157f13307dfdbd25e397d5b8e6 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 15 Jul 2025 18:56:38 +0000 Subject: [PATCH 330/368] fix: fill nulls in `sum_horizontal` --- narwhals/_plan/arrow/namespace.py | 8 ++++++-- tests/plan/compliant_test.py | 5 +++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index b6bfa5b1f2..469fa7a885 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -102,12 +102,16 @@ def lit( # NOTE: Update with `ignore_nulls`/`fill_null` behavior once added to each `Function` # https://github.com/narwhals-dev/narwhals/pull/2719 def _horizontal_function( - self, fn: Callable[[Any, Any], Any], / + self, fn: Callable[[Any, Any], Any], /, fill: NonNestedLiteral = None ) -> Callable[[FunctionExpr[Any], ArrowDataFrame, str], ArrowExpr | ArrowScalar]: + from narwhals._plan.arrow.expr import lit + def func( node: FunctionExpr[Any], frame: ArrowDataFrame, name: str ) -> ArrowExpr | ArrowScalar: it = (self._expr.from_ir(e, frame, name).native for e in node.input) + if fill is not None: + it = (pc.fill_null(native, lit(fill)) for native in it) result = reduce(fn, it) if isinstance(result, pa.Scalar): return self._scalar.from_native(result, name, self.version) @@ -128,7 +132,7 @@ def all_horizontal( def sum_horizontal( self, node: FunctionExpr[F.SumHorizontal], frame: ArrowDataFrame, name: str ) -> ArrowExpr | ArrowScalar: - return self._horizontal_function(pc.add)(node, frame, name) + return self._horizontal_function(pc.add, fill=0)(node, frame, name) def min_horizontal( self, node: FunctionExpr[F.MinHorizontal], frame: ArrowDataFrame, name: str diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 0efb812823..0cff170fd2 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -295,6 +295,11 @@ def _ids_ir(expr: DummyExpr | Any) -> str: {"j": [27.05, 9.5, 5.5]}, id="mean_horizontal-null", ), + pytest.param( + nwd.sum_horizontal("j", nwd.col("k"), "e"), + {"j": [54.1, 19.0, 11.0]}, + id="sum_horizontal-null", + ), ], ids=_ids_ir, ) From 2b44d63a499443f2262022910ff835e097cc0ca9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 15 Jul 2025 20:49:25 +0000 Subject: [PATCH 331/368] test: Add detail to xfail message --- tests/plan/compliant_test.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 0cff170fd2..53a27aba61 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -46,7 +46,10 @@ def _ids_ir(expr: DummyExpr | Any) -> str: XFAIL_REWRITE_SPECIAL_ALIASES = pytest.mark.xfail( - reason="Bug in `meta` namespace impl", raises=ComputeError + reason="https://github.com/narwhals-dev/narwhals/blob/3732e5a6b56411157f13307dfdbd25e397d5b8e6/narwhals/_plan/meta.py#L142-L162\n" + "Matches behavior of `polars`\n" + "pl.select(pl.lit(1).name.suffix('_suffix'))", + raises=ComputeError, ) XFAIL_KLEENE_ALL_NULL = pytest.mark.xfail( reason="`pyarrow` uses `pa.null()`, which also fails in current `narwhals`.\n" From 17634cd19a8356c61a92134759e05f9e55387fc3 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 16 Jul 2025 16:35:29 +0000 Subject: [PATCH 332/368] refactor: Split out `functions`, `typing` --- narwhals/_plan/arrow/dataframe.py | 7 +- narwhals/_plan/arrow/expr.py | 183 ++++++++-------------------- narwhals/_plan/arrow/functions.py | 190 ++++++++++++++++++++++++++++++ narwhals/_plan/arrow/namespace.py | 38 +++--- narwhals/_plan/arrow/typing.py | 103 ++++++++++++++++ 5 files changed, 364 insertions(+), 157 deletions(-) create mode 100644 narwhals/_plan/arrow/functions.py create mode 100644 narwhals/_plan/arrow/typing.py diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 1a2209db0f..d27784b765 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -16,9 +16,9 @@ if t.TYPE_CHECKING: from collections.abc import Iterable, Iterator - from typing_extensions import Self, TypeAlias, TypeIs + from typing_extensions import Self, TypeIs - from narwhals._arrow.typing import ChunkedArrayAny, ScalarAny + from narwhals._arrow.typing import ChunkedArrayAny from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.common import ExprIR, NamedIR from narwhals._plan.options import SortMultipleOptions @@ -28,9 +28,6 @@ from narwhals.schema import Schema -UnaryFn: TypeAlias = "t.Callable[[ChunkedArrayAny], ScalarAny]" - - def is_series(obj: t.Any) -> TypeIs[ArrowSeries]: return isinstance(obj, ArrowSeries) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index c6931f3dea..e822017506 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -1,37 +1,27 @@ from __future__ import annotations -from collections.abc import Callable, Mapping from typing import TYPE_CHECKING, Any, Protocol, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import -from narwhals._arrow.utils import ( - cast_for_truediv, - chunked_array as _chunked_array, - floordiv_compat, - narwhals_to_native_dtype, -) -from narwhals._plan import boolean, functions as F, operators as ops # noqa: N812 +from narwhals._arrow.utils import narwhals_to_native_dtype +from narwhals._plan.arrow import functions as fn +from narwhals._plan.arrow.functions import lit from narwhals._plan.arrow.series import ArrowSeries +from narwhals._plan.arrow.typing import NativeScalar, StoresNativeT_co from narwhals._plan.common import ExprIR, into_dtype from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch -from narwhals._typing_compat import TypeVar from narwhals._utils import Implementation, Version, _StoresNative, not_implemented from narwhals.exceptions import InvalidOperationError, ShapeError if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Callable - from typing_extensions import Self, TypeAlias + from typing_extensions import Self - from narwhals._arrow.typing import ( - ArrayAny, - ArrayOrScalar, - ChunkedArrayAny, - Incomplete, - ) - from narwhals._plan import expr + from narwhals._arrow.typing import ChunkedArrayAny, Incomplete + from narwhals._plan import boolean, expr from narwhals._plan.aggregation import ( ArgMax, ArgMin, @@ -52,80 +42,15 @@ from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.boolean import IsBetween, IsFinite, IsNan, IsNull from narwhals._plan.expr import BinaryExpr, FunctionExpr, Ternary - from narwhals.typing import ClosedInterval, IntoDType, PythonLiteral + from narwhals._plan.functions import FillNull, Pow + from narwhals.typing import IntoDType, PythonLiteral -NativeScalar: TypeAlias = "pa.Scalar[Any]" -BinOp: TypeAlias = Callable[..., "ChunkedArrayAny | NativeScalar"] -LogicalOp: TypeAlias = Callable[ - ..., "pa.ChunkedArray[pa.BooleanScalar] | pa.BooleanScalar" -] BACKEND_VERSION = Implementation.PYARROW._backend_version() -_StoresNativeAny: TypeAlias = _StoresNative[Any] -_StoresNativeT_co = TypeVar("_StoresNativeT_co", bound=_StoresNativeAny, covariant=True) - - -def truediv_compat(lhs: Any, rhs: Any) -> Any: - return pc.divide(*cast_for_truediv(lhs, rhs)) - - -def modulus(lhs: Any, rhs: Any) -> Any: - floor_div = floordiv_compat(lhs, rhs) - return pc.subtract(lhs, pc.multiply(floor_div, rhs)) - - -def any_(native: Any) -> pa.BooleanScalar: - return pc.any(native, min_count=0) - - -def all_(native: Any) -> pa.BooleanScalar: - return pc.all(native, min_count=0) - - -def lit(value: Any, dtype: pa.DataType | None = None) -> NativeScalar: - # NOTE: PR that fixed these the overloads was closed - # https://github.com/zen-xu/pyarrow-stubs/pull/208 - return pa.scalar(value) if dtype is None else pa.scalar(value, dtype) - - -def array(value: NativeScalar) -> ArrayAny: - return pa.array([value], value.type) - - -def chunked_array( - arr: ArrayOrScalar | list[Iterable[Any]], dtype: pa.DataType | None = None, / -) -> ChunkedArrayAny: - return _chunked_array(array(arr) if isinstance(arr, pa.Scalar) else arr, dtype) - - -DISPATCH_BINARY: Mapping[type[ops.Operator], BinOp] = { - ops.Eq: pc.equal, - ops.NotEq: pc.not_equal, - ops.Lt: pc.less, - ops.LtEq: pc.less_equal, - ops.Gt: pc.greater, - ops.GtEq: pc.greater_equal, - ops.Add: pc.add, - ops.Sub: pc.subtract, - ops.Multiply: pc.multiply, - ops.TrueDivide: truediv_compat, - ops.FloorDivide: floordiv_compat, - ops.Modulus: modulus, - ops.And: pc.and_kleene, - ops.Or: pc.or_kleene, - ops.ExclusiveOr: pc.xor, -} - -IS_BETWEEN: Mapping[ClosedInterval, tuple[LogicalOp, LogicalOp]] = { - "left": (pc.greater_equal, pc.less), - "right": (pc.greater, pc.less_equal), - "none": (pc.greater, pc.less), - "both": (pc.greater_equal, pc.less_equal), -} class _ArrowDispatch( - ExprDispatch["ArrowDataFrame", _StoresNativeT_co, "ArrowNamespace"], Protocol + ExprDispatch["ArrowDataFrame", StoresNativeT_co, "ArrowNamespace"], Protocol ): """Common to `Expr`, `Scalar` + their dependencies.""" @@ -134,25 +59,23 @@ def __narwhals_namespace__(self) -> ArrowNamespace: return ArrowNamespace(self.version) - def _with_native(self, native: Any, name: str, /) -> _StoresNativeT_co: ... - def cast( - self, node: expr.Cast, frame: ArrowDataFrame, name: str - ) -> _StoresNativeT_co: + def _with_native(self, native: Any, name: str, /) -> StoresNativeT_co: ... + def cast(self, node: expr.Cast, frame: ArrowDataFrame, name: str) -> StoresNativeT_co: data_type = narwhals_to_native_dtype(node.dtype, frame.version) native = self._dispatch(node.expr, frame, name).native - return self._with_native(pc.cast(native, data_type), name) + return self._with_native(fn.cast(native, data_type), name) def pow( - self, node: FunctionExpr[F.Pow], frame: ArrowDataFrame, name: str - ) -> _StoresNativeT_co: + self, node: FunctionExpr[Pow], frame: ArrowDataFrame, name: str + ) -> StoresNativeT_co: base, exponent = node.function.unwrap_input(node) base_ = self._dispatch(base, frame, "base").native exponent_ = self._dispatch(exponent, frame, "exponent").native return self._with_native(pc.power(base_, exponent_), name) def fill_null( - self, node: FunctionExpr[F.FillNull], frame: ArrowDataFrame, name: str - ) -> _StoresNativeT_co: + self, node: FunctionExpr[FillNull], frame: ArrowDataFrame, name: str + ) -> StoresNativeT_co: expr, value = node.function.unwrap_input(node) native = self._dispatch(expr, frame, name).native value_ = self._dispatch(value, frame, "value").native @@ -160,59 +83,58 @@ def fill_null( def is_between( self, node: FunctionExpr[IsBetween], frame: ArrowDataFrame, name: str - ) -> _StoresNativeT_co: + ) -> StoresNativeT_co: expr, lower_bound, upper_bound = node.function.unwrap_input(node) native = self._dispatch(expr, frame, name).native lower = self._dispatch(lower_bound, frame, "lower").native upper = self._dispatch(upper_bound, frame, "upper").native - fn_lhs, fn_rhs = IS_BETWEEN[node.function.closed] - result = pc.and_kleene(fn_lhs(native, lower), fn_rhs(native, upper)) + result = fn.is_between(native, lower, upper, node.function.closed) return self._with_native(result, name) def _unary_function( - self, fn: Callable[[Any], Any], / - ) -> Callable[[FunctionExpr[Any], ArrowDataFrame, str], _StoresNativeT_co]: + self, fn_native: Callable[[Any], Any], / + ) -> Callable[[FunctionExpr[Any], ArrowDataFrame, str], StoresNativeT_co]: def func( node: FunctionExpr[Any], frame: ArrowDataFrame, name: str - ) -> _StoresNativeT_co: + ) -> StoresNativeT_co: native = self._dispatch(node.input[0], frame, name).native - return self._with_native(fn(native), name) + return self._with_native(fn_native(native), name) return func def not_( self, node: FunctionExpr[boolean.Not], frame: ArrowDataFrame, name: str - ) -> _StoresNativeT_co: + ) -> StoresNativeT_co: return self._unary_function(pc.invert)(node, frame, name) def all( self, node: FunctionExpr[boolean.All], frame: ArrowDataFrame, name: str - ) -> _StoresNativeT_co: - return self._unary_function(all_)(node, frame, name) + ) -> StoresNativeT_co: + return self._unary_function(fn.all_)(node, frame, name) def any( self, node: FunctionExpr[boolean.Any], frame: ArrowDataFrame, name: str - ) -> _StoresNativeT_co: - return self._unary_function(any_)(node, frame, name) + ) -> StoresNativeT_co: + return self._unary_function(fn.any_)(node, frame, name) def is_finite( self, node: FunctionExpr[IsFinite], frame: ArrowDataFrame, name: str - ) -> _StoresNativeT_co: - return self._unary_function(pc.is_finite)(node, frame, name) + ) -> StoresNativeT_co: + return self._unary_function(fn.is_finite)(node, frame, name) def is_nan( self, node: FunctionExpr[IsNan], frame: ArrowDataFrame, name: str - ) -> _StoresNativeT_co: - return self._unary_function(pc.is_nan)(node, frame, name) + ) -> StoresNativeT_co: + return self._unary_function(fn.is_nan)(node, frame, name) def is_null( self, node: FunctionExpr[IsNull], frame: ArrowDataFrame, name: str - ) -> _StoresNativeT_co: - return self._unary_function(pc.is_null)(node, frame, name) + ) -> StoresNativeT_co: + return self._unary_function(fn.is_null)(node, frame, name) def ternary_expr( self, node: Ternary, frame: ArrowDataFrame, name: str - ) -> _StoresNativeT_co: + ) -> StoresNativeT_co: when = self._dispatch(node.predicate, frame, name) then = self._dispatch(node.truthy, frame, name) otherwise = self._dispatch(node.falsy, frame, name) @@ -326,38 +248,36 @@ def last(self, node: Last, frame: ArrowDataFrame, name: str) -> ArrowScalar: def arg_min(self, node: ArgMin, frame: ArrowDataFrame, name: str) -> ArrowScalar: native = self._dispatch_expr(node.expr, frame, name).native - result = pc.index(native, pc.min(native)) + result = pc.index(native, fn.min_(native)) return self._with_native(result, name) def arg_max(self, node: ArgMax, frame: ArrowDataFrame, name: str) -> ArrowScalar: native = self._dispatch_expr(node.expr, frame, name).native - result: NativeScalar = pc.index(native, pc.max(native)) + result: NativeScalar = pc.index(native, fn.max_(native)) return self._with_native(result, name) def sum(self, node: Sum, frame: ArrowDataFrame, name: str) -> ArrowScalar: - result: NativeScalar = pc.sum( - self._dispatch_expr(node.expr, frame, name).native, min_count=0 - ) + result = fn.sum_(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) def n_unique(self, node: NUnique, frame: ArrowDataFrame, name: str) -> ArrowScalar: - result = pc.count(self._dispatch_expr(node.expr, frame, name).native, mode="all") + result = fn.n_unique(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) def std(self, node: Std, frame: ArrowDataFrame, name: str) -> ArrowScalar: - result = pc.stddev( + result = fn.std( self._dispatch_expr(node.expr, frame, name).native, ddof=node.ddof ) return self._with_native(result, name) def var(self, node: Var, frame: ArrowDataFrame, name: str) -> ArrowScalar: - result = pc.variance( + result = fn.var( self._dispatch_expr(node.expr, frame, name).native, ddof=node.ddof ) return self._with_native(result, name) def quantile(self, node: Quantile, frame: ArrowDataFrame, name: str) -> ArrowScalar: - result = pc.quantile( + result = fn.quantile( self._dispatch_expr(node.expr, frame, name).native, q=node.quantile, interpolation=node.interpolation, @@ -365,23 +285,23 @@ def quantile(self, node: Quantile, frame: ArrowDataFrame, name: str) -> ArrowSca return self._with_native(result, name) def count(self, node: Count, frame: ArrowDataFrame, name: str) -> ArrowScalar: - result = pc.count(self._dispatch_expr(node.expr, frame, name).native) + result = fn.count(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) def max(self, node: Max, frame: ArrowDataFrame, name: str) -> ArrowScalar: - result: NativeScalar = pc.max(self._dispatch_expr(node.expr, frame, name).native) + result: NativeScalar = fn.max_(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) def mean(self, node: Mean, frame: ArrowDataFrame, name: str) -> ArrowScalar: - result = pc.mean(self._dispatch_expr(node.expr, frame, name).native) + result = fn.mean(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) def median(self, node: Median, frame: ArrowDataFrame, name: str) -> ArrowScalar: - result = pc.approximate_median(self._dispatch_expr(node.expr, frame, name).native) + result = fn.median(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) def min(self, node: Min, frame: ArrowDataFrame, name: str) -> ArrowScalar: - result: NativeScalar = pc.min(self._dispatch_expr(node.expr, frame, name).native) + result: NativeScalar = fn.min_(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) def binary_expr( # type: ignore[override] @@ -391,8 +311,7 @@ def binary_expr( # type: ignore[override] self._dispatch(node.left, frame, name), self._dispatch(node.right, frame, name), ) - fn = DISPATCH_BINARY[node.op.__class__] - result = fn(lhs.native, rhs.native) + result = fn.binary(lhs.native, node.op.__class__, rhs.native) return self._with_native(result, name) @@ -468,12 +387,12 @@ def to_python(self) -> PythonLiteral: def broadcast(self, length: int) -> ArrowSeries: scalar = self.native if length == 1: - chunked = chunked_array(scalar) + chunked = fn.chunked_array(scalar) else: # NOTE: Same issue as `pa.scalar` overlapping overloads # https://github.com/zen-xu/pyarrow-stubs/pull/209 pa_repeat: Incomplete = pa.repeat - chunked = chunked_array(pa_repeat(scalar, length)) + chunked = fn.chunked_array(pa_repeat(scalar, length)) return ArrowSeries.from_native(chunked, self.name, version=self.version) def arg_min(self, node: ArgMin, frame: ArrowDataFrame, name: str) -> ArrowScalar: diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py new file mode 100644 index 0000000000..8959abd248 --- /dev/null +++ b/narwhals/_plan/arrow/functions.py @@ -0,0 +1,190 @@ +"""Native functions, aliased and/or with behavior aligned to `polars`.""" + +from __future__ import annotations + +import typing as t +from typing import TYPE_CHECKING, Any + +import pyarrow as pa # ignore-banned-import +import pyarrow.compute as pc # ignore-banned-import + +from narwhals._arrow.utils import ( + cast_for_truediv, + chunked_array as _chunked_array, + floordiv_compat as floordiv, +) +from narwhals._plan import operators as ops + +if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + + from narwhals._arrow.typing import ArrayAny, ArrayOrScalar, ChunkedArrayAny + from narwhals._plan.arrow.typing import ( + BinaryComp, + BinaryLogical, + BinaryNumericTemporal, + BinOp, + ChunkedArray, + ChunkedOrScalar, + ChunkedOrScalarAny, + DataType, + DataTypeT, + NativeScalar, + Scalar, + ScalarAny, + ScalarT, + UnaryFunction, + ) + from narwhals.typing import ClosedInterval + +is_null = pc.is_null +is_not_null = t.cast("UnaryFunction[ScalarAny,pa.BooleanScalar]", pc.is_valid) +is_nan = pc.is_nan +is_finite = pc.is_finite + +and_ = t.cast("BinaryLogical", pc.and_kleene) +or_ = t.cast("BinaryLogical", pc.or_kleene) +xor = t.cast("BinaryLogical", pc.xor) + +eq = t.cast("BinaryComp", pc.equal) +not_eq = t.cast("BinaryComp", pc.not_equal) +gt_eq = t.cast("BinaryComp", pc.greater_equal) +gt = t.cast("BinaryComp", pc.greater) +lt_eq = t.cast("BinaryComp", pc.less_equal) +lt = t.cast("BinaryComp", pc.less) + + +add = t.cast("BinaryNumericTemporal", pc.add) +sub = pc.subtract +multiply = pc.multiply + + +def truediv(lhs: Any, rhs: Any) -> Any: + return pc.divide(*cast_for_truediv(lhs, rhs)) + + +def modulus(lhs: Any, rhs: Any) -> Any: + floor_div = floordiv(lhs, rhs) + return sub(lhs, multiply(floor_div, rhs)) + + +_DISPATCH_BINARY: Mapping[type[ops.Operator], BinOp] = { + ops.Eq: eq, + ops.NotEq: not_eq, + ops.Lt: lt, + ops.LtEq: lt_eq, + ops.Gt: gt, + ops.GtEq: gt_eq, + ops.Add: add, + ops.Sub: sub, + ops.Multiply: multiply, + ops.TrueDivide: truediv, + ops.FloorDivide: floordiv, + ops.Modulus: modulus, + ops.And: and_, + ops.Or: or_, + ops.ExclusiveOr: xor, +} + +_IS_BETWEEN: Mapping[ClosedInterval, tuple[BinaryComp, BinaryComp]] = { + "left": (gt_eq, lt), + "right": (gt, lt_eq), + "none": (gt, lt), + "both": (gt_eq, lt_eq), +} + + +@t.overload +def cast( + native: Scalar[Any], target_type: DataTypeT, *, safe: bool | None = ... +) -> Scalar[DataTypeT]: ... + + +@t.overload +def cast( + native: ChunkedArray[Any], target_type: DataTypeT, *, safe: bool | None = ... +) -> ChunkedArray[Scalar[DataTypeT]]: ... + + +@t.overload +def cast( + native: ChunkedOrScalar[Scalar[Any]], + target_type: DataTypeT, + *, + safe: bool | None = ..., +) -> ChunkedArray[Scalar[DataTypeT]] | Scalar[DataTypeT]: ... + + +def cast( + native: ChunkedOrScalar[Scalar[Any]], + target_type: DataTypeT, + *, + safe: bool | None = None, +) -> ChunkedArray[Scalar[DataTypeT]] | Scalar[DataTypeT]: + return pc.cast(native, target_type, safe=safe) + + +def any_(native: Any) -> pa.BooleanScalar: + return pc.any(native, min_count=0) + + +def all_(native: Any) -> pa.BooleanScalar: + return pc.all(native, min_count=0) + + +def sum_(native: Any) -> NativeScalar: + return pc.sum(native, min_count=0) + + +min_ = pc.min +min_horizontal = pc.min_element_wise +max_ = pc.max +max_horizontal = pc.max_element_wise +mean = pc.mean +count = pc.count +median = pc.approximate_median +std = pc.stddev +var = pc.variance +quantile = pc.quantile + + +def n_unique(native: Any) -> pa.Int64Scalar: + return count(native, mode="all") + + +def is_between( + native: ChunkedOrScalar[ScalarT], + lower: ChunkedOrScalar[ScalarT], + upper: ChunkedOrScalar[ScalarT], + closed: ClosedInterval, +) -> ChunkedOrScalar[pa.BooleanScalar]: + fn_lhs, fn_rhs = _IS_BETWEEN[closed] + return and_(fn_lhs(native, lower), fn_rhs(native, upper)) + + +def binary( + lhs: ChunkedOrScalarAny, op: type[ops.Operator], rhs: ChunkedOrScalarAny +) -> ChunkedOrScalarAny: + return _DISPATCH_BINARY[op](lhs, rhs) + + +def lit(value: Any, dtype: DataType | None = None) -> NativeScalar: + # NOTE: PR that fixed these the overloads was closed + # https://github.com/zen-xu/pyarrow-stubs/pull/208 + return pa.scalar(value) if dtype is None else pa.scalar(value, dtype) + + +def array( + value: NativeScalar | Iterable[Any], dtype: DataType | None = None, / +) -> ArrayAny: + return ( + pa.array([value], value.type) + if isinstance(value, pa.Scalar) + else pa.array(value, dtype) + ) + + +def chunked_array( + arr: ArrayOrScalar | list[Iterable[Any]], dtype: DataType | None = None, / +) -> ChunkedArrayAny: + return _chunked_array(array(arr) if isinstance(arr, pa.Scalar) else arr, dtype) diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index 469fa7a885..dbe539a728 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -7,6 +7,8 @@ import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import narwhals_to_native_dtype +from narwhals._plan.arrow import functions as fn +from narwhals._plan.arrow.functions import lit from narwhals._plan.literal import is_literal_scalar from narwhals._plan.protocols import EagerNamespace from narwhals._utils import Version @@ -15,7 +17,7 @@ if TYPE_CHECKING: from collections.abc import Callable - from narwhals._arrow.typing import ChunkedArrayAny, Incomplete + from narwhals._arrow.typing import ChunkedArrayAny from narwhals._plan import expr, functions as F # noqa: N812 from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar @@ -102,17 +104,15 @@ def lit( # NOTE: Update with `ignore_nulls`/`fill_null` behavior once added to each `Function` # https://github.com/narwhals-dev/narwhals/pull/2719 def _horizontal_function( - self, fn: Callable[[Any, Any], Any], /, fill: NonNestedLiteral = None + self, fn_native: Callable[[Any, Any], Any], /, fill: NonNestedLiteral = None ) -> Callable[[FunctionExpr[Any], ArrowDataFrame, str], ArrowExpr | ArrowScalar]: - from narwhals._plan.arrow.expr import lit - def func( node: FunctionExpr[Any], frame: ArrowDataFrame, name: str ) -> ArrowExpr | ArrowScalar: it = (self._expr.from_ir(e, frame, name).native for e in node.input) if fill is not None: it = (pc.fill_null(native, lit(fill)) for native in it) - result = reduce(fn, it) + result = reduce(fn_native, it) if isinstance(result, pa.Scalar): return self._scalar.from_native(result, name, self.version) return self._expr.from_native(result, name, self.version) @@ -122,40 +122,40 @@ def func( def any_horizontal( self, node: FunctionExpr[AnyHorizontal], frame: ArrowDataFrame, name: str ) -> ArrowExpr | ArrowScalar: - return self._horizontal_function(pc.or_kleene)(node, frame, name) + return self._horizontal_function(fn.or_)(node, frame, name) def all_horizontal( self, node: FunctionExpr[AllHorizontal], frame: ArrowDataFrame, name: str ) -> ArrowExpr | ArrowScalar: - return self._horizontal_function(pc.and_kleene)(node, frame, name) + return self._horizontal_function(fn.and_)(node, frame, name) def sum_horizontal( self, node: FunctionExpr[F.SumHorizontal], frame: ArrowDataFrame, name: str ) -> ArrowExpr | ArrowScalar: - return self._horizontal_function(pc.add, fill=0)(node, frame, name) + return self._horizontal_function(fn.add, fill=0)(node, frame, name) def min_horizontal( self, node: FunctionExpr[F.MinHorizontal], frame: ArrowDataFrame, name: str ) -> ArrowExpr | ArrowScalar: - return self._horizontal_function(pc.min_element_wise)(node, frame, name) + return self._horizontal_function(fn.min_horizontal)(node, frame, name) def max_horizontal( self, node: FunctionExpr[F.MaxHorizontal], frame: ArrowDataFrame, name: str ) -> ArrowExpr | ArrowScalar: - return self._horizontal_function(pc.max_element_wise)(node, frame, name) + return self._horizontal_function(fn.max_horizontal)(node, frame, name) def mean_horizontal( self, node: FunctionExpr[F.MeanHorizontal], frame: ArrowDataFrame, name: str ) -> ArrowExpr | ArrowScalar: - from narwhals._plan.arrow.expr import lit, truediv_compat - - # NOTE: Overloads too broken - add: Incomplete = pc.add - sub = pc.subtract + int64 = pa.int64() inputs = [self._expr.from_ir(e, frame, name).native for e in node.input] filled = (pc.fill_null(native, lit(0)) for native in inputs) - not_null = (sub(lit(1), pc.is_null(native).cast(pa.int64())) for native in inputs) - result = truediv_compat(reduce(add, filled), reduce(add, not_null)) + # NOTE: `mypy` doesn't like that `add` is overloaded + sum_not_null = reduce( + fn.add, # type: ignore[arg-type] + (fn.cast(fn.is_not_null(native), int64) for native in inputs), + ) + result = fn.truediv(reduce(fn.add, filled), sum_not_null) if isinstance(result, pa.Scalar): return self._scalar.from_native(result, name, self.version) return self._expr.from_native(result, name, self.version) @@ -192,10 +192,8 @@ def int_range( if isinstance(start_, int) and isinstance(end_, int): import numpy as np # ignore-banned-import - from narwhals._plan.arrow.expr import chunked_array - pa_dtype = narwhals_to_native_dtype(dtype, self.version) - native = chunked_array(pa.array(np.arange(start_, end_, step), pa_dtype)) + native = fn.chunked_array(fn.array(np.arange(start_, end_, step), pa_dtype)) return self._expr.from_native(native, name, self.version) else: diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py new file mode 100644 index 0000000000..81d0703f81 --- /dev/null +++ b/narwhals/_plan/arrow/typing.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Protocol, overload + +from narwhals._typing_compat import TypeVar +from narwhals._utils import _StoresNative as StoresNative + +if TYPE_CHECKING: + import pyarrow as pa + import pyarrow.compute as pc + from typing_extensions import TypeAlias + + +ScalarT = TypeVar("ScalarT", bound="pa.Scalar[Any]", default="pa.Scalar[Any]") +ScalarPT_contra = TypeVar( + "ScalarPT_contra", + bound="pa.Scalar[Any]", + default="pa.Scalar[Any]", + contravariant=True, +) +ScalarRT_co = TypeVar( + "ScalarRT_co", bound="pa.Scalar[Any]", default="pa.Scalar[Any]", covariant=True +) +NumericOrTemporalScalar: TypeAlias = "pc.NumericOrTemporalScalar" +NumericOrTemporalScalarT = TypeVar( + "NumericOrTemporalScalarT", + bound=NumericOrTemporalScalar, + default=NumericOrTemporalScalar, +) + + +class UnaryFunction(Protocol[ScalarPT_contra, ScalarRT_co]): + @overload + def __call__(self, data: ScalarPT_contra, *args: Any, **kwds: Any) -> ScalarRT_co: ... + + @overload + def __call__( + self, data: ChunkedArray[ScalarPT_contra], *args: Any, **kwds: Any + ) -> ChunkedArray[ScalarRT_co]: ... + + @overload + def __call__( + self, data: ChunkedOrScalar[ScalarPT_contra], *args: Any, **kwds: Any + ) -> ChunkedOrScalar[ScalarRT_co]: ... + + def __call__( + self, data: ChunkedOrScalar[ScalarPT_contra], *args: Any, **kwds: Any + ) -> ChunkedOrScalar[ScalarRT_co]: ... + + +class BinaryFunction(Protocol[ScalarPT_contra, ScalarRT_co]): + @overload + def __call__(self, x: ScalarPT_contra, y: ScalarPT_contra, /) -> ScalarRT_co: ... + + @overload + def __call__( + self, x: ChunkedArray[ScalarPT_contra], y: ChunkedArray[ScalarPT_contra], / + ) -> ChunkedArray[ScalarRT_co]: ... + + @overload + def __call__( + self, x: ScalarPT_contra, y: ChunkedArray[ScalarPT_contra], / + ) -> ChunkedArray[ScalarRT_co]: ... + + @overload + def __call__( + self, x: ChunkedArray[ScalarPT_contra], y: ScalarPT_contra, / + ) -> ChunkedArray[ScalarRT_co]: ... + + @overload + def __call__( + self, x: ChunkedOrScalar[ScalarPT_contra], y: ChunkedOrScalar[ScalarPT_contra], / + ) -> ChunkedOrScalar[ScalarRT_co]: ... + + def __call__( + self, x: ChunkedOrScalar[ScalarPT_contra], y: ChunkedOrScalar[ScalarPT_contra], / + ) -> ChunkedOrScalar[ScalarRT_co]: ... + + +class BinaryComp( + BinaryFunction[ScalarPT_contra, "pa.BooleanScalar"], Protocol[ScalarPT_contra] +): ... + + +class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Protocol): ... + + +BinaryNumericTemporal: TypeAlias = BinaryFunction[ + NumericOrTemporalScalarT, NumericOrTemporalScalarT +] +DataType: TypeAlias = "pa.DataType" +DataTypeT = TypeVar("DataTypeT", bound=DataType, default=Any) +DataTypeT_co = TypeVar("DataTypeT_co", bound=DataType, covariant=True, default=Any) +ScalarT_co = TypeVar("ScalarT_co", bound="pa.Scalar[Any]", covariant=True, default=Any) +Scalar: TypeAlias = "pa.Scalar[DataTypeT_co]" +ChunkedArray: TypeAlias = "pa.ChunkedArray[ScalarT_co]" +ChunkedOrScalar: TypeAlias = "ChunkedArray[ScalarT_co] | ScalarT_co" +ScalarAny: TypeAlias = "Scalar[Any]" +ChunkedOrScalarAny: TypeAlias = "ChunkedOrScalar[ScalarAny]" +NativeScalar: TypeAlias = ScalarAny +BinOp: TypeAlias = Callable[..., ChunkedOrScalarAny] +StoresNativeT_co = TypeVar("StoresNativeT_co", bound=StoresNative[Any], covariant=True) From 02092b4dfce21def8a89eee820ae64c5405e6f1a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 16 Jul 2025 18:17:53 +0000 Subject: [PATCH 333/368] feat(pyarrow): Impl `concat_str` --- narwhals/_plan/arrow/functions.py | 30 ++++++++++++++++++++++++++++-- narwhals/_plan/arrow/namespace.py | 10 ++++++++-- narwhals/_plan/arrow/typing.py | 3 +++ tests/plan/compliant_test.py | 31 +++++++++++++++++++++++++++++++ 4 files changed, 70 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 8959abd248..186beaa3a3 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -16,9 +16,14 @@ from narwhals._plan import operators as ops if TYPE_CHECKING: - from collections.abc import Iterable, Mapping + from collections.abc import Iterable, Iterator, Mapping, Sequence - from narwhals._arrow.typing import ArrayAny, ArrayOrScalar, ChunkedArrayAny + from narwhals._arrow.typing import ( + ArrayAny, + ArrayOrScalar, + ChunkedArrayAny, + Incomplete, + ) from narwhals._plan.arrow.typing import ( BinaryComp, BinaryLogical, @@ -33,6 +38,7 @@ Scalar, ScalarAny, ScalarT, + StringScalar, UnaryFunction, ) from narwhals.typing import ClosedInterval @@ -168,6 +174,26 @@ def binary( return _DISPATCH_BINARY[op](lhs, rhs) +def concat_str( + *arrays: ChunkedArrayAny, separator: str = "", ignore_nulls: bool = False +) -> ChunkedArray[StringScalar]: + fn: Incomplete = pc.binary_join_element_wise + it, sep = _cast_to_comparable_string_types(arrays, separator) + return fn(*it, sep, null_handling="skip" if ignore_nulls else "emit_null") # type: ignore[no-any-return] + + +def _cast_to_comparable_string_types( + arrays: Sequence[ChunkedArrayAny], /, separator: str +) -> tuple[Iterator[ChunkedArray[StringScalar]], StringScalar]: + # Ensure `chunked_arrays` are either all `string` or all `large_string`. + dtype = ( + pa.string() + if not any(pa.types.is_large_string(obj.type) for obj in arrays) + else pa.large_string() + ) + return (obj.cast(dtype) for obj in arrays), pa.scalar(separator, dtype) + + def lit(value: Any, dtype: DataType | None = None) -> NativeScalar: # NOTE: PR that fixed these the overloads was closed # https://github.com/zen-xu/pyarrow-stubs/pull/208 diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index dbe539a728..e81d0f4a2b 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -160,11 +160,17 @@ def mean_horizontal( return self._scalar.from_native(result, name, self.version) return self._expr.from_native(result, name, self.version) - # TODO @dangotbanned: Impl `concat_str` def concat_str( self, node: FunctionExpr[ConcatHorizontal], frame: ArrowDataFrame, name: str ) -> ArrowExpr | ArrowScalar: - raise NotImplementedError + exprs = (self._expr.from_ir(e, frame, name) for e in node.input) + aligned = (ser.native for ser in self._expr.align(exprs)) + separator = node.function.separator + ignore_nulls = node.function.ignore_nulls + result = fn.concat_str(*aligned, separator=separator, ignore_nulls=ignore_nulls) + if isinstance(result, pa.Scalar): + return self._scalar.from_native(result, name, self.version) + return self._expr.from_native(result, name, self.version) def int_range( self, node: RangeExpr[IntRange], frame: ArrowDataFrame, name: str diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index 81d0703f81..e5a86b11d4 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -9,8 +9,11 @@ if TYPE_CHECKING: import pyarrow as pa import pyarrow.compute as pc + from pyarrow.lib import LargeStringType, StringType from typing_extensions import TypeAlias + StringScalar: TypeAlias = "Scalar[StringType | LargeStringType]" + ScalarT = TypeVar("ScalarT", bound="pa.Scalar[Any]", default="pa.Scalar[Any]") ScalarPT_contra = TypeVar( diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 53a27aba61..3a0305bf13 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -36,6 +36,8 @@ def data_small() -> dict[str, Any]: "k": [42, 10, None], "l": [4, 5, 6], "m": [0, 1, 2], + "n": ["dogs", "cats", None], + "o": ["play", "swim", "walk"], } @@ -303,6 +305,35 @@ def _ids_ir(expr: DummyExpr | Any) -> str: {"j": [54.1, 19.0, 11.0]}, id="sum_horizontal-null", ), + pytest.param( + nwd.concat_str(nwd.col("b") * 2, "n", nwd.col("o"), separator=" "), + {"b": ["2 dogs play", "4 cats swim", None]}, + id="concat_str-preserve_nulls", + ), + pytest.param( + nwd.concat_str( + nwd.col("b") * 2, "n", nwd.col("o"), separator=" ", ignore_nulls=True + ), + {"b": ["2 dogs play", "4 cats swim", "6 walk"]}, + id="concat_str-ignore_nulls", + ), + pytest.param( + nwd.concat_str("a", nwd.lit("a")), + {"a": ["Aa", "Ba", "Aa"]}, + id="concat_str-lit", + ), + pytest.param( + nwd.concat_str( + nwd.lit("a"), + nwd.lit("b"), + nwd.lit("c"), + nwd.lit("d"), + nwd.col("e").last() + 13, + separator="|", + ), + {"literal": ["a|b|c|d|20"]}, + id="concat_str-all-lit", + ), ], ids=_ids_ir, ) From 3db1e65d8ad60b7a3e0d22a6d3818366d4f4a5d3 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 16 Jul 2025 20:21:58 +0000 Subject: [PATCH 334/368] prep for complex nodes --- narwhals/_plan/arrow/expr.py | 56 +++++++++++++++++++++++++++++------- tests/plan/compliant_test.py | 4 +-- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index e822017506..008687bc8d 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -41,7 +41,15 @@ from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.boolean import IsBetween, IsFinite, IsNan, IsNull - from narwhals._plan.expr import BinaryExpr, FunctionExpr, Ternary + from narwhals._plan.expr import ( + AnonymousExpr, + BinaryExpr, + FunctionExpr, + OrderedWindowExpr, + RollingExpr, + Ternary, + WindowExpr, + ) from narwhals._plan.functions import FillNull, Pow from narwhals.typing import IntoDType, PythonLiteral @@ -132,6 +140,16 @@ def is_null( ) -> StoresNativeT_co: return self._unary_function(fn.is_null)(node, frame, name) + def binary_expr( + self, node: BinaryExpr, frame: ArrowDataFrame, name: str + ) -> StoresNativeT_co: + lhs, rhs = ( + self._dispatch(node.left, frame, name), + self._dispatch(node.right, frame, name), + ) + result = fn.binary(lhs.native, node.op.__class__, rhs.native) + return self._with_native(result, name) + def ternary_expr( self, node: Ternary, frame: ArrowDataFrame, name: str ) -> StoresNativeT_co: @@ -148,6 +166,7 @@ class ArrowExpr( # type: ignore[misc] EagerExpr["ArrowDataFrame", ArrowSeries], ): _evaluated: ArrowSeries + _version: Version @property def name(self) -> str: @@ -304,15 +323,26 @@ def min(self, node: Min, frame: ArrowDataFrame, name: str) -> ArrowScalar: result: NativeScalar = fn.min_(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) - def binary_expr( # type: ignore[override] - self, node: BinaryExpr, frame: ArrowDataFrame, name: str - ) -> ArrowScalar | Self: - lhs, rhs = ( - self._dispatch(node.left, frame, name), - self._dispatch(node.right, frame, name), - ) - result = fn.binary(lhs.native, node.op.__class__, rhs.native) - return self._with_native(result, name) + # TODO @dangotbanned: top-level, complex-ish nodes + # - All are fairly complex + # - `over`/`_ordered` (with partitions) requires `group_by`, `join` + # - `over_ordered` alone should be possible w/ the current API + # - `map_batches` is defined in `EagerExpr`, might be simpler here than on main + # - `rolling_expr` has 4 variants + + def over(self, node: WindowExpr, frame: ArrowDataFrame, name: str) -> Self: + raise NotImplementedError + + def over_ordered( + self, node: OrderedWindowExpr, frame: ArrowDataFrame, name: str + ) -> Self: + raise NotImplementedError + + def map_batches(self, node: AnonymousExpr, frame: ArrowDataFrame, name: str) -> Self: + raise NotImplementedError + + def rolling_expr(self, node: RollingExpr, frame: ArrowDataFrame, name: str) -> Self: + raise NotImplementedError class ArrowScalar( @@ -321,6 +351,8 @@ class ArrowScalar( EagerScalar["ArrowDataFrame", ArrowSeries], ): _evaluated: NativeScalar + _version: Version + _name: str @classmethod def from_native( @@ -415,3 +447,7 @@ def count(self, node: Count, frame: ArrowDataFrame, name: str) -> ArrowScalar: return self._with_native(pa.scalar(1 if native.is_valid else 0), name) filter = not_implemented() + over = not_implemented() + over_ordered = not_implemented() + map_batches = not_implemented() + rolling_expr = not_implemented() diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 3a0305bf13..22c7db7f1e 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -361,7 +361,7 @@ def test_protocol_expr() -> None: pytest.importorskip("pyarrow") from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar - expr = ArrowExpr() # type: ignore[abstract] - scalar = ArrowScalar() # type: ignore[abstract] + expr = ArrowExpr() + scalar = ArrowScalar() assert expr assert scalar From a5f110f241f536eb5bdd7487e70ac95c75240776 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 16 Jul 2025 22:39:39 +0000 Subject: [PATCH 335/368] feat(pyarrow): Impl `map_batches` --- narwhals/_plan/arrow/dataframe.py | 9 +++----- narwhals/_plan/arrow/expr.py | 25 ++++++++++++++++------- narwhals/_plan/arrow/functions.py | 9 ++++++++ narwhals/_plan/arrow/series.py | 34 ++++++++++++++++++++++++++++++- narwhals/_plan/protocols.py | 18 +++++++++++++++- 5 files changed, 80 insertions(+), 15 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index d27784b765..d73232df98 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -7,6 +7,7 @@ import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import native_to_narwhals_dtype +from narwhals._plan.arrow import functions as fn from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.common import ExprIR from narwhals._plan.dummy import DummyFrame @@ -16,7 +17,7 @@ if t.TYPE_CHECKING: from collections.abc import Iterable, Iterator - from typing_extensions import Self, TypeIs + from typing_extensions import Self from narwhals._arrow.typing import ChunkedArrayAny from narwhals._plan.arrow.namespace import ArrowNamespace @@ -28,10 +29,6 @@ from narwhals.schema import Schema -def is_series(obj: t.Any) -> TypeIs[ArrowSeries]: - return isinstance(obj, ArrowSeries) - - class ArrowDataFrame(DummyCompliantFrame[ArrowSeries, "pa.Table", "ChunkedArrayAny"]): def __narwhals_namespace__(self) -> ArrowNamespace: from narwhals._plan.arrow.namespace import ArrowNamespace @@ -60,7 +57,7 @@ def to_narwhals(self) -> DummyFrame[pa.Table, ChunkedArrayAny]: def from_series( cls, series: t.Iterable[ArrowSeries] | ArrowSeries, *more_series: ArrowSeries ) -> Self: - lhs = (series,) if is_series(series) else series + lhs = (series,) if fn.is_series(series) else series it = chain(lhs, more_series) if more_series else lhs return cls.from_dict({s.name: s.native for s in it}) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 008687bc8d..6d05d487ca 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -51,7 +51,7 @@ WindowExpr, ) from narwhals._plan.functions import FillNull, Pow - from narwhals.typing import IntoDType, PythonLiteral + from narwhals.typing import Into1DArray, IntoDType, PythonLiteral BACKEND_VERSION = Implementation.PYARROW._backend_version() @@ -324,11 +324,10 @@ def min(self, node: Min, frame: ArrowDataFrame, name: str) -> ArrowScalar: return self._with_native(result, name) # TODO @dangotbanned: top-level, complex-ish nodes - # - All are fairly complex - # - `over`/`_ordered` (with partitions) requires `group_by`, `join` - # - `over_ordered` alone should be possible w/ the current API - # - `map_batches` is defined in `EagerExpr`, might be simpler here than on main - # - `rolling_expr` has 4 variants + # - [ ] `over`/`_ordered` (with partitions) requires `group_by`, `join` + # - [ ] `over_ordered` alone should be possible w/ the current API + # - [x] `map_batches` is defined in `EagerExpr`, might be simpler here than on main + # - [ ] `rolling_expr` has 4 variants def over(self, node: WindowExpr, frame: ArrowDataFrame, name: str) -> Self: raise NotImplementedError @@ -338,8 +337,20 @@ def over_ordered( ) -> Self: raise NotImplementedError + # NOTE: Can't implement in `EagerExpr`, since it doesn't derive `ExprDispatch` def map_batches(self, node: AnonymousExpr, frame: ArrowDataFrame, name: str) -> Self: - raise NotImplementedError + if node.is_scalar: + # NOTE: Just trying to avoid redoing the whole API for `ArrowSeries` + msg = "Only elementwise is currently supported" + raise NotImplementedError(msg) + series = self._dispatch_expr(node.input[0], frame, name) + udf = node.function.function + result: ArrowSeries | Into1DArray = udf(series) + if not fn.is_series(result): + result = ArrowSeries.from_numpy(result, name, version=self.version) + if dtype := node.function.return_dtype: + result = result.cast(dtype) + return self.from_series(result) def rolling_expr(self, node: RollingExpr, frame: ArrowDataFrame, name: str) -> Self: raise NotImplementedError diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 186beaa3a3..6381877d0f 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -18,12 +18,15 @@ if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping, Sequence + from typing_extensions import TypeIs + from narwhals._arrow.typing import ( ArrayAny, ArrayOrScalar, ChunkedArrayAny, Incomplete, ) + from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.arrow.typing import ( BinaryComp, BinaryLogical, @@ -214,3 +217,9 @@ def chunked_array( arr: ArrayOrScalar | list[Iterable[Any]], dtype: DataType | None = None, / ) -> ChunkedArrayAny: return _chunked_array(array(arr) if isinstance(arr, pa.Scalar) else arr, dtype) + + +def is_series(obj: t.Any) -> TypeIs[ArrowSeries]: + from narwhals._plan.arrow.series import ArrowSeries + + return isinstance(obj, ArrowSeries) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index 0b4abf0787..2870ddf0b4 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -2,12 +2,20 @@ from typing import TYPE_CHECKING, Any -from narwhals._arrow.utils import native_to_narwhals_dtype +from narwhals._arrow.utils import narwhals_to_native_dtype, native_to_narwhals_dtype +from narwhals._plan.arrow import functions as fn from narwhals._plan.protocols import DummyCompliantSeries +from narwhals._utils import Version +from narwhals.dependencies import is_numpy_array_1d if TYPE_CHECKING: + from collections.abc import Iterable + + from typing_extensions import Self + from narwhals._arrow.typing import ChunkedArrayAny # noqa: F401 from narwhals.dtypes import DType + from narwhals.typing import Into1DArray, IntoDType class ArrowSeries(DummyCompliantSeries["ChunkedArrayAny"]): @@ -20,3 +28,27 @@ def __len__(self) -> int: @property def dtype(self) -> DType: return native_to_narwhals_dtype(self.native.type, self._version) + + @classmethod + def from_numpy( + cls, data: Into1DArray, name: str = "", /, *, version: Version = Version.MAIN + ) -> Self: + return cls.from_iterable( + data if is_numpy_array_1d(data) else [data], name=name, version=version + ) + + @classmethod + def from_iterable( + cls, + data: Iterable[Any], + *, + version: Version, + name: str = "", + dtype: IntoDType | None = None, + ) -> Self: + dtype_pa = narwhals_to_native_dtype(dtype, version) if dtype else None + return cls.from_native(fn.chunked_array([data], dtype_pa), name, version=version) + + def cast(self, dtype: IntoDType) -> Self: + dtype_pa = narwhals_to_native_dtype(dtype, self.version) + return self._with_native(fn.cast(self.native, dtype_pa)) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 3d860c5bf3..ed4dadcaac 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -25,7 +25,7 @@ from narwhals._plan.schema import FrozenSchema from narwhals.dtypes import DType from narwhals.schema import Schema - from narwhals.typing import IntoDType, NonNestedLiteral, PythonLiteral + from narwhals.typing import Into1DArray, IntoDType, NonNestedLiteral, PythonLiteral T = TypeVar("T") R_co = TypeVar("R_co", covariant=True) @@ -696,12 +696,28 @@ def from_native( obj._version = version return obj + @classmethod + def from_numpy( + cls, data: Into1DArray, name: str = "", /, *, version: Version = Version.MAIN + ) -> Self: ... + + @classmethod + def from_iterable( + cls, + data: Iterable[Any], + *, + version: Version, + name: str = "", + dtype: IntoDType | None = None, + ) -> Self: ... + def _with_native(self, native: NativeSeriesT) -> Self: return self.from_native(native, self.name, version=self.version) def alias(self, name: str) -> Self: return self.from_native(self.native, name, version=self.version) + def cast(self, dtype: IntoDType) -> Self: ... def __len__(self) -> int: return len(self.native) From 747a5ae2a5d3d87f212627aa65b80883e0f674aa Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 17 Jul 2025 12:49:04 +0000 Subject: [PATCH 336/368] test: Add `map_batches` tests --- narwhals/_plan/arrow/series.py | 5 +++- narwhals/_plan/protocols.py | 9 ++++++- tests/plan/compliant_test.py | 45 ++++++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index 2870ddf0b4..f622e3ed84 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -15,13 +15,16 @@ from narwhals._arrow.typing import ChunkedArrayAny # noqa: F401 from narwhals.dtypes import DType - from narwhals.typing import Into1DArray, IntoDType + from narwhals.typing import Into1DArray, IntoDType, _1DArray class ArrowSeries(DummyCompliantSeries["ChunkedArrayAny"]): def to_list(self) -> list[Any]: return self.native.to_pylist() + def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: + return self.native.to_numpy() + def __len__(self) -> int: return self.native.length() diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index ed4dadcaac..c5ce58bba3 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -25,7 +25,13 @@ from narwhals._plan.schema import FrozenSchema from narwhals.dtypes import DType from narwhals.schema import Schema - from narwhals.typing import Into1DArray, IntoDType, NonNestedLiteral, PythonLiteral + from narwhals.typing import ( + Into1DArray, + IntoDType, + NonNestedLiteral, + PythonLiteral, + _1DArray, + ) T = TypeVar("T") R_co = TypeVar("R_co", covariant=True) @@ -722,3 +728,4 @@ def __len__(self) -> int: return len(self.native) def to_list(self) -> list[Any]: ... + def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: ... diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 22c7db7f1e..c095b0023a 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -6,11 +6,14 @@ import pytest pytest.importorskip("pyarrow") +pytest.importorskip("numpy") +import numpy as np import pyarrow as pa import narwhals as nw from narwhals._plan import demo as nwd, selectors as ndcs from narwhals._plan.common import is_expr +from narwhals._utils import Version from narwhals.exceptions import ComputeError from tests.utils import assert_equal_data @@ -334,6 +337,48 @@ def _ids_ir(expr: DummyExpr | Any) -> str: {"literal": ["a|b|c|d|20"]}, id="concat_str-all-lit", ), + pytest.param( + [ + nwd.col("a") + .alias("...") + .map_batches( + lambda s: s.from_iterable( + [*((len(s) - 1) * [type(s.dtype).__name__.lower()]), "last"], + version=Version.MAIN, + name="funky", + ), + is_elementwise=True, + ), + nwd.col("a"), + ], + {"funky": ["string", "string", "last"], "a": ["A", "B", "A"]}, + id="map_batches-series", + ), + pytest.param( + nwd.col("b") + .map_batches(lambda s: s.to_numpy() + 1, nw.Float64(), is_elementwise=True) + .sum(), + {"b": [9.0]}, + id="map_batches-numpy", + ), + pytest.param( + ndcs.by_name("b", "c", "d") + .map_batches(lambda s: np.append(s.to_numpy(), [10, 2]), is_elementwise=True) + .sort(), + {"b": [1, 2, 2, 3, 10], "c": [2, 2, 4, 9, 10], "d": [2, 7, 8, 8, 10]}, + id="map_batches-selector", + ), + pytest.param( + nwd.col("j", "k") + .fill_null(15) + .map_batches(lambda s: (s.to_numpy().max()), returns_scalar=True), + {"j": [15], "k": [42]}, + id="map_batches-return_scalar", + marks=pytest.mark.xfail( + reason="not implemented `map_batches(returns_scalar=True)` for `pyarrow`", + raises=NotImplementedError, + ), + ), ], ids=_ids_ir, ) From e85af02d1258f3725da20516ce8698f86f29afc0 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 17 Jul 2025 14:25:50 +0000 Subject: [PATCH 337/368] refactor(pyarrow): Split out `int_range` Needing it for `over_ordered`, `with_row_index` --- narwhals/_plan/arrow/functions.py | 18 ++++++++++++++++++ narwhals/_plan/arrow/namespace.py | 6 +++--- narwhals/_plan/arrow/typing.py | 15 ++++++++++++++- 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 6381877d0f..8df52d806a 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -37,6 +37,8 @@ ChunkedOrScalarAny, DataType, DataTypeT, + IntegerScalar, + IntegerType, NativeScalar, Scalar, ScalarAny, @@ -197,6 +199,22 @@ def _cast_to_comparable_string_types( return (obj.cast(dtype) for obj in arrays), pa.scalar(separator, dtype) +def int_range( + start: int = 0, + end: int | None = None, + step: int = 1, + /, + *, + dtype: IntegerType = pa.int64(), # noqa: B008 +) -> ChunkedArray[IntegerScalar]: + import numpy as np # ignore-banned-import + + if end is None: + end = start + start = 0 + return pa.chunked_array([pa.array(np.arange(start, end, step), dtype)]) + + def lit(value: Any, dtype: DataType | None = None) -> NativeScalar: # NOTE: PR that fixed these the overloads was closed # https://github.com/zen-xu/pyarrow-stubs/pull/208 diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index e81d0f4a2b..adf9fd8165 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -196,10 +196,10 @@ def int_range( ) raise InvalidOperationError(msg) if isinstance(start_, int) and isinstance(end_, int): - import numpy as np # ignore-banned-import - pa_dtype = narwhals_to_native_dtype(dtype, self.version) - native = fn.chunked_array(fn.array(np.arange(start_, end_, step), pa_dtype)) + if not pa.types.is_integer(pa_dtype): + raise TypeError(pa_dtype) + native = fn.int_range(start_, end_, step, dtype=pa_dtype) return self._expr.from_native(native, name, self.version) else: diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index e5a86b11d4..e633e6560e 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -9,10 +9,23 @@ if TYPE_CHECKING: import pyarrow as pa import pyarrow.compute as pc - from pyarrow.lib import LargeStringType, StringType + from pyarrow.lib import ( + Int8Type, + Int16Type, + Int32Type, + Int64Type, + LargeStringType, + StringType, + Uint8Type, + Uint16Type, + Uint32Type, + Uint64Type, + ) from typing_extensions import TypeAlias StringScalar: TypeAlias = "Scalar[StringType | LargeStringType]" + IntegerType: TypeAlias = "Int8Type | Int16Type | Int32Type | Int64Type | Uint8Type | Uint16Type | Uint32Type | Uint64Type" + IntegerScalar: TypeAlias = "Scalar[IntegerType]" ScalarT = TypeVar("ScalarT", bound="pa.Scalar[Any]", default="pa.Scalar[Any]") From 088a48a28704614c3abd3b16c7b9b15388ecad76 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 17 Jul 2025 15:51:01 +0000 Subject: [PATCH 338/368] feat(DRAFT): Add `with_columns` Still have a bug somewhere lol --- narwhals/_plan/dummy.py | 15 ++++++- narwhals/_plan/protocols.py | 3 ++ narwhals/_plan/schema.py | 18 ++++++++- tests/plan/compliant_test.py | 78 ++++++++++++++++++++++++++++++++++++ 4 files changed, 111 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 4bad96a7ee..1ae5fd159d 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -28,7 +28,7 @@ from narwhals._plan.selectors import by_name from narwhals._plan.typing import NativeFrameT, NativeSeriesT from narwhals._plan.window import Over -from narwhals._utils import Version +from narwhals._utils import Version, generate_repr from narwhals.dependencies import is_pyarrow_chunked_array, is_pyarrow_table from narwhals.exceptions import ComputeError, InvalidOperationError from narwhals.schema import Schema @@ -851,6 +851,9 @@ def schema(self) -> Schema: def columns(self) -> list[str]: return self._compliant.columns + def __repr__(self) -> str: # pragma: no cover + return generate_repr(f"nw.{type(self).__name__}", self.to_native().__repr__()) + # NOTE: Gave up on trying to get typing working for now @classmethod def from_native( @@ -920,6 +923,16 @@ def select(self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: t.Any) -> ) return self._from_compliant(self._compliant.select(named_irs, schema_projected)) + def with_columns( + self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: t.Any + ) -> Self: + named_irs, schema_projected = self._project( + exprs, named_exprs, ExprContext.WITH_COLUMNS + ) + return self._from_compliant( + self._compliant.with_columns(named_irs, schema_projected) + ) + def sort( self, by: str | Iterable[str], diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index c5ce58bba3..795c611198 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -661,6 +661,9 @@ def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[SeriesT def select(self, irs: Seq[NamedIR], projected: FrozenSchema) -> Self: return self.from_series(self._evaluate_irs(irs)) + def with_columns(self, irs: Seq[NamedIR], projected: FrozenSchema) -> Self: + return self.from_series(self._evaluate_irs(irs)) + def sort( self, by: Seq[NamedIR], options: SortMultipleOptions, projected: FrozenSchema ) -> Self: ... diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index 30246993a8..c42b25d321 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -1,9 +1,11 @@ from __future__ import annotations +from collections import deque from collections.abc import Mapping from functools import lru_cache +from itertools import chain, repeat from types import MappingProxyType -from typing import TYPE_CHECKING, TypeVar, overload +from typing import TYPE_CHECKING, Any, TypeVar, overload from narwhals._plan.common import _IMMUTABLE_HASH_NAME, Immutable, NamedIR from narwhals.dtypes import Unknown @@ -44,7 +46,7 @@ def project( if context.is_select(): return exprs, self._select(exprs) if context.is_with_columns(): - raise NotImplementedError(context) + return self._with_columns(exprs) raise TypeError(context) def _select(self, exprs: Seq[NamedIR]) -> FrozenSchema: @@ -61,6 +63,18 @@ def _select(self, exprs: Seq[NamedIR]) -> FrozenSchema: default = Unknown() return freeze_schema((name, self.get(name, default)) for name in names) + def _with_columns(self, exprs: Seq[NamedIR]) -> tuple[Seq[NamedIR], FrozenSchema]: + exprs_out = deque[NamedIR]() + named: dict[str, NamedIR[Any]] = {e.name: e for e in exprs} + items: IntoFrozenSchema + for name in self: + exprs_out.append(named.pop(name, NamedIR.from_name(name))) + if named: + items = chain(self.items(), zip(named, repeat(Unknown(), len(named)))) + else: + items = self + return tuple(exprs_out), freeze_schema(items) + @property def __immutable_hash__(self) -> int: if hasattr(self, _IMMUTABLE_HASH_NAME): diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index c095b0023a..48c7037050 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -44,6 +44,13 @@ def data_small() -> dict[str, Any]: } +@pytest.fixture +def data_smaller(data_small: dict[str, Any]) -> dict[str, Any]: + """Use only columns `"a"-"f"`.""" + keep = {"a", "b", "c", "d", "e", "f"} + return {k: v for k, v in data_small.items() if k in keep} + + def _ids_ir(expr: DummyExpr | Any) -> str: if is_expr(expr): return repr(expr._ir) @@ -395,6 +402,77 @@ def test_select( assert_equal_data(result, expected) +@pytest.mark.parametrize( + ("expr", "expected"), + [ + ( + ["d", nwd.col("a"), "b", nwd.col("e")], + { + "a": ["A", "B", "A"], + "b": [1, 2, 3], + "c": [9, 2, 4], + "d": [8, 7, 8], + "e": [None, 9, 7], + "f": [True, False, None], + }, + ), + ( + ndcs.numeric().cast(nw.String), + { + "a": ["A", "B", "A"], + "b": ["1", "2", "3"], + "c": ["9", "2", "4"], + "d": ["8", "7", "8"], + "e": [None, "9", "7"], + "f": [True, False, None], + }, + ), + ( + [ + nwd.col("e").fill_null(nwd.col("e").last()), + nwd.col("f").sort(), + nwd.nth(1).max(), + ], + { + "a": ["A", "B", "A"], + "b": [3, 3, 3], + "c": [9, 2, 4], + "d": [8, 7, 8], + "e": [7, 9, 7], + "f": [None, False, True], + }, + ), + pytest.param( + [nwd.col("a").alias("a?")], + { + "a": ["A", "B", "A"], + "b": [1, 2, 3], + "c": [9, 2, 4], + "d": [8, 7, 8], + "e": [None, 9, 7], + "f": [True, False, None], + "a?": ["A", "B", "A"], + }, + id="with_columns-extend", + marks=pytest.mark.xfail( + reason="Non-replacing exprs are being silently dropped?" + ), + ), + ], +) +def test_with_columns( + expr: DummyExpr | Sequence[DummyExpr], + expected: dict[str, Any], + data_smaller: dict[str, Any], +) -> None: + from narwhals._plan.dummy import DummyFrame + + frame = pa.table(data_smaller) + df = DummyFrame.from_native(frame) + result = df.with_columns(expr).to_dict(as_series=False) + assert_equal_data(result, expected) + + if TYPE_CHECKING: def test_protocol_expr() -> None: From cf2f96cb6c2ff50e87bc8059cd8ad30cf37af476 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 18 Jul 2025 13:18:39 +0000 Subject: [PATCH 339/368] fix: Extend exprs in `Schema._with_columns` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I guess I resolved this one as well #1868 🥳 --- narwhals/_plan/schema.py | 1 + tests/plan/compliant_test.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index c42b25d321..17b8416285 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -71,6 +71,7 @@ def _with_columns(self, exprs: Seq[NamedIR]) -> tuple[Seq[NamedIR], FrozenSchema exprs_out.append(named.pop(name, NamedIR.from_name(name))) if named: items = chain(self.items(), zip(named, repeat(Unknown(), len(named)))) + exprs_out.extend(named.values()) else: items = self return tuple(exprs_out), freeze_schema(items) diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 48c7037050..2ed3c707a7 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -443,20 +443,25 @@ def test_select( }, ), pytest.param( - [nwd.col("a").alias("a?")], + [ + nwd.col("a").alias("a?"), + ndcs.by_name("a"), + nwd.col("b").cast(nw.Float64).name.suffix("_float"), + nwd.col("c").max() + 1, + nwd.sum_horizontal(1, "d", nwd.col("b"), nwd.lit(3)), + ], { "a": ["A", "B", "A"], "b": [1, 2, 3], - "c": [9, 2, 4], + "c": [10, 10, 10], "d": [8, 7, 8], "e": [None, 9, 7], "f": [True, False, None], "a?": ["A", "B", "A"], + "b_float": [1.0, 2.0, 3.0], + "literal": [13, 13, 15], }, id="with_columns-extend", - marks=pytest.mark.xfail( - reason="Non-replacing exprs are being silently dropped?" - ), ), ], ) From cd6ade0e04c4c70d781716b4c72900e87b6df950 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 18 Jul 2025 16:30:10 +0000 Subject: [PATCH 340/368] wip: add `concat` Planning to drop `CompliantDataFrame.from_series` --- narwhals/_plan/arrow/functions.py | 28 ++++++++++ narwhals/_plan/arrow/namespace.py | 79 +++++++++++++++++++++++++-- narwhals/_plan/common.py | 4 ++ narwhals/_plan/protocols.py | 89 ++++++++++++++++++++++++++----- 4 files changed, 183 insertions(+), 17 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 8df52d806a..095aeba780 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -14,12 +14,14 @@ floordiv_compat as floordiv, ) from narwhals._plan import operators as ops +from narwhals._utils import Implementation if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping, Sequence from typing_extensions import TypeIs + from narwhals._arrow.dataframe import PromoteOptions from narwhals._arrow.typing import ( ArrayAny, ArrayOrScalar, @@ -48,6 +50,8 @@ ) from narwhals.typing import ClosedInterval +BACKEND_VERSION = Implementation.PYARROW._backend_version() + is_null = pc.is_null is_not_null = t.cast("UnaryFunction[ScalarAny,pa.BooleanScalar]", pc.is_valid) is_nan = pc.is_nan @@ -237,6 +241,30 @@ def chunked_array( return _chunked_array(array(arr) if isinstance(arr, pa.Scalar) else arr, dtype) +def concat_vertical_chunked( + arrays: Iterable[ChunkedArrayAny], dtype: DataType | None = None, / +) -> ChunkedArrayAny: + # NOTE: Overloads are broken, this is legit + v_concat: Incomplete = pa.chunked_array + return v_concat(arrays, dtype) # type: ignore[no-any-return] + + +def concat_vertical_table( + tables: Iterable[pa.Table], /, promote_options: PromoteOptions = "none" +) -> pa.Table: + return pa.concat_tables(tables, promote_options=promote_options) + + +if BACKEND_VERSION >= (14,): + + def concat_diagonal(tables: Iterable[pa.Table]) -> pa.Table: + return pa.concat_tables(tables, promote_options="default") +else: + + def concat_diagonal(tables: Iterable[pa.Table]) -> pa.Table: + return pa.concat_tables(tables, promote=True) + + def is_series(obj: t.Any) -> TypeIs[ArrowSeries]: from narwhals._plan.arrow.series import ArrowSeries diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index adf9fd8165..7a00edf51f 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import reduce -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, Any, Literal, cast, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import @@ -9,13 +9,14 @@ from narwhals._arrow.utils import narwhals_to_native_dtype from narwhals._plan.arrow import functions as fn from narwhals._plan.arrow.functions import lit +from narwhals._plan.common import collect, is_tuple_of from narwhals._plan.literal import is_literal_scalar from narwhals._plan.protocols import EagerNamespace from narwhals._utils import Version from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Iterable, Iterator, Sequence from narwhals._arrow.typing import ChunkedArrayAny from narwhals._plan import expr, functions as F # noqa: N812 @@ -27,7 +28,7 @@ from narwhals._plan.expr import FunctionExpr, RangeExpr from narwhals._plan.ranges import IntRange from narwhals._plan.strings import ConcatHorizontal - from narwhals.typing import NonNestedLiteral, PythonLiteral + from narwhals.typing import ConcatMethod, NonNestedLiteral, PythonLiteral class ArrowNamespace( @@ -208,3 +209,75 @@ def int_range( f"{start_!r}\n{end_!r}" ) raise InvalidOperationError(msg) + + @overload + def concat( + self, items: Iterable[ArrowDataFrame], *, how: ConcatMethod + ) -> ArrowDataFrame: ... + + @overload + def concat( + self, items: Iterable[ArrowSeries], *, how: Literal["vertical"] + ) -> ArrowSeries: ... + + def concat( + self, + items: Iterable[ArrowDataFrame] | Iterable[ArrowSeries], + *, + how: ConcatMethod, + ) -> ArrowDataFrame | ArrowSeries: + if how == "vertical": + return self._concat_vertical(items) + if how == "horizontal": + return self._concat_horizontal(items) + it = iter(items) + first = next(it) + if self._is_series(first): + raise TypeError(first) + dfs = cast("Sequence[ArrowDataFrame]", (first, *it)) + return self._concat_diagonal(dfs) + + def _concat_diagonal(self, items: Iterable[ArrowDataFrame]) -> ArrowDataFrame: + return self._dataframe.from_native( + fn.concat_vertical_table(df.native for df in items), self.version + ) + + def _concat_horizontal( + self, items: Iterable[ArrowDataFrame] | Iterable[ArrowSeries] + ) -> ArrowDataFrame: + def gen( + objs: Iterable[ArrowDataFrame | ArrowSeries], + ) -> Iterator[tuple[ChunkedArrayAny, str]]: + for item in objs: + if self._is_series(item): + yield item.native, item.name + else: + yield from zip(item.native.itercolumns(), item.columns) + + arrays, names = zip(*gen(items)) + native = pa.Table.from_arrays(arrays, list(names)) + return self._dataframe.from_native(native, self.version) + + def _concat_vertical( + self, items: Iterable[ArrowDataFrame] | Iterable[ArrowSeries] + ) -> ArrowDataFrame | ArrowSeries: + collected = collect(items) + if is_tuple_of(collected, self._series): + sers = collected + chunked = fn.concat_vertical_chunked(ser.native for ser in sers) + return sers[0]._with_native(chunked) + if is_tuple_of(collected, self._dataframe): + dfs = collected + cols_0 = dfs[0].columns + for i, df in enumerate(dfs[1:], start=1): + cols_current = df.columns + if cols_current != cols_0: + msg = ( + "unable to vstack, column names don't match:\n" + f" - dataframe 0: {cols_0}\n" + f" - dataframe {i}: {cols_current}\n" + ) + raise TypeError(msg) + return df._with_native(fn.concat_vertical_table(df.native for df in dfs)) + else: + raise TypeError(items) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 495a4e18ae..e3b0e3a780 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -503,6 +503,10 @@ def is_horizontal_reduction(obj: FunctionExpr[Any] | Any) -> TypeIs[FunctionExpr return is_function_expr(obj) and obj.options.is_input_wildcard_expansion() +def is_tuple_of(obj: Any, tp: type[T]) -> TypeIs[Seq[T]]: + return bool(isinstance(obj, tuple) and obj and isinstance(obj[0], tp)) + + def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) -> DType: dtypes = version.dtypes mapping: dict[type[NonNestedLiteral], type[NonNestedDType]] = { diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 795c611198..214ca0cbec 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -16,7 +16,7 @@ from narwhals._utils import Version, _hasattr_static if TYPE_CHECKING: - from typing_extensions import Self, TypeAlias + from typing_extensions import Self, TypeAlias, TypeIs from narwhals._plan.dummy import DummyFrame, DummySeries from narwhals._plan.expr import FunctionExpr, RangeExpr @@ -26,6 +26,7 @@ from narwhals.dtypes import DType from narwhals.schema import Schema from narwhals.typing import ( + ConcatMethod, Into1DArray, IntoDType, NonNestedLiteral, @@ -39,11 +40,20 @@ LengthT = TypeVar("LengthT") NativeT_co = TypeVar("NativeT_co", covariant=True, default=Any) +ConcatT1 = TypeVar("ConcatT1") +ConcatT2 = TypeVar("ConcatT2", default=ConcatT1) + ExprAny: TypeAlias = "CompliantExpr[Any, Any]" ScalarAny: TypeAlias = "CompliantScalar[Any, Any]" SeriesAny: TypeAlias = "DummyCompliantSeries[Any]" FrameAny: TypeAlias = "DummyCompliantFrame[Any, Any, Any]" -NamespaceAny: TypeAlias = "CompliantNamespace[Any, Any, Any, Any]" +NamespaceAny: TypeAlias = "CompliantNamespace[Any, Any, Any]" + +EagerExprAny: TypeAlias = "EagerExpr[Any, Any]" +EagerScalarAny: TypeAlias = "EagerScalar[Any, Any]" + +LazyExprAny: TypeAlias = "LazyExpr[Any, Any, Any]" +LazyScalarAny: TypeAlias = "LazyScalar[Any, Any, Any]" ExprT_co = TypeVar("ExprT_co", bound=ExprAny, covariant=True) ScalarT = TypeVar("ScalarT", bound=ScalarAny) @@ -54,10 +64,11 @@ FrameT_contra = TypeVar("FrameT_contra", bound=FrameAny, contravariant=True) NamespaceT_co = TypeVar("NamespaceT_co", bound="NamespaceAny", covariant=True) -EagerExprT_co = TypeVar("EagerExprT_co", bound="EagerExpr[Any, Any]", covariant=True) -EagerScalarT_co = TypeVar( - "EagerScalarT_co", bound="EagerScalar[Any, Any]", covariant=True -) +EagerExprT_co = TypeVar("EagerExprT_co", bound=EagerExprAny, covariant=True) +EagerScalarT_co = TypeVar("EagerScalarT_co", bound=EagerScalarAny, covariant=True) + +LazyExprT_co = TypeVar("LazyExprT_co", bound=LazyExprAny, covariant=True) +LazyScalarT_co = TypeVar("LazyScalarT_co", bound=LazyScalarAny, covariant=True) # NOTE: Unlike the version in `nw._utils`, here `.version` it is public @@ -513,9 +524,34 @@ class LazyScalar( ): ... -class CompliantNamespace( - StoresVersion, Protocol[FrameT, SeriesT_co, ExprT_co, ScalarT_co] -): +# NOTE: `mypy` is wrong +# error: Invariant type variable "ConcatT2" used in protocol where covariant one is expected [misc] +class Concat(Protocol[ConcatT1, ConcatT2]): # type: ignore[misc] + @overload + def concat(self, items: Iterable[ConcatT1], *, how: ConcatMethod) -> ConcatT1: ... + # Series only supports vertical publicly (like in polars) + @overload + def concat( + self, items: Iterable[ConcatT2], *, how: Literal["vertical"] + ) -> ConcatT2: ... + def concat( + self, items: Iterable[ConcatT1] | Iterable[ConcatT2], *, how: ConcatMethod + ) -> ConcatT1 | ConcatT2: ... + + +class EagerConcat(Concat[ConcatT1, ConcatT2], Protocol[ConcatT1, ConcatT2]): # type: ignore[misc] + def _concat_diagonal(self, items: Iterable[ConcatT1], /) -> ConcatT1: ... + # Series can be used here to go from [Series, Series] -> DataFrame + # but that is only available privately + def _concat_horizontal( + self, items: Iterable[ConcatT1] | Iterable[ConcatT2], / + ) -> ConcatT1: ... + def _concat_vertical( + self, items: Iterable[ConcatT1] | Iterable[ConcatT2], / + ) -> ConcatT1 | ConcatT2: ... + + +class CompliantNamespace(StoresVersion, Protocol[FrameT, ExprT_co, ScalarT_co]): """Need to hold `Expr` and `Scalar` types outside of their defs. Likely, re-wrapping the output types will work like: @@ -531,9 +567,7 @@ class CompliantNamespace( """ @property - def _dataframe(self) -> type[FrameT]: ... - @property - def _series(self) -> type[SeriesT_co]: ... + def _frame(self) -> type[FrameT]: ... @property def _expr(self) -> type[ExprT_co]: ... @property @@ -570,9 +604,24 @@ def int_range( class EagerNamespace( - CompliantNamespace[FrameT, SeriesT_co, EagerExprT_co, EagerScalarT_co], - Protocol[FrameT, SeriesT_co, EagerExprT_co, EagerScalarT_co], + EagerConcat[FrameT, SeriesT], + CompliantNamespace[FrameT, EagerExprT_co, EagerScalarT_co], + Protocol[FrameT, SeriesT, EagerExprT_co, EagerScalarT_co], ): + @property + def _series(self) -> type[SeriesT]: ... + @property + def _dataframe(self) -> type[FrameT]: ... + @property + def _frame(self) -> type[FrameT]: + return self._dataframe + + def _is_series(self, obj: Any) -> TypeIs[SeriesT]: + return isinstance(obj, self._series) + + def _is_dataframe(self, obj: Any) -> TypeIs[FrameT]: + return isinstance(obj, self._dataframe) + @overload def lit( self, node: expr.Literal[NonNestedLiteral], frame: FrameT, name: str @@ -598,6 +647,18 @@ def len(self, node: expr.Len, frame: FrameT, name: str) -> EagerScalarT_co: ) +class LazyNamespace( + Concat[FrameT, FrameT], + CompliantNamespace[FrameT, LazyExprT_co, LazyScalarT_co], + Protocol[FrameT, LazyExprT_co, LazyScalarT_co], +): + @property + def _lazyframe(self) -> type[FrameT]: ... + @property + def _frame(self) -> type[FrameT]: + return self._lazyframe + + class DummyCompliantFrame(StoresVersion, Protocol[SeriesT, NativeFrameT, NativeSeriesT]): _native: NativeFrameT From 76b9b9a69e59bab7324e962377bda654de85ab90 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 18 Jul 2025 18:04:38 +0000 Subject: [PATCH 341/368] refactor: Split out eager, remove `from_series` Will start removing the `Dummy` prefix soon --- narwhals/_plan/arrow/dataframe.py | 20 ++---- narwhals/_plan/arrow/expr.py | 3 +- narwhals/_plan/dummy.py | 107 +++++++++++++++++------------- narwhals/_plan/protocols.py | 80 +++++++++++++--------- tests/plan/compliant_test.py | 9 +-- 5 files changed, 118 insertions(+), 101 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index d73232df98..11b4a8e3e6 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -1,17 +1,14 @@ from __future__ import annotations import typing as t -from itertools import chain import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import native_to_narwhals_dtype -from narwhals._plan.arrow import functions as fn from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.common import ExprIR -from narwhals._plan.dummy import DummyFrame -from narwhals._plan.protocols import DummyCompliantFrame +from narwhals._plan.protocols import DummyEagerDataFrame from narwhals._utils import Version if t.TYPE_CHECKING: @@ -22,6 +19,7 @@ from narwhals._arrow.typing import ChunkedArrayAny from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.common import ExprIR, NamedIR + from narwhals._plan.dummy import DummyDataFrame from narwhals._plan.options import SortMultipleOptions from narwhals._plan.schema import FrozenSchema from narwhals._plan.typing import Seq @@ -29,7 +27,7 @@ from narwhals.schema import Schema -class ArrowDataFrame(DummyCompliantFrame[ArrowSeries, "pa.Table", "ChunkedArrayAny"]): +class ArrowDataFrame(DummyEagerDataFrame[ArrowSeries, "pa.Table", "ChunkedArrayAny"]): def __narwhals_namespace__(self) -> ArrowNamespace: from narwhals._plan.arrow.namespace import ArrowNamespace @@ -50,16 +48,10 @@ def schema(self) -> dict[str, DType]: def __len__(self) -> int: return self.native.num_rows - def to_narwhals(self) -> DummyFrame[pa.Table, ChunkedArrayAny]: - return DummyFrame[pa.Table, "ChunkedArrayAny"]._from_compliant(self) + def to_narwhals(self) -> DummyDataFrame[pa.Table, ChunkedArrayAny]: + from narwhals._plan.dummy import DummyDataFrame - @classmethod - def from_series( - cls, series: t.Iterable[ArrowSeries] | ArrowSeries, *more_series: ArrowSeries - ) -> Self: - lhs = (series,) if fn.is_series(series) else series - it = chain(lhs, more_series) if more_series else lhs - return cls.from_dict({s.name: s.native for s in it}) + return DummyDataFrame[pa.Table, "ChunkedArrayAny"]._from_compliant(self) @classmethod def from_dict( diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 6d05d487ca..4174263dc7 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -239,7 +239,8 @@ def sort_by(self, node: expr.SortBy, frame: ArrowDataFrame, name: str) -> ArrowE self._dispatch_expr(e, frame, f"_{idx}") for idx, e in enumerate(node.by) ) - df = frame.from_series(series, *by) + ns = self.__narwhals_namespace__() + df = ns._concat_horizontal((series, *by)) names = df.columns[1:] indices = pc.sort_indices(df.native, options=node.options.to_arrow(names)) result: ChunkedArrayAny = df.native.column(0).take(indices) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 1ae5fd159d..f1f97f5fbd 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -44,7 +44,11 @@ from narwhals._plan.lists import ExprListNamespace from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.name import ExprNameNamespace - from narwhals._plan.protocols import DummyCompliantFrame, DummyCompliantSeries + from narwhals._plan.protocols import ( + DummyCompliantDataFrame, + DummyCompliantFrame, + DummyCompliantSeries, + ) from narwhals._plan.schema import FrozenSchema from narwhals._plan.strings import ExprStringNamespace from narwhals._plan.struct import ExprStructNamespace @@ -64,7 +68,10 @@ ) -CompliantFrame: TypeAlias = "DummyCompliantFrame[t.Any, NativeFrameT, NativeSeriesT]" +CompliantFrame: TypeAlias = "DummyCompliantFrame[t.Any, NativeFrameT]" +CompliantDataFrame: TypeAlias = ( + "DummyCompliantDataFrame[t.Any, NativeFrameT, NativeSeriesT]" +) # NOTE: Trying to keep consistent logic between `DataFrame.sort` and `Expr.sort_by` @@ -831,18 +838,14 @@ def to_narwhals(self) -> DummyExpr: return DummyExprV1._from_ir(self._ir) -class DummyFrame(Generic[NativeFrameT, NativeSeriesT]): - _compliant: CompliantFrame[NativeFrameT, NativeSeriesT] +class DummyFrame(Generic[NativeFrameT]): + _compliant: CompliantFrame[NativeFrameT] _version: t.ClassVar[Version] = Version.MAIN @property def version(self) -> Version: return self._version - @property - def _series(self) -> type[DummySeries[NativeSeriesT]]: - return DummySeries[NativeSeriesT] - @property def schema(self) -> Schema: return Schema(self._compliant.schema.items()) @@ -854,22 +857,12 @@ def columns(self) -> list[str]: def __repr__(self) -> str: # pragma: no cover return generate_repr(f"nw.{type(self).__name__}", self.to_native().__repr__()) - # NOTE: Gave up on trying to get typing working for now @classmethod - def from_native( - cls, native: NativeFrame, / - ) -> DummyFrame[pa.Table, pa.ChunkedArray[t.Any]]: - if is_pyarrow_table(native): - from narwhals._plan.arrow.dataframe import ArrowDataFrame - - return ArrowDataFrame.from_native(native, cls._version).to_narwhals() - - raise NotImplementedError(type(native)) + def from_native(cls, native: t.Any, /) -> Self: + raise NotImplementedError @classmethod - def _from_compliant( - cls, compliant: CompliantFrame[NativeFrameT, NativeSeriesT], / - ) -> Self: + def _from_compliant(cls, compliant: CompliantFrame[NativeFrameT], /) -> Self: obj = cls.__new__(cls) obj._compliant = compliant return obj @@ -877,32 +870,6 @@ def _from_compliant( def to_native(self) -> NativeFrameT: return self._compliant.native - @t.overload - def to_dict( - self, *, as_series: t.Literal[True] = ... - ) -> dict[str, DummySeries[NativeSeriesT]]: ... - - @t.overload - def to_dict(self, *, as_series: t.Literal[False]) -> dict[str, list[t.Any]]: ... - - @t.overload - def to_dict( - self, *, as_series: bool - ) -> dict[str, DummySeries[NativeSeriesT]] | dict[str, list[t.Any]]: ... - - def to_dict( - self, *, as_series: bool = True - ) -> dict[str, DummySeries[NativeSeriesT]] | dict[str, list[t.Any]]: - if as_series: - return { - key: self._series._from_compliant(value) - for key, value in self._compliant.to_dict(as_series=as_series).items() - } - return self._compliant.to_dict(as_series=as_series) - - def __len__(self) -> int: - return len(self._compliant) - def _project( self, exprs: tuple[IntoExpr | Iterable[IntoExpr], ...], @@ -950,6 +917,52 @@ def sort( return self._from_compliant(self._compliant.sort(named_irs, opts, schema_frozen)) +class DummyDataFrame(DummyFrame[NativeFrameT], Generic[NativeFrameT, NativeSeriesT]): + _compliant: CompliantDataFrame[NativeFrameT, NativeSeriesT] + + @property + def _series(self) -> type[DummySeries[NativeSeriesT]]: + return DummySeries[NativeSeriesT] + + # NOTE: Gave up on trying to get typing working for now + @classmethod + def from_native( # type: ignore[override] + cls, native: NativeFrame, / + ) -> DummyDataFrame[pa.Table, pa.ChunkedArray[t.Any]]: + if is_pyarrow_table(native): + from narwhals._plan.arrow.dataframe import ArrowDataFrame + + return ArrowDataFrame.from_native(native, cls._version).to_narwhals() + + raise NotImplementedError(type(native)) + + @t.overload + def to_dict( + self, *, as_series: t.Literal[True] = ... + ) -> dict[str, DummySeries[NativeSeriesT]]: ... + + @t.overload + def to_dict(self, *, as_series: t.Literal[False]) -> dict[str, list[t.Any]]: ... + + @t.overload + def to_dict( + self, *, as_series: bool + ) -> dict[str, DummySeries[NativeSeriesT]] | dict[str, list[t.Any]]: ... + + def to_dict( + self, *, as_series: bool = True + ) -> dict[str, DummySeries[NativeSeriesT]] | dict[str, list[t.Any]]: + if as_series: + return { + key: self._series._from_compliant(value) + for key, value in self._compliant.to_dict(as_series=as_series).items() + } + return self._compliant.to_dict(as_series=as_series) + + def __len__(self) -> int: + return len(self._compliant) + + class DummySeries(Generic[NativeSeriesT]): _compliant: DummyCompliantSeries[NativeSeriesT] _version: t.ClassVar[Version] = Version.MAIN diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 214ca0cbec..2d85d3e854 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: from typing_extensions import Self, TypeAlias, TypeIs - from narwhals._plan.dummy import DummyFrame, DummySeries + from narwhals._plan.dummy import DummyDataFrame, DummyFrame, DummySeries from narwhals._plan.expr import FunctionExpr, RangeExpr from narwhals._plan.options import SortMultipleOptions from narwhals._plan.ranges import IntRange @@ -43,14 +43,19 @@ ConcatT1 = TypeVar("ConcatT1") ConcatT2 = TypeVar("ConcatT2", default=ConcatT1) +ColumnT = TypeVar("ColumnT") +ColumnT_co = TypeVar("ColumnT_co", covariant=True) + ExprAny: TypeAlias = "CompliantExpr[Any, Any]" ScalarAny: TypeAlias = "CompliantScalar[Any, Any]" SeriesAny: TypeAlias = "DummyCompliantSeries[Any]" -FrameAny: TypeAlias = "DummyCompliantFrame[Any, Any, Any]" +FrameAny: TypeAlias = "DummyCompliantFrame[Any, Any]" +DataFrameAny: TypeAlias = "DummyCompliantDataFrame[Any, Any, Any]" NamespaceAny: TypeAlias = "CompliantNamespace[Any, Any, Any]" EagerExprAny: TypeAlias = "EagerExpr[Any, Any]" EagerScalarAny: TypeAlias = "EagerScalar[Any, Any]" +EagerDataFrameAny: TypeAlias = "DummyEagerDataFrame[Any, Any, Any]" LazyExprAny: TypeAlias = "LazyExpr[Any, Any, Any]" LazyScalarAny: TypeAlias = "LazyScalar[Any, Any, Any]" @@ -66,6 +71,7 @@ EagerExprT_co = TypeVar("EagerExprT_co", bound=EagerExprAny, covariant=True) EagerScalarT_co = TypeVar("EagerScalarT_co", bound=EagerScalarAny, covariant=True) +EagerDataFrameT = TypeVar("EagerDataFrameT", bound=EagerDataFrameAny) LazyExprT_co = TypeVar("LazyExprT_co", bound=LazyExprAny, covariant=True) LazyScalarT_co = TypeVar("LazyScalarT_co", bound=LazyScalarAny, covariant=True) @@ -604,44 +610,44 @@ def int_range( class EagerNamespace( - EagerConcat[FrameT, SeriesT], - CompliantNamespace[FrameT, EagerExprT_co, EagerScalarT_co], - Protocol[FrameT, SeriesT, EagerExprT_co, EagerScalarT_co], + EagerConcat[EagerDataFrameT, SeriesT], + CompliantNamespace[EagerDataFrameT, EagerExprT_co, EagerScalarT_co], + Protocol[EagerDataFrameT, SeriesT, EagerExprT_co, EagerScalarT_co], ): @property def _series(self) -> type[SeriesT]: ... @property - def _dataframe(self) -> type[FrameT]: ... + def _dataframe(self) -> type[EagerDataFrameT]: ... @property - def _frame(self) -> type[FrameT]: + def _frame(self) -> type[EagerDataFrameT]: return self._dataframe def _is_series(self, obj: Any) -> TypeIs[SeriesT]: return isinstance(obj, self._series) - def _is_dataframe(self, obj: Any) -> TypeIs[FrameT]: + def _is_dataframe(self, obj: Any) -> TypeIs[EagerDataFrameT]: return isinstance(obj, self._dataframe) @overload def lit( - self, node: expr.Literal[NonNestedLiteral], frame: FrameT, name: str + self, node: expr.Literal[NonNestedLiteral], frame: EagerDataFrameT, name: str ) -> EagerScalarT_co: ... @overload def lit( - self, node: expr.Literal[DummySeries[Any]], frame: FrameT, name: str + self, node: expr.Literal[DummySeries[Any]], frame: EagerDataFrameT, name: str ) -> EagerExprT_co: ... @overload def lit( self, node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[Any]], - frame: FrameT, + frame: EagerDataFrameT, name: str, ) -> EagerExprT_co | EagerScalarT_co: ... def lit( - self, node: expr.Literal[Any], frame: FrameT, name: str + self, node: expr.Literal[Any], frame: EagerDataFrameT, name: str ) -> EagerExprT_co | EagerScalarT_co: ... - def len(self, node: expr.Len, frame: FrameT, name: str) -> EagerScalarT_co: + def len(self, node: expr.Len, frame: EagerDataFrameT, name: str) -> EagerScalarT_co: return self._scalar.from_python( len(frame), name or node.name, dtype=None, version=frame.version ) @@ -659,19 +665,17 @@ def _frame(self) -> type[FrameT]: return self._lazyframe -class DummyCompliantFrame(StoresVersion, Protocol[SeriesT, NativeFrameT, NativeSeriesT]): +class DummyCompliantFrame(StoresVersion, Protocol[ColumnT_co, NativeFrameT]): _native: NativeFrameT def __narwhals_namespace__(self) -> Any: ... - @property def native(self) -> NativeFrameT: return self._native @property def columns(self) -> list[str]: ... - - def to_narwhals(self) -> DummyFrame[NativeFrameT, NativeSeriesT]: ... + def to_narwhals(self) -> DummyFrame[NativeFrameT]: ... @classmethod def from_native(cls, native: NativeFrameT, /, version: Version) -> Self: @@ -683,13 +687,22 @@ def from_native(cls, native: NativeFrameT, /, version: Version) -> Self: def _with_native(self, native: NativeFrameT) -> Self: return self.from_native(native, self.version) - @classmethod - def from_series( - cls, series: Iterable[SeriesT] | SeriesT, *more_series: SeriesT - ) -> Self: - """Return a new DataFrame, horizontally concatenating multiple Series.""" - ... + @property + def schema(self) -> Mapping[str, DType]: ... + def _evaluate_irs( + self, nodes: Iterable[NamedIR[ExprIR]], / + ) -> Iterator[ColumnT_co]: ... + def select(self, irs: Seq[NamedIR], projected: FrozenSchema) -> Self: ... + def with_columns(self, irs: Seq[NamedIR], projected: FrozenSchema) -> Self: ... + def sort( + self, by: Seq[NamedIR], options: SortMultipleOptions, projected: FrozenSchema + ) -> Self: ... + +class DummyCompliantDataFrame( + DummyCompliantFrame[SeriesT, NativeFrameT], + Protocol[SeriesT, NativeFrameT, NativeSeriesT], +): @classmethod def from_dict( cls, @@ -699,6 +712,8 @@ def from_dict( schema: Mapping[str, DType] | Schema | None = None, ) -> Self: ... + def to_narwhals(self) -> DummyDataFrame[NativeFrameT, NativeSeriesT]: ... + @overload def to_dict(self, *, as_series: Literal[True]) -> dict[str, SeriesT]: ... @overload @@ -714,20 +729,19 @@ def to_dict( def __len__(self) -> int: ... - @property - def schema(self) -> Mapping[str, DType]: ... - - def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[SeriesT]: ... +class DummyEagerDataFrame( + DummyCompliantDataFrame[SeriesT, NativeFrameT, NativeSeriesT], + Protocol[SeriesT, NativeFrameT, NativeSeriesT], +): + def __narwhals_namespace__(self) -> EagerNamespace[Self, SeriesT, Any, Any]: ... def select(self, irs: Seq[NamedIR], projected: FrozenSchema) -> Self: - return self.from_series(self._evaluate_irs(irs)) + ns = self.__narwhals_namespace__() + return ns._concat_horizontal(self._evaluate_irs(irs)) def with_columns(self, irs: Seq[NamedIR], projected: FrozenSchema) -> Self: - return self.from_series(self._evaluate_irs(irs)) - - def sort( - self, by: Seq[NamedIR], options: SortMultipleOptions, projected: FrozenSchema - ) -> Self: ... + ns = self.__narwhals_namespace__() + return ns._concat_horizontal(self._evaluate_irs(irs)) class DummyCompliantSeries(StoresVersion, Protocol[NativeSeriesT]): diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 2ed3c707a7..fa40c5ea91 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -13,6 +13,7 @@ import narwhals as nw from narwhals._plan import demo as nwd, selectors as ndcs from narwhals._plan.common import is_expr +from narwhals._plan.dummy import DummyDataFrame from narwhals._utils import Version from narwhals.exceptions import ComputeError from tests.utils import assert_equal_data @@ -394,10 +395,8 @@ def test_select( expected: dict[str, Any], data_small: dict[str, Any], ) -> None: - from narwhals._plan.dummy import DummyFrame - frame = pa.table(data_small) - df = DummyFrame.from_native(frame) + df = DummyDataFrame.from_native(frame) result = df.select(expr).to_dict(as_series=False) assert_equal_data(result, expected) @@ -470,10 +469,8 @@ def test_with_columns( expected: dict[str, Any], data_smaller: dict[str, Any], ) -> None: - from narwhals._plan.dummy import DummyFrame - frame = pa.table(data_smaller) - df = DummyFrame.from_native(frame) + df = DummyDataFrame.from_native(frame) result = df.with_columns(expr).to_dict(as_series=False) assert_equal_data(result, expected) From 250922faf1ccd5b9b2d5ccf739e2d0387b1459af Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 23 Jul 2025 17:04:04 +0000 Subject: [PATCH 342/368] revert: Leave out unused `schema_projected` for now Getting too complicated to work on `over` with this extra unused arg Can add back later if/when it would actually be used for query planning --- narwhals/_plan/arrow/dataframe.py | 7 ++----- narwhals/_plan/dummy.py | 8 +++----- narwhals/_plan/protocols.py | 13 +++++-------- 3 files changed, 10 insertions(+), 18 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 11b4a8e3e6..26e0f58c81 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -21,7 +21,6 @@ from narwhals._plan.common import ExprIR, NamedIR from narwhals._plan.dummy import DummyDataFrame from narwhals._plan.options import SortMultipleOptions - from narwhals._plan.schema import FrozenSchema from narwhals._plan.typing import Seq from narwhals.dtypes import DType from narwhals.schema import Schema @@ -94,9 +93,7 @@ def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[ArrowSe # NOTE: Not handling actual expressions yet # `DummyFrame` is typed for just `str` names - def sort( - self, by: Seq[NamedIR], options: SortMultipleOptions, projected: FrozenSchema - ) -> Self: - df_by = self.select(by, projected) + def sort(self, by: Seq[NamedIR], options: SortMultipleOptions) -> Self: + df_by = self.select(by) indices = pc.sort_indices(df_by.native, options=options.to_arrow(df_by.columns)) return self._with_native(self.native.take(indices)) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index f1f97f5fbd..75b17255f6 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -888,7 +888,7 @@ def select(self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: t.Any) -> named_irs, schema_projected = self._project( exprs, named_exprs, ExprContext.SELECT ) - return self._from_compliant(self._compliant.select(named_irs, schema_projected)) + return self._from_compliant(self._compliant.select(named_irs)) def with_columns( self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: t.Any @@ -896,9 +896,7 @@ def with_columns( named_irs, schema_projected = self._project( exprs, named_exprs, ExprContext.WITH_COLUMNS ) - return self._from_compliant( - self._compliant.with_columns(named_irs, schema_projected) - ) + return self._from_compliant(self._compliant.with_columns(named_irs)) def sort( self, @@ -914,7 +912,7 @@ def sort( sort, self.schema ) named_irs = expr_expansion.into_named_irs(irs, output_names) - return self._from_compliant(self._compliant.sort(named_irs, opts, schema_frozen)) + return self._from_compliant(self._compliant.sort(named_irs, opts)) class DummyDataFrame(DummyFrame[NativeFrameT], Generic[NativeFrameT, NativeSeriesT]): diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 2d85d3e854..1eacc72b98 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -22,7 +22,6 @@ from narwhals._plan.expr import FunctionExpr, RangeExpr from narwhals._plan.options import SortMultipleOptions from narwhals._plan.ranges import IntRange - from narwhals._plan.schema import FrozenSchema from narwhals.dtypes import DType from narwhals.schema import Schema from narwhals.typing import ( @@ -692,11 +691,9 @@ def schema(self) -> Mapping[str, DType]: ... def _evaluate_irs( self, nodes: Iterable[NamedIR[ExprIR]], / ) -> Iterator[ColumnT_co]: ... - def select(self, irs: Seq[NamedIR], projected: FrozenSchema) -> Self: ... - def with_columns(self, irs: Seq[NamedIR], projected: FrozenSchema) -> Self: ... - def sort( - self, by: Seq[NamedIR], options: SortMultipleOptions, projected: FrozenSchema - ) -> Self: ... + def select(self, irs: Seq[NamedIR]) -> Self: ... + def with_columns(self, irs: Seq[NamedIR]) -> Self: ... + def sort(self, by: Seq[NamedIR], options: SortMultipleOptions) -> Self: ... class DummyCompliantDataFrame( @@ -735,11 +732,11 @@ class DummyEagerDataFrame( Protocol[SeriesT, NativeFrameT, NativeSeriesT], ): def __narwhals_namespace__(self) -> EagerNamespace[Self, SeriesT, Any, Any]: ... - def select(self, irs: Seq[NamedIR], projected: FrozenSchema) -> Self: + def select(self, irs: Seq[NamedIR]) -> Self: ns = self.__narwhals_namespace__() return ns._concat_horizontal(self._evaluate_irs(irs)) - def with_columns(self, irs: Seq[NamedIR], projected: FrozenSchema) -> Self: + def with_columns(self, irs: Seq[NamedIR]) -> Self: ns = self.__narwhals_namespace__() return ns._concat_horizontal(self._evaluate_irs(irs)) From 8729ebe5ccade31d82c96a53f3da9e43185e941b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 23 Jul 2025 17:06:54 +0000 Subject: [PATCH 343/368] refactor(typing): Relax `_concat_horizontal` The mixed case is already handled in the only impl --- narwhals/_plan/arrow/namespace.py | 2 +- narwhals/_plan/protocols.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index 7a00edf51f..19946d8176 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -243,7 +243,7 @@ def _concat_diagonal(self, items: Iterable[ArrowDataFrame]) -> ArrowDataFrame: ) def _concat_horizontal( - self, items: Iterable[ArrowDataFrame] | Iterable[ArrowSeries] + self, items: Iterable[ArrowDataFrame | ArrowSeries] ) -> ArrowDataFrame: def gen( objs: Iterable[ArrowDataFrame | ArrowSeries], diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 1eacc72b98..ec193003c1 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -548,9 +548,7 @@ class EagerConcat(Concat[ConcatT1, ConcatT2], Protocol[ConcatT1, ConcatT2]): # def _concat_diagonal(self, items: Iterable[ConcatT1], /) -> ConcatT1: ... # Series can be used here to go from [Series, Series] -> DataFrame # but that is only available privately - def _concat_horizontal( - self, items: Iterable[ConcatT1] | Iterable[ConcatT2], / - ) -> ConcatT1: ... + def _concat_horizontal(self, items: Iterable[ConcatT1 | ConcatT2], /) -> ConcatT1: ... def _concat_vertical( self, items: Iterable[ConcatT1] | Iterable[ConcatT2], / ) -> ConcatT1 | ConcatT2: ... From 6ef9375c1a273498decbec963ad5d21f42074cce Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 23 Jul 2025 20:00:11 +0000 Subject: [PATCH 344/368] feat(pyarrow): Impl `over_ordered` New test is one that currently fails in #2528 https://github.com/narwhals-dev/narwhals/pull/2528#discussion_r2222598488 --- narwhals/_plan/arrow/dataframe.py | 14 +++++++++- narwhals/_plan/arrow/expr.py | 31 +++++++++++++++++++--- narwhals/_plan/common.py | 9 +++++++ narwhals/_plan/options.py | 16 ++++++++--- narwhals/_plan/protocols.py | 1 + tests/plan/compliant_test.py | 44 +++++++++++++++++++++++++++++++ 6 files changed, 106 insertions(+), 9 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 26e0f58c81..297ca6d281 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -6,13 +6,14 @@ import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import native_to_narwhals_dtype +from narwhals._plan.arrow import functions as fn from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.common import ExprIR from narwhals._plan.protocols import DummyEagerDataFrame from narwhals._utils import Version if t.TYPE_CHECKING: - from collections.abc import Iterable, Iterator + from collections.abc import Iterable, Iterator, Sequence from typing_extensions import Self @@ -97,3 +98,14 @@ def sort(self, by: Seq[NamedIR], options: SortMultipleOptions) -> Self: df_by = self.select(by) indices = pc.sort_indices(df_by.native, options=options.to_arrow(df_by.columns)) return self._with_native(self.native.take(indices)) + + def with_row_index(self, name: str) -> Self: + return self._with_native(self.native.add_column(0, name, fn.int_range(len(self)))) + + def get_column(self, name: str) -> ArrowSeries: + chunked = self.native.column(name) + return ArrowSeries.from_native(chunked, name, version=self.version) + + def drop(self, columns: Sequence[str]) -> Self: + to_drop = list(columns) + return self._with_native(self.native.drop(to_drop)) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 4174263dc7..9f78b4f9e4 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -10,9 +10,15 @@ from narwhals._plan.arrow.functions import lit from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.arrow.typing import NativeScalar, StoresNativeT_co -from narwhals._plan.common import ExprIR, into_dtype +from narwhals._plan.common import ExprIR, NamedIR, into_dtype from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch -from narwhals._utils import Implementation, Version, _StoresNative, not_implemented +from narwhals._utils import ( + Implementation, + Version, + _StoresNative, + generate_temporary_column_name, + not_implemented, +) from narwhals.exceptions import InvalidOperationError, ShapeError if TYPE_CHECKING: @@ -326,7 +332,7 @@ def min(self, node: Min, frame: ArrowDataFrame, name: str) -> ArrowScalar: # TODO @dangotbanned: top-level, complex-ish nodes # - [ ] `over`/`_ordered` (with partitions) requires `group_by`, `join` - # - [ ] `over_ordered` alone should be possible w/ the current API + # - [x] `over_ordered` alone should be possible w/ the current API # - [x] `map_batches` is defined in `EagerExpr`, might be simpler here than on main # - [ ] `rolling_expr` has 4 variants @@ -336,7 +342,24 @@ def over(self, node: WindowExpr, frame: ArrowDataFrame, name: str) -> Self: def over_ordered( self, node: OrderedWindowExpr, frame: ArrowDataFrame, name: str ) -> Self: - raise NotImplementedError + if node.partition_by: + msg = f"Need to implement `group_by`, `join` for:\n{node!r}" + raise NotImplementedError(msg) + + # NOTE: Converting `over(order_by=..., options=...)` into the right shape for `DataFrame.sort` + sort_by = tuple(NamedIR.from_ir(e) for e in node.order_by) + options = node.sort_options.to_multiple(len(node.order_by)) + + idx_name = generate_temporary_column_name(8, frame.columns) + sorted_context = frame.with_row_index(idx_name).sort(sort_by, options) + height = len(sorted_context) + + evaluated = self._dispatch(node.expr, sorted_context.drop([idx_name]), name) + + # NOTE: Might be able to skip this if ^^^ returned a scalar, since len == 1 is already sorted + indices = pc.sort_indices(sorted_context.get_column(idx_name).native) + result = evaluated.broadcast(height).native.take(indices) + return self._with_native(result, name) # NOTE: Can't implement in `EagerExpr`, since it doesn't derive `ExprDispatch` def map_batches(self, node: AnonymousExpr, frame: ArrowDataFrame, name: str) -> Self: diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index e3b0e3a780..92369cf590 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -320,6 +320,15 @@ def from_name(name: str, /) -> NamedIR[Column]: return NamedIR(expr=col(name), name=name) + @staticmethod + def from_ir(expr: ExprIRT2, /) -> NamedIR[ExprIRT2]: + """Construct from an already expanded `ExprIR`. + + Should be cheap to get the output name from cache, but will raise if used + without care. + """ + return NamedIR(expr=expr, name=expr.meta.output_name(raise_if_undetermined=True)) + def map_ir(self, function: MapIR, /) -> NamedIR[ExprIR]: """**WARNING**: don't use renaming ops here, or `self.name` is invalid.""" return self.with_expr(function(self.expr.map_ir(function))) diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 63dbd59250..29b7fa4637 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -158,6 +158,15 @@ def to_arrow(self) -> pc.ArraySortOptions: null_placement=("at_end" if self.nulls_last else "at_start"), ) + def to_multiple(self, n_repeat: int = 1, /) -> SortMultipleOptions: + if n_repeat == 1: + desc: Seq[bool] = (self.descending,) + nulls: Seq[bool] = (self.nulls_last,) + else: + desc = tuple(repeat(self.descending, n_repeat)) + nulls = tuple(repeat(self.nulls_last)) + return SortMultipleOptions(descending=desc, nulls_last=nulls) + class SortMultipleOptions(Immutable): __slots__ = ("descending", "nulls_last") @@ -181,12 +190,10 @@ def parse( def to_arrow(self, by: Sequence[str]) -> pc.SortOptions: import pyarrow.compute as pc - if len(self.nulls_last) != 1: + first = self.nulls_last[0] + if len(self.nulls_last) != 1 and any(x != first for x in self.nulls_last[1:]): msg = f"pyarrow doesn't support multiple values for `nulls_last`, got: {self.nulls_last!r}" raise NotImplementedError(msg) - placement: Literal["at_start", "at_end"] = ( - "at_end" if self.nulls_last[0] else "at_start" - ) if len(self.descending) == 1: descending: Iterable[bool] = repeat(self.descending[0], len(by)) else: @@ -195,6 +202,7 @@ def to_arrow(self, by: Sequence[str]) -> pc.SortOptions: (key, "descending" if desc else "ascending") for key, desc in zip(by, descending) ] + placement: Literal["at_start", "at_end"] = "at_end" if first else "at_start" return pc.SortOptions(sort_keys=sorting, null_placement=placement) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index ec193003c1..bd0ee7fb66 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -723,6 +723,7 @@ def to_dict( ) -> dict[str, SeriesT] | dict[str, list[Any]]: ... def __len__(self) -> int: ... + def with_row_index(self, name: str) -> Self: ... class DummyEagerDataFrame( diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index fa40c5ea91..f702bcd460 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -22,6 +22,7 @@ from collections.abc import Sequence from narwhals._plan.dummy import DummyExpr + from narwhals.typing import PythonLiteral @pytest.fixture @@ -52,6 +53,18 @@ def data_smaller(data_small: dict[str, Any]) -> dict[str, Any]: return {k: v for k, v in data_small.items() if k in keep} +@pytest.fixture +def data_indexed() -> dict[str, Any]: + """Used in https://github.com/narwhals-dev/narwhals/pull/2528.""" + return { + "a": [8, 2, 1, None], + "b": [58, 5, 6, 12], + "c": [2.5, 1.0, 3.0, 0.9], + "d": [2, 1, 4, 3], + "idx": [0, 1, 2, 3], + } + + def _ids_ir(expr: DummyExpr | Any) -> str: if is_expr(expr): return repr(expr._ir) @@ -475,6 +488,37 @@ def test_with_columns( assert_equal_data(result, expected) +def first(*names: str) -> DummyExpr: + return nwd.col(*names).first() + + +def last(*names: str) -> DummyExpr: + return nwd.col(*names).last() + + +@pytest.mark.parametrize( + ("agg", "expected"), + [ + (first("a"), 8), + (first("b"), 58), + (first("c"), 2.5), + (last("a"), None), + (last("b"), 12), + (last("c"), 0.9), + ], +) +def test_first_last_expr_with_columns( + data_indexed: dict[str, Any], agg: DummyExpr, expected: PythonLiteral +) -> None: + """Related https://github.com/narwhals-dev/narwhals/pull/2528#discussion_r2225930065.""" + height = len(next(iter(data_indexed.values()))) + expected_broadcast = height * [expected] + frame = DummyDataFrame.from_native(pa.table(data_indexed)) + expr = agg.over(order_by="idx").alias("result") + result = frame.with_columns(expr).select("result").to_dict(as_series=False) + assert_equal_data(result, {"result": expected_broadcast}) + + if TYPE_CHECKING: def test_protocol_expr() -> None: From d3bed76ebf854c1e91f64ec7396e98fd0ea88f17 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 23 Jul 2025 20:02:46 +0000 Subject: [PATCH 345/368] oop #2845 --- utils/import_check.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/import_check.py b/utils/import_check.py index cd753215ac..17f6616fa6 100644 --- a/utils/import_check.py +++ b/utils/import_check.py @@ -42,14 +42,14 @@ def __init__(self, file_name: str, lines: list[str]) -> None: else: self.allowed_imports = set() - def visit_If(self, node: ast.If) -> None: # noqa: N802 + def visit_If(self, node: ast.If) -> None: # Check if the condition is `if TYPE_CHECKING` if isinstance(node.test, ast.Name) and node.test.id == "TYPE_CHECKING": # Skip the body of this if statement return self.generic_visit(node) - def visit_Import(self, node: ast.Import) -> None: # noqa: N802 + def visit_Import(self, node: ast.Import) -> None: for alias in node.names: if ( alias.name in BANNED_IMPORTS @@ -63,7 +63,7 @@ def visit_Import(self, node: ast.Import) -> None: # noqa: N802 self.generic_visit(node) - def visit_ImportFrom(self, node: ast.ImportFrom) -> None: # noqa: N802 + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: if ( node.module in BANNED_IMPORTS and "# ignore-banned-import" not in self.lines[node.lineno - 1] From 2c8570198582c39767b7b30ff7aa42b69ca7ea0a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 30 Jul 2025 14:53:08 +0000 Subject: [PATCH 346/368] add `_with_columns` --- narwhals/_plan/arrow/dataframe.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 297ca6d281..51445d5024 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -18,6 +18,7 @@ from typing_extensions import Self from narwhals._arrow.typing import ChunkedArrayAny + from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.common import ExprIR, NamedIR from narwhals._plan.dummy import DummyDataFrame @@ -109,3 +110,18 @@ def get_column(self, name: str) -> ArrowSeries: def drop(self, columns: Sequence[str]) -> Self: to_drop = list(columns) return self._with_native(self.native.drop(to_drop)) + + # NOTE: Use instead of `with_columns` for trivial cases + def _with_columns(self, exprs: Iterable[ArrowExpr | ArrowScalar], /) -> Self: + native = self.native + columns = self.columns + height = len(self) + for into_series in exprs: + name = into_series.name + chunked = into_series.broadcast(height).native + if name in columns: + i = columns.index(name) + native = native.set_column(i, name, chunked) + else: + native = native.append_column(name, chunked) + return self._with_native(native) From 94d533d7d3f3e0fbe41ff1bad609d9693a893054 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 30 Jul 2025 15:41:49 +0000 Subject: [PATCH 347/368] perf: Add early return path for `over_ordered` Related #2528 --- narwhals/_plan/arrow/expr.py | 14 ++++++++------ narwhals/_plan/protocols.py | 4 +++- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 9f78b4f9e4..295b91431e 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -341,7 +341,7 @@ def over(self, node: WindowExpr, frame: ArrowDataFrame, name: str) -> Self: def over_ordered( self, node: OrderedWindowExpr, frame: ArrowDataFrame, name: str - ) -> Self: + ) -> Self | ArrowScalar: if node.partition_by: msg = f"Need to implement `group_by`, `join` for:\n{node!r}" raise NotImplementedError(msg) @@ -349,15 +349,17 @@ def over_ordered( # NOTE: Converting `over(order_by=..., options=...)` into the right shape for `DataFrame.sort` sort_by = tuple(NamedIR.from_ir(e) for e in node.order_by) options = node.sort_options.to_multiple(len(node.order_by)) - idx_name = generate_temporary_column_name(8, frame.columns) sorted_context = frame.with_row_index(idx_name).sort(sort_by, options) - height = len(sorted_context) - evaluated = self._dispatch(node.expr, sorted_context.drop([idx_name]), name) - - # NOTE: Might be able to skip this if ^^^ returned a scalar, since len == 1 is already sorted + if isinstance(evaluated, ArrowScalar): + # NOTE: We're already sorted, defer broadcasting to the outer context + # Wouldn't be suitable for partitions, but will be fine here + # - https://github.com/narwhals-dev/narwhals/pull/2528/commits/2ae42458cae91f4473e01270919815fcd7cb9667 + # - https://github.com/narwhals-dev/narwhals/pull/2528/commits/b8066c4c57d4b0b6c38d58a0f5de05eefc2cae70 + return self._with_native(evaluated.native, name) indices = pc.sort_indices(sorted_context.get_column(idx_name).native) + height = len(sorted_context) result = evaluated.broadcast(height).native.take(indices) return self._with_native(result, name) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index bd0ee7fb66..c66dca1d3d 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -337,9 +337,11 @@ def ternary_expr( self, node: expr.Ternary, frame: FrameT_contra, name: str ) -> Self: ... def over(self, node: expr.WindowExpr, frame: FrameT_contra, name: str) -> Self: ... + # NOTE: `Scalar` is returned **only** for un-partitioned `OrderableAggExpr` + # e.g. `nw.col("a").first().over(order_by="b")` def over_ordered( self, node: expr.OrderedWindowExpr, frame: FrameT_contra, name: str - ) -> Self: ... + ) -> Self | CompliantScalar[FrameT_contra, SeriesT_co]: ... def map_batches( self, node: expr.AnonymousExpr, frame: FrameT_contra, name: str ) -> Self: ... From 25549de09a392394488bb253ba8b001b6fc91952 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 15 Aug 2025 17:48:00 +0000 Subject: [PATCH 348/368] style(ruff): re-run updated config --- narwhals/_plan/arrow/expr.py | 7 +++---- narwhals/_plan/arrow/namespace.py | 13 +++++-------- narwhals/_plan/common.py | 4 ++-- narwhals/_plan/expr.py | 3 +-- narwhals/_plan/expr_expansion.py | 7 +++---- narwhals/_plan/temporal.py | 2 +- tests/plan/utils.py | 3 +-- 7 files changed, 16 insertions(+), 23 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 295b91431e..ef6e0164ed 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -426,13 +426,12 @@ def from_python( def from_series(cls, series: ArrowSeries) -> Self: if len(series) == 1: return cls.from_native(series.native[0], series.name, series.version) - elif len(series) == 0: + if len(series) == 0: return cls.from_python( None, series.name, dtype=series.dtype, version=series.version ) - else: - msg = f"Too long {len(series)!r}" - raise InvalidOperationError(msg) + msg = f"Too long {len(series)!r}" + raise InvalidOperationError(msg) def _dispatch_expr( self, node: ExprIR, frame: ArrowDataFrame, name: str diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index 19946d8176..68a1339300 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -203,12 +203,10 @@ def int_range( native = fn.int_range(start_, end_, step, dtype=pa_dtype) return self._expr.from_native(native, name, self.version) - else: - msg = ( - f"All inputs for `int_range()` resolve to int, but got \n" - f"{start_!r}\n{end_!r}" - ) - raise InvalidOperationError(msg) + msg = ( + f"All inputs for `int_range()` resolve to int, but got \n{start_!r}\n{end_!r}" + ) + raise InvalidOperationError(msg) @overload def concat( @@ -279,5 +277,4 @@ def _concat_vertical( ) raise TypeError(msg) return df._with_native(fn.concat_vertical_table(df.native for df in dfs)) - else: - raise TypeError(items) + raise TypeError(items) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 92369cf590..fb2b639851 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -121,7 +121,7 @@ def __hash__(self) -> int: def __eq__(self, other: object) -> bool: if self is other: return True - elif type(self) is not type(other): + if type(self) is not type(other): return False return all( getattr(self, key) == getattr(other, key) for key in self.__immutable_keys__ @@ -162,7 +162,7 @@ def _field_str(name: str, value: Any) -> str: if isinstance(value, tuple): inner = ", ".join(f"{v}" for v in value) return f"{name}=[{inner}]" - elif isinstance(value, str): + if isinstance(value, str): return f"{name}={value!r}" return f"{name}={value}" diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 561aa3d96e..786af72b20 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -432,8 +432,7 @@ def __repr__(self) -> str: if len(self.input) >= 2: return f"{first!r}.{self.function!r}({list(self.input[1:])!r})" return f"{first!r}.{self.function!r}()" - else: - return f"{self.function!r}()" + return f"{self.function!r}()" def iter_left(self) -> t.Iterator[ExprIR]: for e in self.input: diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 1075dc885e..ada5bcb755 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -470,14 +470,13 @@ def rewrite_special_aliases(origin: ExprIR, /) -> ExprIR: roots = parent.meta.root_names() alias = next(iter(roots)) return Alias(expr=parent, name=alias) - elif isinstance(origin, RenameAlias): + if isinstance(origin, RenameAlias): parent = origin.expr leaf_name_or_err = meta.get_single_leaf_name(parent) if not isinstance(leaf_name_or_err, str): raise leaf_name_or_err alias = origin.function(leaf_name_or_err) return Alias(expr=parent, name=alias) - else: - msg = "`keep`, `suffix`, `prefix` should be last expression" - raise InvalidOperationError(msg) + msg = "`keep`, `suffix`, `prefix` should be last expression" + raise InvalidOperationError(msg) return origin diff --git a/narwhals/_plan/temporal.py b/narwhals/_plan/temporal.py index 301bfc95a4..11956622b7 100644 --- a/narwhals/_plan/temporal.py +++ b/narwhals/_plan/temporal.py @@ -28,7 +28,7 @@ def __repr__(self) -> str: tp = type(self) if tp is TemporalFunction: return tp.__name__ - elif tp is Timestamp: + if tp is Timestamp: tu = cast("Timestamp", self).time_unit return f"dt.timestamp[{tu!r}]" m: dict[type[TemporalFunction], str] = { diff --git a/tests/plan/utils.py b/tests/plan/utils.py index d284fb4821..e02604871a 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -17,8 +17,7 @@ def _unwrap_ir(obj: DummyExpr | ExprIR | NamedIR) -> ExprIR: return obj if isinstance(obj, NamedIR): return obj.expr - else: - raise NotImplementedError(type(obj)) + raise NotImplementedError(type(obj)) def assert_expr_ir_equal( From a12a98560fde4ca4641f47d9249bdaf29df85132 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 19 Aug 2025 14:07:48 +0000 Subject: [PATCH 349/368] chore(ruff): re-run updated config x2 --- narwhals/_plan/arrow/namespace.py | 2 +- narwhals/_plan/demo.py | 2 +- narwhals/_plan/dummy.py | 2 +- narwhals/_plan/protocols.py | 8 +------- tests/plan/compliant_test.py | 1 - tests/plan/expr_parsing_test.py | 7 +------ 6 files changed, 5 insertions(+), 17 deletions(-) diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index 68a1339300..62dfbeac2f 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -19,7 +19,7 @@ from collections.abc import Callable, Iterable, Iterator, Sequence from narwhals._arrow.typing import ChunkedArrayAny - from narwhals._plan import expr, functions as F # noqa: N812 + from narwhals._plan import expr, functions as F from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar from narwhals._plan.arrow.series import ArrowSeries diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index e0f3f106da..9913b084d9 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -7,7 +7,7 @@ aggregation as agg, boolean, expr_parsing as parse, - functions as F, # noqa: N812 + functions as F, ) from narwhals._plan.common import ( into_dtype, diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 75b17255f6..a67e268751 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -12,7 +12,7 @@ expr, expr_expansion, expr_parsing as parse, - functions as F, # noqa: N812 + functions as F, operators as ops, ) from narwhals._plan.common import NamedIR, into_dtype, is_column, is_expr, is_series diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index c66dca1d3d..233fe270ec 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -3,13 +3,7 @@ from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence, Sized from typing import TYPE_CHECKING, Any, ClassVar, Literal, Protocol, overload -from narwhals._plan import ( # noqa: N812 - aggregation as agg, - boolean, - expr, - functions as F, - strings, -) +from narwhals._plan import aggregation as agg, boolean, expr, functions as F, strings from narwhals._plan.common import ExprIR, Function, NamedIR, flatten_hash_safe from narwhals._plan.typing import NativeFrameT, NativeSeriesT, Seq from narwhals._typing_compat import TypeVar diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index f702bcd460..b6a868ccf4 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -1,6 +1,5 @@ from __future__ import annotations -# ruff: noqa: FBT003 from typing import TYPE_CHECKING, Any import pytest diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 79b1ff0be2..7b65848597 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -11,12 +11,7 @@ import narwhals as nw import narwhals._plan.demo as nwd -from narwhals._plan import ( - boolean, - expr, - functions as F, # noqa: N812 - operators as ops, -) +from narwhals._plan import boolean, expr, functions as F, operators as ops from narwhals._plan.common import ExprIR, Function from narwhals._plan.dummy import DummyExpr, DummySeries from narwhals._plan.expr import BinaryExpr, FunctionExpr, RangeExpr From 62030ae241da66307ad12f459494f35ec3a512ad Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 19 Aug 2025 16:47:52 +0000 Subject: [PATCH 350/368] refactor(expr-ir): Rename `Dummy*` everything (#3014) --- narwhals/_plan/arrow/dataframe.py | 14 ++-- narwhals/_plan/arrow/namespace.py | 8 +- narwhals/_plan/arrow/series.py | 4 +- narwhals/_plan/boolean.py | 8 +- narwhals/_plan/categorical.py | 4 +- narwhals/_plan/common.py | 50 ++++++------- narwhals/_plan/demo.py | 48 ++++++------ narwhals/_plan/dummy.py | 119 ++++++++++++------------------ narwhals/_plan/expr_expansion.py | 4 +- narwhals/_plan/lists.py | 4 +- narwhals/_plan/literal.py | 12 +-- narwhals/_plan/name.py | 14 ++-- narwhals/_plan/protocols.py | 36 ++++----- narwhals/_plan/selectors.py | 20 ++--- narwhals/_plan/strings.py | 32 ++++---- narwhals/_plan/struct.py | 4 +- narwhals/_plan/temporal.py | 46 ++++++------ narwhals/_plan/typing.py | 8 +- narwhals/_plan/when_then.py | 28 +++---- tests/plan/compliant_test.py | 26 +++---- tests/plan/expr_expansion_test.py | 18 ++--- tests/plan/expr_parsing_test.py | 20 ++--- tests/plan/expr_rewrites_test.py | 4 +- tests/plan/meta_test.py | 8 +- tests/plan/utils.py | 8 +- 25 files changed, 250 insertions(+), 297 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 51445d5024..214fe908ff 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -9,7 +9,7 @@ from narwhals._plan.arrow import functions as fn from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.common import ExprIR -from narwhals._plan.protocols import DummyEagerDataFrame +from narwhals._plan.protocols import EagerDataFrame from narwhals._utils import Version if t.TYPE_CHECKING: @@ -21,14 +21,14 @@ from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.common import ExprIR, NamedIR - from narwhals._plan.dummy import DummyDataFrame + from narwhals._plan.dummy import DataFrame from narwhals._plan.options import SortMultipleOptions from narwhals._plan.typing import Seq from narwhals.dtypes import DType from narwhals.schema import Schema -class ArrowDataFrame(DummyEagerDataFrame[ArrowSeries, "pa.Table", "ChunkedArrayAny"]): +class ArrowDataFrame(EagerDataFrame[ArrowSeries, "pa.Table", "ChunkedArrayAny"]): def __narwhals_namespace__(self) -> ArrowNamespace: from narwhals._plan.arrow.namespace import ArrowNamespace @@ -49,10 +49,10 @@ def schema(self) -> dict[str, DType]: def __len__(self) -> int: return self.native.num_rows - def to_narwhals(self) -> DummyDataFrame[pa.Table, ChunkedArrayAny]: - from narwhals._plan.dummy import DummyDataFrame + def to_narwhals(self) -> DataFrame[pa.Table, ChunkedArrayAny]: + from narwhals._plan.dummy import DataFrame - return DummyDataFrame[pa.Table, "ChunkedArrayAny"]._from_compliant(self) + return DataFrame[pa.Table, "ChunkedArrayAny"]._from_compliant(self) @classmethod def from_dict( @@ -94,7 +94,7 @@ def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[ArrowSe yield from ns._expr.align(from_named_ir(e, self) for e in nodes) # NOTE: Not handling actual expressions yet - # `DummyFrame` is typed for just `str` names + # `BaseFrame` is typed for just `str` names def sort(self, by: Seq[NamedIR], options: SortMultipleOptions) -> Self: df_by = self.select(by) indices = pc.sort_indices(df_by.native, options=options.to_arrow(df_by.columns)) diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index 62dfbeac2f..e941727d6b 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -24,7 +24,7 @@ from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.boolean import AllHorizontal, AnyHorizontal - from narwhals._plan.dummy import DummySeries + from narwhals._plan.dummy import Series from narwhals._plan.expr import FunctionExpr, RangeExpr from narwhals._plan.ranges import IntRange from narwhals._plan.strings import ConcatHorizontal @@ -74,7 +74,7 @@ def lit( @overload def lit( self, - node: expr.Literal[DummySeries[ChunkedArrayAny]], + node: expr.Literal[Series[ChunkedArrayAny]], frame: ArrowDataFrame, name: str, ) -> ArrowExpr: ... @@ -82,14 +82,14 @@ def lit( @overload def lit( self, - node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[ChunkedArrayAny]], + node: expr.Literal[NonNestedLiteral] | expr.Literal[Series[ChunkedArrayAny]], frame: ArrowDataFrame, name: str, ) -> ArrowExpr | ArrowScalar: ... def lit( self, - node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[ChunkedArrayAny]], + node: expr.Literal[NonNestedLiteral] | expr.Literal[Series[ChunkedArrayAny]], frame: ArrowDataFrame, name: str, ) -> ArrowExpr | ArrowScalar: diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index f622e3ed84..d7941fa681 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -4,7 +4,7 @@ from narwhals._arrow.utils import narwhals_to_native_dtype, native_to_narwhals_dtype from narwhals._plan.arrow import functions as fn -from narwhals._plan.protocols import DummyCompliantSeries +from narwhals._plan.protocols import CompliantSeries from narwhals._utils import Version from narwhals.dependencies import is_numpy_array_1d @@ -18,7 +18,7 @@ from narwhals.typing import Into1DArray, IntoDType, _1DArray -class ArrowSeries(DummyCompliantSeries["ChunkedArrayAny"]): +class ArrowSeries(CompliantSeries["ChunkedArrayAny"]): def to_list(self) -> list[Any]: return self.native.to_pylist() diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/boolean.py index 5e24c7c329..a9ec2aeda7 100644 --- a/narwhals/_plan/boolean.py +++ b/narwhals/_plan/boolean.py @@ -12,7 +12,7 @@ from typing_extensions import Self from narwhals._plan.common import ExprIR - from narwhals._plan.dummy import DummySeries + from narwhals._plan.dummy import Series from narwhals._plan.expr import FunctionExpr, Literal # noqa: F401 from narwhals._plan.typing import NativeSeriesT, Seq # noqa: F401 from narwhals.typing import ClosedInterval @@ -126,11 +126,9 @@ def from_iterable(cls, other: t.Iterable[t.Any], /) -> IsInSeq: # NOTE: Shouldn't be allowed for lazy backends (maybe besides `polars`) -class IsInSeries(IsIn["Literal[DummySeries[NativeSeriesT]]"]): +class IsInSeries(IsIn["Literal[Series[NativeSeriesT]]"]): @classmethod - def from_series( - cls, other: DummySeries[NativeSeriesT], / - ) -> IsInSeries[NativeSeriesT]: + def from_series(cls, other: Series[NativeSeriesT], /) -> IsInSeries[NativeSeriesT]: from narwhals._plan.literal import SeriesLiteral return IsInSeries(other=SeriesLiteral(value=other).to_literal()) diff --git a/narwhals/_plan/categorical.py b/narwhals/_plan/categorical.py index 0f8d490fba..cea0274c6e 100644 --- a/narwhals/_plan/categorical.py +++ b/narwhals/_plan/categorical.py @@ -6,7 +6,7 @@ from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: - from narwhals._plan.dummy import DummyExpr + from narwhals._plan.dummy import Expr class CategoricalFunction(Function): ... @@ -33,7 +33,7 @@ class ExprCatNamespace(ExprNamespace[IRCatNamespace]): def _ir_namespace(self) -> type[IRCatNamespace]: return IRCatNamespace - def get_categories(self) -> DummyExpr: + def get_categories(self) -> Expr: return self._to_narwhals( self._ir.get_categories().to_function_expr(self._expr._ir) ) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index fb2b639851..69ce6aed6e 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -27,7 +27,7 @@ from typing_extensions import Never, Self, TypeIs, dataclass_transform from narwhals._plan import expr - from narwhals._plan.dummy import DummyExpr, DummySelector, DummySeries + from narwhals._plan.dummy import Expr, Selector, Series from narwhals._plan.expr import ( AggExpr, BinaryExpr, @@ -38,7 +38,7 @@ ) from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.options import FunctionOptions - from narwhals._plan.protocols import DummyCompliantSeries + from narwhals._plan.protocols import CompliantSeries from narwhals.typing import NonNestedDType, NonNestedLiteral else: @@ -170,12 +170,12 @@ def _field_str(name: str, value: Any) -> str: class ExprIR(Immutable): """Anything that can be a node on a graph of expressions.""" - def to_narwhals(self, version: Version = Version.MAIN) -> DummyExpr: + def to_narwhals(self, version: Version = Version.MAIN) -> Expr: from narwhals._plan import dummy if version is Version.MAIN: - return dummy.DummyExpr._from_ir(self) - return dummy.DummyExprV1._from_ir(self) + return dummy.Expr._from_ir(self) + return dummy.ExprV1._from_ir(self) @property def is_scalar(self) -> bool: @@ -277,12 +277,12 @@ def _repr_html_(self) -> str: class SelectorIR(ExprIR): - def to_narwhals(self, version: Version = Version.MAIN) -> DummySelector: + def to_narwhals(self, version: Version = Version.MAIN) -> Selector: from narwhals._plan import dummy if version is Version.MAIN: - return dummy.DummySelector._from_ir(self) - return dummy.DummySelectorV1._from_ir(self) + return dummy.Selector._from_ir(self) + return dummy.SelectorV1._from_ir(self) def matches_column(self, name: str, dtype: DType) -> bool: """Return True if we can select this column. @@ -370,13 +370,13 @@ class IRNamespace(Immutable): _ir: ExprIR @classmethod - def from_expr(cls, expr: DummyExpr, /) -> Self: + def from_expr(cls, expr: Expr, /) -> Self: return cls(_ir=expr._ir) class ExprNamespace(Immutable, Generic[IRNamespaceT]): __slots__ = ("_expr",) - _expr: DummyExpr + _expr: Expr @property def _ir_namespace(self) -> type[IRNamespaceT]: @@ -386,7 +386,7 @@ def _ir_namespace(self) -> type[IRNamespaceT]: def _ir(self) -> IRNamespaceT: return self._ir_namespace.from_expr(self._expr) - def _to_narwhals(self, ir: ExprIR, /) -> DummyExpr: + def _to_narwhals(self, ir: ExprIR, /) -> Expr: return self._expr._from_ir(ir) @@ -433,13 +433,13 @@ def is_non_nested_literal(obj: Any) -> TypeIs[NonNestedLiteral]: return obj is None or isinstance(obj, _NON_NESTED_LITERAL_TPS) -def is_expr(obj: Any) -> TypeIs[DummyExpr]: - from narwhals._plan.dummy import DummyExpr +def is_expr(obj: Any) -> TypeIs[Expr]: + from narwhals._plan.dummy import Expr - return isinstance(obj, DummyExpr) + return isinstance(obj, Expr) -def is_column(obj: Any) -> TypeIs[DummyExpr]: +def is_column(obj: Any) -> TypeIs[Expr]: """Indicate if the given object is a basic/unaliased column. https://github.com/pola-rs/polars/blob/a3d6a3a7863b4d42e720a05df69ff6b6f5fc551f/py-polars/polars/_utils/various.py#L164-L168. @@ -447,26 +447,22 @@ def is_column(obj: Any) -> TypeIs[DummyExpr]: return is_expr(obj) and obj.meta.is_column() -def is_series( - obj: DummySeries[NativeSeriesT] | Any, -) -> TypeIs[DummySeries[NativeSeriesT]]: - from narwhals._plan.dummy import DummySeries +def is_series(obj: Series[NativeSeriesT] | Any) -> TypeIs[Series[NativeSeriesT]]: + from narwhals._plan.dummy import Series - return isinstance(obj, DummySeries) + return isinstance(obj, Series) def is_compliant_series( - obj: DummyCompliantSeries[NativeSeriesT] | Any, -) -> TypeIs[DummyCompliantSeries[NativeSeriesT]]: + obj: CompliantSeries[NativeSeriesT] | Any, +) -> TypeIs[CompliantSeries[NativeSeriesT]]: return _hasattr_static(obj, "__narwhals_series__") -def is_iterable_reject( - obj: Any, -) -> TypeIs[str | bytes | DummySeries | DummyCompliantSeries]: - from narwhals._plan.dummy import DummySeries +def is_iterable_reject(obj: Any) -> TypeIs[str | bytes | Series | CompliantSeries]: + from narwhals._plan.dummy import Series - return isinstance(obj, (str, bytes, DummySeries)) or is_compliant_series(obj) + return isinstance(obj, (str, bytes, Series)) or is_compliant_series(obj) def is_regex_projection(name: str) -> bool: diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index 9913b084d9..1dd28617f6 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -26,14 +26,14 @@ if t.TYPE_CHECKING: from typing_extensions import TypeIs - from narwhals._plan.dummy import DummyExpr, DummySeries + from narwhals._plan.dummy import Expr, Series from narwhals._plan.expr import SortBy from narwhals._plan.typing import IntoExpr, IntoExprColumn, NativeSeriesT from narwhals.dtypes import IntegerType from narwhals.typing import IntoDType, NonNestedLiteral -def col(*names: str | t.Iterable[str]) -> DummyExpr: +def col(*names: str | t.Iterable[str]) -> Expr: flat_names = tuple(flatten(names)) node = ( Column(name=flat_names[0]) @@ -43,7 +43,7 @@ def col(*names: str | t.Iterable[str]) -> DummyExpr: return node.to_narwhals() -def nth(*indices: int | t.Sequence[int]) -> DummyExpr: +def nth(*indices: int | t.Sequence[int]) -> Expr: flat_indices = tuple(flatten(indices)) node = ( Nth(index=flat_indices[0]) @@ -54,8 +54,8 @@ def nth(*indices: int | t.Sequence[int]) -> DummyExpr: def lit( - value: NonNestedLiteral | DummySeries[NativeSeriesT], dtype: IntoDType | None = None -) -> DummyExpr: + value: NonNestedLiteral | Series[NativeSeriesT], dtype: IntoDType | None = None +) -> Expr: if is_series(value): return SeriesLiteral(value=value).to_literal().to_narwhals() if not is_non_nested_literal(value): @@ -68,64 +68,64 @@ def lit( return ScalarLiteral(value=value, dtype=dtype).to_literal().to_narwhals() -def len() -> DummyExpr: +def len() -> Expr: return Len().to_narwhals() -def all() -> DummyExpr: +def all() -> Expr: return All().to_narwhals() -def exclude(*names: str | t.Iterable[str]) -> DummyExpr: +def exclude(*names: str | t.Iterable[str]) -> Expr: return all().exclude(*names) -def max(*columns: str) -> DummyExpr: +def max(*columns: str) -> Expr: return col(columns).max() -def mean(*columns: str) -> DummyExpr: +def mean(*columns: str) -> Expr: return col(columns).mean() -def min(*columns: str) -> DummyExpr: +def min(*columns: str) -> Expr: return col(columns).min() -def median(*columns: str) -> DummyExpr: +def median(*columns: str) -> Expr: return col(columns).median() -def sum(*columns: str) -> DummyExpr: +def sum(*columns: str) -> Expr: return col(columns).sum() -def all_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> DummyExpr: +def all_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: it = parse.parse_into_seq_of_expr_ir(*exprs) return boolean.AllHorizontal().to_function_expr(*it).to_narwhals() -def any_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> DummyExpr: +def any_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: it = parse.parse_into_seq_of_expr_ir(*exprs) return boolean.AnyHorizontal().to_function_expr(*it).to_narwhals() -def sum_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> DummyExpr: +def sum_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: it = parse.parse_into_seq_of_expr_ir(*exprs) return F.SumHorizontal().to_function_expr(*it).to_narwhals() -def min_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> DummyExpr: +def min_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: it = parse.parse_into_seq_of_expr_ir(*exprs) return F.MinHorizontal().to_function_expr(*it).to_narwhals() -def max_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> DummyExpr: +def max_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: it = parse.parse_into_seq_of_expr_ir(*exprs) return F.MaxHorizontal().to_function_expr(*it).to_narwhals() -def mean_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> DummyExpr: +def mean_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: it = parse.parse_into_seq_of_expr_ir(*exprs) return F.MeanHorizontal().to_function_expr(*it).to_narwhals() @@ -135,7 +135,7 @@ def concat_str( *more_exprs: IntoExpr, separator: str = "", ignore_nulls: bool = False, -) -> DummyExpr: +) -> Expr: it = parse.parse_into_seq_of_expr_ir(exprs, *more_exprs) return ( ConcatHorizontal(separator=separator, ignore_nulls=ignore_nulls) @@ -162,11 +162,11 @@ def when( ... .otherwise(4) ... ) >>> when_then_many - Narwhals DummyExpr (main): + nw._plan.Expr(main): .when([(col('x')) == (lit(str: a))]).then(lit(int: 1)).otherwise(.when([(col('x')) == (lit(str: b))]).then(lit(int: 2)).otherwise(.when([(col('x')) == (lit(str: c))]).then(lit(int: 3)).otherwise(lit(int: 4)))) >>> >>> nwd.when(nwd.col("y") == "b").then(1) - Narwhals DummyExpr (main): + nw._plan.Expr(main): .when([(col('y')) == (lit(str: b))]).then(lit(int: 1)).otherwise(lit(null)) """ condition = parse.parse_predicates_constraints_into_expr_ir( @@ -182,7 +182,7 @@ def int_range( *, dtype: IntegerType | type[IntegerType] = Version.MAIN.dtypes.Int64, eager: bool = False, -) -> DummyExpr: +) -> Expr: if end is None: end = start start = 0 @@ -219,7 +219,7 @@ def _order_dependent_error(node: agg.OrderableAggExpr) -> OrderDependentExprErro return OrderDependentExprError(msg) -def ensure_orderable_rules(*exprs: DummyExpr) -> tuple[DummyExpr, ...]: +def ensure_orderable_rules(*exprs: Expr) -> tuple[Expr, ...]: for expr in exprs: node = expr._ir if isinstance(node, agg.OrderableAggExpr): diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index a67e268751..5752e2f157 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -37,7 +37,7 @@ from collections.abc import Iterable, Sequence import pyarrow as pa - from typing_extensions import Never, Self, TypeAlias + from typing_extensions import Never, Self from narwhals._plan.categorical import ExprCatNamespace from narwhals._plan.common import ExprIR @@ -45,9 +45,9 @@ from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.name import ExprNameNamespace from narwhals._plan.protocols import ( - DummyCompliantDataFrame, - DummyCompliantFrame, - DummyCompliantSeries, + CompliantBaseFrame, + CompliantDataFrame, + CompliantSeries, ) from narwhals._plan.schema import FrozenSchema from narwhals._plan.strings import ExprStringNamespace @@ -68,12 +68,6 @@ ) -CompliantFrame: TypeAlias = "DummyCompliantFrame[t.Any, NativeFrameT]" -CompliantDataFrame: TypeAlias = ( - "DummyCompliantDataFrame[t.Any, NativeFrameT, NativeSeriesT]" -) - - # NOTE: Trying to keep consistent logic between `DataFrame.sort` and `Expr.sort_by` def _parse_sort_by( by: IntoExpr | Iterable[IntoExpr] = (), @@ -91,16 +85,16 @@ def _parse_sort_by( # NOTE: Overly simplified placeholders for mocking typing # Entirely ignoring namespace + function binding -class DummyExpr: +class Expr: _ir: ExprIR _version: t.ClassVar[Version] = Version.MAIN def __repr__(self) -> str: - return f"Narwhals DummyExpr ({self.version.name.lower()}):\n{self._ir!r}" + return f"nw._plan.Expr({self.version.name.lower()}):\n{self._ir!r}" def __str__(self) -> str: """Use `print(self)` for formatting.""" - return f"Narwhals DummyExpr ({self.version.name.lower()}):\n{self._ir!s}" + return f"nw._plan.Expr({self.version.name.lower()}):\n{self._ir!s}" def _repr_html_(self) -> str: return self._ir._repr_html_() @@ -693,25 +687,25 @@ def str(self) -> ExprStringNamespace: return ExprStringNamespace(_expr=self) -class DummySelector(DummyExpr): +class Selector(Expr): """Selectors placeholder. Examples: >>> from narwhals._plan import selectors as ncs >>> >>> (ncs.matches("[^z]a") & ncs.string()) | ncs.datetime("us", None) - Narwhals DummySelector (main): + nw._plan.Selector(main): [([(ncs.matches(pattern='[^z]a')) & (ncs.string())]) | (ncs.datetime(time_unit=['us'], time_zone=[None]))] >>> >>> ~(ncs.boolean() | ncs.matches(r"is_.*")) - Narwhals DummySelector (main): + nw._plan.Selector(main): ~[(ncs.boolean()) | (ncs.matches(pattern='is_.*'))] """ _ir: expr.SelectorIR def __repr__(self) -> str: - return f"Narwhals DummySelector ({self.version.name.lower()}):\n{self._ir!r}" + return f"nw._plan.Selector({self.version.name.lower()}):\n{self._ir!r}" @classmethod def _from_ir(cls, ir: expr.SelectorIR, /) -> Self: # type: ignore[override] @@ -719,14 +713,14 @@ def _from_ir(cls, ir: expr.SelectorIR, /) -> Self: # type: ignore[override] obj._ir = ir return obj - def _to_expr(self) -> DummyExpr: + def _to_expr(self) -> Expr: return self._ir.to_narwhals(self.version) @t.overload # type: ignore[override] def __or__(self, other: Self) -> Self: ... @t.overload - def __or__(self, other: IntoExprColumn | int | bool) -> DummyExpr: ... - def __or__(self, other: IntoExprColumn | int | bool) -> Self | DummyExpr: + def __or__(self, other: IntoExprColumn | int | bool) -> Expr: ... + def __or__(self, other: IntoExprColumn | int | bool) -> Self | Expr: if isinstance(other, type(self)): op = ops.Or() return self._from_ir(op.to_binary_selector(self._ir, other._ir)) @@ -735,8 +729,8 @@ def __or__(self, other: IntoExprColumn | int | bool) -> Self | DummyExpr: @t.overload # type: ignore[override] def __and__(self, other: Self) -> Self: ... @t.overload - def __and__(self, other: IntoExprColumn | int | bool) -> DummyExpr: ... - def __and__(self, other: IntoExprColumn | int | bool) -> Self | DummyExpr: + def __and__(self, other: IntoExprColumn | int | bool) -> Expr: ... + def __and__(self, other: IntoExprColumn | int | bool) -> Self | Expr: if is_column(other) and (name := other.meta.output_name()): other = by_name(name) if isinstance(other, type(self)): @@ -747,8 +741,8 @@ def __and__(self, other: IntoExprColumn | int | bool) -> Self | DummyExpr: @t.overload # type: ignore[override] def __sub__(self, other: Self) -> Self: ... @t.overload - def __sub__(self, other: IntoExpr) -> DummyExpr: ... - def __sub__(self, other: IntoExpr) -> Self | DummyExpr: + def __sub__(self, other: IntoExpr) -> Expr: ... + def __sub__(self, other: IntoExpr) -> Self | Expr: if isinstance(other, type(self)): op = ops.Sub() return self._from_ir(op.to_binary_selector(self._ir, other._ir)) @@ -757,8 +751,8 @@ def __sub__(self, other: IntoExpr) -> Self | DummyExpr: @t.overload # type: ignore[override] def __xor__(self, other: Self) -> Self: ... @t.overload - def __xor__(self, other: IntoExprColumn | int | bool) -> DummyExpr: ... - def __xor__(self, other: IntoExprColumn | int | bool) -> Self | DummyExpr: + def __xor__(self, other: IntoExprColumn | int | bool) -> Expr: ... + def __xor__(self, other: IntoExprColumn | int | bool) -> Self | Expr: if isinstance(other, type(self)): op = ops.ExclusiveOr() return self._from_ir(op.to_binary_selector(self._ir, other._ir)) @@ -767,7 +761,7 @@ def __xor__(self, other: IntoExprColumn | int | bool) -> Self | DummyExpr: def __invert__(self) -> Self: return self._from_ir(expr.InvertSelector(selector=self._ir)) - def __add__(self, other: t.Any) -> DummyExpr: # type: ignore[override] + def __add__(self, other: t.Any) -> Expr: # type: ignore[override] if isinstance(other, type(self)): msg = "unsupported operand type(s) for op: ('Selector' + 'Selector')" raise TypeError(msg) @@ -784,8 +778,8 @@ def __rsub__(self, other: t.Any) -> Never: @t.overload # type: ignore[override] def __rand__(self, other: Self) -> Self: ... @t.overload - def __rand__(self, other: IntoExprColumn | int | bool) -> DummyExpr: ... - def __rand__(self, other: IntoExprColumn | int | bool) -> Self | DummyExpr: + def __rand__(self, other: IntoExprColumn | int | bool) -> Expr: ... + def __rand__(self, other: IntoExprColumn | int | bool) -> Self | Expr: if is_column(other) and (name := other.meta.output_name()): return by_name(name) & self return self._to_expr().__rand__(other) @@ -793,8 +787,8 @@ def __rand__(self, other: IntoExprColumn | int | bool) -> Self | DummyExpr: @t.overload # type: ignore[override] def __ror__(self, other: Self) -> Self: ... @t.overload - def __ror__(self, other: IntoExprColumn | int | bool) -> DummyExpr: ... - def __ror__(self, other: IntoExprColumn | int | bool) -> Self | DummyExpr: + def __ror__(self, other: IntoExprColumn | int | bool) -> Expr: ... + def __ror__(self, other: IntoExprColumn | int | bool) -> Self | Expr: if is_column(other) and (name := other.meta.output_name()): return by_name(name) | self return self._to_expr().__ror__(other) @@ -802,44 +796,23 @@ def __ror__(self, other: IntoExprColumn | int | bool) -> Self | DummyExpr: @t.overload # type: ignore[override] def __rxor__(self, other: Self) -> Self: ... @t.overload - def __rxor__(self, other: IntoExprColumn | int | bool) -> DummyExpr: ... - def __rxor__(self, other: IntoExprColumn | int | bool) -> Self | DummyExpr: + def __rxor__(self, other: IntoExprColumn | int | bool) -> Expr: ... + def __rxor__(self, other: IntoExprColumn | int | bool) -> Self | Expr: if is_column(other) and (name := other.meta.output_name()): return by_name(name) ^ self return self._to_expr().__rxor__(other) -class DummyExprV1(DummyExpr): +class ExprV1(Expr): _version: t.ClassVar[Version] = Version.V1 -class DummySelectorV1(DummySelector): +class SelectorV1(Selector): _version: t.ClassVar[Version] = Version.V1 -class DummyCompliantExpr: - _ir: ExprIR - _version: Version - - @property - def version(self) -> Version: - return self._version - - @classmethod - def _from_ir(cls, ir: ExprIR, /, version: Version) -> Self: - obj = cls.__new__(cls) - obj._ir = ir - obj._version = version - return obj - - def to_narwhals(self) -> DummyExpr: - if self.version is Version.MAIN: - return DummyExpr._from_ir(self._ir) - return DummyExprV1._from_ir(self._ir) - - -class DummyFrame(Generic[NativeFrameT]): - _compliant: CompliantFrame[NativeFrameT] +class BaseFrame(Generic[NativeFrameT]): + _compliant: CompliantBaseFrame[t.Any, NativeFrameT] _version: t.ClassVar[Version] = Version.MAIN @property @@ -862,7 +835,9 @@ def from_native(cls, native: t.Any, /) -> Self: raise NotImplementedError @classmethod - def _from_compliant(cls, compliant: CompliantFrame[NativeFrameT], /) -> Self: + def _from_compliant( + cls, compliant: CompliantBaseFrame[t.Any, NativeFrameT], / + ) -> Self: obj = cls.__new__(cls) obj._compliant = compliant return obj @@ -915,18 +890,18 @@ def sort( return self._from_compliant(self._compliant.sort(named_irs, opts)) -class DummyDataFrame(DummyFrame[NativeFrameT], Generic[NativeFrameT, NativeSeriesT]): - _compliant: CompliantDataFrame[NativeFrameT, NativeSeriesT] +class DataFrame(BaseFrame[NativeFrameT], Generic[NativeFrameT, NativeSeriesT]): + _compliant: CompliantDataFrame[t.Any, NativeFrameT, NativeSeriesT] @property - def _series(self) -> type[DummySeries[NativeSeriesT]]: - return DummySeries[NativeSeriesT] + def _series(self) -> type[Series[NativeSeriesT]]: + return Series[NativeSeriesT] # NOTE: Gave up on trying to get typing working for now @classmethod def from_native( # type: ignore[override] cls, native: NativeFrame, / - ) -> DummyDataFrame[pa.Table, pa.ChunkedArray[t.Any]]: + ) -> DataFrame[pa.Table, pa.ChunkedArray[t.Any]]: if is_pyarrow_table(native): from narwhals._plan.arrow.dataframe import ArrowDataFrame @@ -937,7 +912,7 @@ def from_native( # type: ignore[override] @t.overload def to_dict( self, *, as_series: t.Literal[True] = ... - ) -> dict[str, DummySeries[NativeSeriesT]]: ... + ) -> dict[str, Series[NativeSeriesT]]: ... @t.overload def to_dict(self, *, as_series: t.Literal[False]) -> dict[str, list[t.Any]]: ... @@ -945,11 +920,11 @@ def to_dict(self, *, as_series: t.Literal[False]) -> dict[str, list[t.Any]]: ... @t.overload def to_dict( self, *, as_series: bool - ) -> dict[str, DummySeries[NativeSeriesT]] | dict[str, list[t.Any]]: ... + ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[t.Any]]: ... def to_dict( self, *, as_series: bool = True - ) -> dict[str, DummySeries[NativeSeriesT]] | dict[str, list[t.Any]]: + ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[t.Any]]: if as_series: return { key: self._series._from_compliant(value) @@ -961,8 +936,8 @@ def __len__(self) -> int: return len(self._compliant) -class DummySeries(Generic[NativeSeriesT]): - _compliant: DummyCompliantSeries[NativeSeriesT] +class Series(Generic[NativeSeriesT]): + _compliant: CompliantSeries[NativeSeriesT] _version: t.ClassVar[Version] = Version.MAIN @property @@ -981,7 +956,7 @@ def name(self) -> str: @classmethod def from_native( cls, native: NativeSeries, name: str = "", / - ) -> DummySeries[pa.ChunkedArray[t.Any]]: + ) -> Series[pa.ChunkedArray[t.Any]]: if is_pyarrow_chunked_array(native): from narwhals._plan.arrow.series import ArrowSeries @@ -992,7 +967,7 @@ def from_native( raise NotImplementedError(type(native)) @classmethod - def _from_compliant(cls, compliant: DummyCompliantSeries[NativeSeriesT], /) -> Self: + def _from_compliant(cls, compliant: CompliantSeries[NativeSeriesT], /) -> Self: obj = cls.__new__(cls) obj._compliant = compliant return obj @@ -1007,5 +982,5 @@ def __iter__(self) -> t.Iterator[t.Any]: yield from self.to_native() -class DummySeriesV1(DummySeries[NativeSeriesT]): +class SeriesV1(Series[NativeSeriesT]): _version: t.ClassVar[Version] = Version.V1 diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index ada5bcb755..0ca3d7afa1 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -81,7 +81,7 @@ from typing_extensions import TypeAlias - from narwhals._plan.dummy import DummyExpr + from narwhals._plan.dummy import Expr from narwhals._plan.typing import Seq from narwhals.dtypes import DType @@ -152,7 +152,7 @@ def from_ir(ir: ExprIR, /) -> ExpansionFlags: ) @classmethod - def from_expr(cls, expr: DummyExpr, /) -> ExpansionFlags: + def from_expr(cls, expr: Expr, /) -> ExpansionFlags: return cls.from_ir(expr._ir) def with_multiple_columns(self) -> ExpansionFlags: diff --git a/narwhals/_plan/lists.py b/narwhals/_plan/lists.py index f64b98235c..ea6bf382e0 100644 --- a/narwhals/_plan/lists.py +++ b/narwhals/_plan/lists.py @@ -6,7 +6,7 @@ from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: - from narwhals._plan.dummy import DummyExpr + from narwhals._plan.dummy import Expr class ListFunction(Function): ... @@ -33,5 +33,5 @@ class ExprListNamespace(ExprNamespace[IRListNamespace]): def _ir_namespace(self) -> type[IRListNamespace]: return IRListNamespace - def len(self) -> DummyExpr: + def len(self) -> Expr: return self._to_narwhals(self._ir.len().to_function_expr(self._expr._ir)) diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index a2579b72fb..7349ea24cc 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from typing_extensions import TypeIs - from narwhals._plan.dummy import DummySeries + from narwhals._plan.dummy import Series from narwhals._plan.expr import Literal from narwhals.dtypes import DType @@ -55,14 +55,14 @@ def unwrap(self) -> NonNestedLiteralT: return self.value -class SeriesLiteral(LiteralValue["DummySeries[NativeSeriesT]"]): +class SeriesLiteral(LiteralValue["Series[NativeSeriesT]"]): """We already need this. https://github.com/narwhals-dev/narwhals/blob/e51eba891719a5eb1f7ce91c02a477af39c0baee/narwhals/_expression_parsing.py#L96-L97 """ __slots__ = ("value",) - value: DummySeries[NativeSeriesT] + value: Series[NativeSeriesT] @property def dtype(self) -> DType: @@ -75,7 +75,7 @@ def name(self) -> str: def __repr__(self) -> str: return "Series" - def unwrap(self) -> DummySeries[NativeSeriesT]: + def unwrap(self) -> Series[NativeSeriesT]: return self.value @@ -104,6 +104,6 @@ def is_literal_scalar( def is_literal_series( - obj: Literal[DummySeries[NativeSeriesT]] | Any, -) -> TypeIs[Literal[DummySeries[NativeSeriesT]]]: + obj: Literal[Series[NativeSeriesT]] | Any, +) -> TypeIs[Literal[Series[NativeSeriesT]]]: return is_literal(obj) and _is_series(obj.value) diff --git a/narwhals/_plan/name.py b/narwhals/_plan/name.py index 3566574b42..4a71e00995 100644 --- a/narwhals/_plan/name.py +++ b/narwhals/_plan/name.py @@ -10,7 +10,7 @@ from typing_extensions import Self from narwhals._compliant.typing import AliasName - from narwhals._plan.dummy import DummyExpr + from narwhals._plan.dummy import Expr from narwhals._plan.typing import MapIR @@ -117,25 +117,25 @@ class ExprNameNamespace(ExprNamespace[IRNameNamespace]): def _ir_namespace(self) -> type[IRNameNamespace]: return IRNameNamespace - def keep(self) -> DummyExpr: + def keep(self) -> Expr: return self._to_narwhals(self._ir.keep()) - def map(self, function: AliasName) -> DummyExpr: + def map(self, function: AliasName) -> Expr: """Define an alias by mapping a function over the original root column name.""" return self._to_narwhals(self._ir.map(function)) - def prefix(self, prefix: str) -> DummyExpr: + def prefix(self, prefix: str) -> Expr: """Add a prefix to the root column name.""" return self._to_narwhals(self._ir.prefix(prefix)) - def suffix(self, suffix: str) -> DummyExpr: + def suffix(self, suffix: str) -> Expr: """Add a suffix to the root column name.""" return self._to_narwhals(self._ir.suffix(suffix)) - def to_lowercase(self) -> DummyExpr: + def to_lowercase(self) -> Expr: """Update the root column name to use lowercase characters.""" return self._to_narwhals(self._ir.to_lowercase()) - def to_uppercase(self) -> DummyExpr: + def to_uppercase(self) -> Expr: """Update the root column name to use uppercase characters.""" return self._to_narwhals(self._ir.to_uppercase()) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 233fe270ec..38bf13510d 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from typing_extensions import Self, TypeAlias, TypeIs - from narwhals._plan.dummy import DummyDataFrame, DummyFrame, DummySeries + from narwhals._plan.dummy import BaseFrame, DataFrame, Series from narwhals._plan.expr import FunctionExpr, RangeExpr from narwhals._plan.options import SortMultipleOptions from narwhals._plan.ranges import IntRange @@ -41,14 +41,14 @@ ExprAny: TypeAlias = "CompliantExpr[Any, Any]" ScalarAny: TypeAlias = "CompliantScalar[Any, Any]" -SeriesAny: TypeAlias = "DummyCompliantSeries[Any]" -FrameAny: TypeAlias = "DummyCompliantFrame[Any, Any]" -DataFrameAny: TypeAlias = "DummyCompliantDataFrame[Any, Any, Any]" +SeriesAny: TypeAlias = "CompliantSeries[Any]" +FrameAny: TypeAlias = "CompliantBaseFrame[Any, Any]" +DataFrameAny: TypeAlias = "CompliantDataFrame[Any, Any, Any]" NamespaceAny: TypeAlias = "CompliantNamespace[Any, Any, Any]" EagerExprAny: TypeAlias = "EagerExpr[Any, Any]" EagerScalarAny: TypeAlias = "EagerScalar[Any, Any]" -EagerDataFrameAny: TypeAlias = "DummyEagerDataFrame[Any, Any, Any]" +EagerDataFrameAny: TypeAlias = "EagerDataFrame[Any, Any, Any]" LazyExprAny: TypeAlias = "LazyExpr[Any, Any, Any]" LazyScalarAny: TypeAlias = "LazyScalar[Any, Any, Any]" @@ -627,12 +627,12 @@ def lit( ) -> EagerScalarT_co: ... @overload def lit( - self, node: expr.Literal[DummySeries[Any]], frame: EagerDataFrameT, name: str + self, node: expr.Literal[Series[Any]], frame: EagerDataFrameT, name: str ) -> EagerExprT_co: ... @overload def lit( self, - node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries[Any]], + node: expr.Literal[NonNestedLiteral] | expr.Literal[Series[Any]], frame: EagerDataFrameT, name: str, ) -> EagerExprT_co | EagerScalarT_co: ... @@ -658,7 +658,7 @@ def _frame(self) -> type[FrameT]: return self._lazyframe -class DummyCompliantFrame(StoresVersion, Protocol[ColumnT_co, NativeFrameT]): +class CompliantBaseFrame(StoresVersion, Protocol[ColumnT_co, NativeFrameT]): _native: NativeFrameT def __narwhals_namespace__(self) -> Any: ... @@ -668,7 +668,7 @@ def native(self) -> NativeFrameT: @property def columns(self) -> list[str]: ... - def to_narwhals(self) -> DummyFrame[NativeFrameT]: ... + def to_narwhals(self) -> BaseFrame[NativeFrameT]: ... @classmethod def from_native(cls, native: NativeFrameT, /, version: Version) -> Self: @@ -690,8 +690,8 @@ def with_columns(self, irs: Seq[NamedIR]) -> Self: ... def sort(self, by: Seq[NamedIR], options: SortMultipleOptions) -> Self: ... -class DummyCompliantDataFrame( - DummyCompliantFrame[SeriesT, NativeFrameT], +class CompliantDataFrame( + CompliantBaseFrame[SeriesT, NativeFrameT], Protocol[SeriesT, NativeFrameT, NativeSeriesT], ): @classmethod @@ -703,7 +703,7 @@ def from_dict( schema: Mapping[str, DType] | Schema | None = None, ) -> Self: ... - def to_narwhals(self) -> DummyDataFrame[NativeFrameT, NativeSeriesT]: ... + def to_narwhals(self) -> DataFrame[NativeFrameT, NativeSeriesT]: ... @overload def to_dict(self, *, as_series: Literal[True]) -> dict[str, SeriesT]: ... @@ -722,8 +722,8 @@ def __len__(self) -> int: ... def with_row_index(self, name: str) -> Self: ... -class DummyEagerDataFrame( - DummyCompliantDataFrame[SeriesT, NativeFrameT, NativeSeriesT], +class EagerDataFrame( + CompliantDataFrame[SeriesT, NativeFrameT, NativeSeriesT], Protocol[SeriesT, NativeFrameT, NativeSeriesT], ): def __narwhals_namespace__(self) -> EagerNamespace[Self, SeriesT, Any, Any]: ... @@ -736,7 +736,7 @@ def with_columns(self, irs: Seq[NamedIR]) -> Self: return ns._concat_horizontal(self._evaluate_irs(irs)) -class DummyCompliantSeries(StoresVersion, Protocol[NativeSeriesT]): +class CompliantSeries(StoresVersion, Protocol[NativeSeriesT]): _native: NativeSeriesT _name: str @@ -754,10 +754,10 @@ def dtype(self) -> DType: ... def name(self) -> str: return self._name - def to_narwhals(self) -> DummySeries[NativeSeriesT]: - from narwhals._plan.dummy import DummySeries + def to_narwhals(self) -> Series[NativeSeriesT]: + from narwhals._plan.dummy import Series - return DummySeries[NativeSeriesT]._from_compliant(self) + return Series[NativeSeriesT]._from_compliant(self) @classmethod def from_native( diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/selectors.py index 2d82bbeccd..3cd7666ddc 100644 --- a/narwhals/_plan/selectors.py +++ b/narwhals/_plan/selectors.py @@ -17,7 +17,7 @@ from datetime import timezone from typing import TypeVar - from narwhals._plan.dummy import DummySelector + from narwhals._plan import dummy from narwhals._plan.expr import RootSelector from narwhals.dtypes import DType from narwhals.typing import TimeUnit @@ -154,32 +154,32 @@ def matches_column(self, name: str, dtype: DType) -> bool: return isinstance(dtype, dtypes.String) -def all() -> DummySelector: +def all() -> dummy.Selector: return All().to_selector().to_narwhals() def by_dtype( *dtypes: DType | type[DType] | Iterable[DType | type[DType]], -) -> DummySelector: +) -> dummy.Selector: return ByDType.from_dtypes(*dtypes).to_selector().to_narwhals() -def by_name(*names: str | Iterable[str]) -> DummySelector: +def by_name(*names: str | Iterable[str]) -> dummy.Selector: return Matches.from_names(*names).to_selector().to_narwhals() -def boolean() -> DummySelector: +def boolean() -> dummy.Selector: return Boolean().to_selector().to_narwhals() -def categorical() -> DummySelector: +def categorical() -> dummy.Selector: return Categorical().to_selector().to_narwhals() def datetime( time_unit: TimeUnit | Iterable[TimeUnit] | None = None, time_zone: str | timezone | Iterable[str | timezone | None] | None = ("*", None), -) -> DummySelector: +) -> dummy.Selector: return ( Datetime.from_time_unit_and_time_zone(time_unit, time_zone) .to_selector() @@ -187,13 +187,13 @@ def datetime( ) -def matches(pattern: str) -> DummySelector: +def matches(pattern: str) -> dummy.Selector: return Matches.from_string(pattern).to_selector().to_narwhals() -def numeric() -> DummySelector: +def numeric() -> dummy.Selector: return Numeric().to_selector().to_narwhals() -def string() -> DummySelector: +def string() -> dummy.Selector: return String().to_selector().to_narwhals() diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/strings.py index c13cdf4486..f0eec5f138 100644 --- a/narwhals/_plan/strings.py +++ b/narwhals/_plan/strings.py @@ -6,7 +6,7 @@ from narwhals._plan.options import FunctionFlags, FunctionOptions if TYPE_CHECKING: - from narwhals._plan.dummy import DummyExpr + from narwhals._plan.dummy import Expr class StringFunction(Function): @@ -197,68 +197,66 @@ class ExprStringNamespace(ExprNamespace[IRStringNamespace]): def _ir_namespace(self) -> type[IRStringNamespace]: return IRStringNamespace - def len_chars(self) -> DummyExpr: + def len_chars(self) -> Expr: return self._to_narwhals(self._ir.len_chars().to_function_expr(self._expr._ir)) def replace( self, pattern: str, value: str, *, literal: bool = False, n: int = 1 - ) -> DummyExpr: + ) -> Expr: return self._to_narwhals( self._ir.replace(pattern, value, literal=literal, n=n).to_function_expr( self._expr._ir ) ) - def replace_all( - self, pattern: str, value: str, *, literal: bool = False - ) -> DummyExpr: + def replace_all(self, pattern: str, value: str, *, literal: bool = False) -> Expr: return self._to_narwhals( self._ir.replace_all(pattern, value, literal=literal).to_function_expr( self._expr._ir ) ) - def strip_chars(self, characters: str | None = None) -> DummyExpr: + def strip_chars(self, characters: str | None = None) -> Expr: return self._to_narwhals( self._ir.strip_chars(characters).to_function_expr(self._expr._ir) ) - def starts_with(self, prefix: str) -> DummyExpr: + def starts_with(self, prefix: str) -> Expr: return self._to_narwhals( self._ir.starts_with(prefix).to_function_expr(self._expr._ir) ) - def ends_with(self, suffix: str) -> DummyExpr: + def ends_with(self, suffix: str) -> Expr: return self._to_narwhals( self._ir.ends_with(suffix).to_function_expr(self._expr._ir) ) - def contains(self, pattern: str, *, literal: bool = False) -> DummyExpr: + def contains(self, pattern: str, *, literal: bool = False) -> Expr: return self._to_narwhals( self._ir.contains(pattern, literal=literal).to_function_expr(self._expr._ir) ) - def slice(self, offset: int, length: int | None = None) -> DummyExpr: + def slice(self, offset: int, length: int | None = None) -> Expr: return self._to_narwhals( self._ir.slice(offset, length).to_function_expr(self._expr._ir) ) - def head(self, n: int = 5) -> DummyExpr: + def head(self, n: int = 5) -> Expr: return self._to_narwhals(self._ir.head(n).to_function_expr(self._expr._ir)) - def tail(self, n: int = 5) -> DummyExpr: + def tail(self, n: int = 5) -> Expr: return self._to_narwhals(self._ir.tail(n).to_function_expr(self._expr._ir)) - def split(self, by: str) -> DummyExpr: + def split(self, by: str) -> Expr: return self._to_narwhals(self._ir.split(by).to_function_expr(self._expr._ir)) - def to_datetime(self, format: str | None = None) -> DummyExpr: + def to_datetime(self, format: str | None = None) -> Expr: return self._to_narwhals( self._ir.to_datetime(format).to_function_expr(self._expr._ir) ) - def to_lowercase(self) -> DummyExpr: + def to_lowercase(self) -> Expr: return self._to_narwhals(self._ir.to_lowercase().to_function_expr(self._expr._ir)) - def to_uppercase(self) -> DummyExpr: + def to_uppercase(self) -> Expr: return self._to_narwhals(self._ir.to_uppercase().to_function_expr(self._expr._ir)) diff --git a/narwhals/_plan/struct.py b/narwhals/_plan/struct.py index d9dd42a730..1f9abaf014 100644 --- a/narwhals/_plan/struct.py +++ b/narwhals/_plan/struct.py @@ -6,7 +6,7 @@ from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: - from narwhals._plan.dummy import DummyExpr + from narwhals._plan.dummy import Expr class StructFunction(Function): ... @@ -36,5 +36,5 @@ class ExprStructNamespace(ExprNamespace[IRStructNamespace]): def _ir_namespace(self) -> type[IRStructNamespace]: return IRStructNamespace - def field(self, name: str) -> DummyExpr: + def field(self, name: str) -> Expr: return self._to_narwhals(self._ir.field(name).to_function_expr(self._expr._ir)) diff --git a/narwhals/_plan/temporal.py b/narwhals/_plan/temporal.py index 11956622b7..0ff0bb086c 100644 --- a/narwhals/_plan/temporal.py +++ b/narwhals/_plan/temporal.py @@ -9,7 +9,7 @@ from typing_extensions import TypeAlias, TypeIs from narwhals._duration import Interval, IntervalUnit - from narwhals._plan.dummy import DummyExpr + from narwhals._plan.dummy import Expr from narwhals.typing import TimeUnit PolarsTimeUnit: TypeAlias = Literal["ns", "us", "ms"] @@ -229,92 +229,92 @@ class ExprDateTimeNamespace(ExprNamespace[IRDateTimeNamespace]): def _ir_namespace(self) -> type[IRDateTimeNamespace]: return IRDateTimeNamespace - def date(self) -> DummyExpr: + def date(self) -> Expr: return self._to_narwhals(self._ir.date().to_function_expr(self._expr._ir)) - def year(self) -> DummyExpr: + def year(self) -> Expr: return self._to_narwhals(self._ir.year().to_function_expr(self._expr._ir)) - def month(self) -> DummyExpr: + def month(self) -> Expr: return self._to_narwhals(self._ir.month().to_function_expr(self._expr._ir)) - def day(self) -> DummyExpr: + def day(self) -> Expr: return self._to_narwhals(self._ir.day().to_function_expr(self._expr._ir)) - def hour(self) -> DummyExpr: + def hour(self) -> Expr: return self._to_narwhals(self._ir.hour().to_function_expr(self._expr._ir)) - def minute(self) -> DummyExpr: + def minute(self) -> Expr: return self._to_narwhals(self._ir.minute().to_function_expr(self._expr._ir)) - def second(self) -> DummyExpr: + def second(self) -> Expr: return self._to_narwhals(self._ir.second().to_function_expr(self._expr._ir)) - def millisecond(self) -> DummyExpr: + def millisecond(self) -> Expr: return self._to_narwhals(self._ir.millisecond().to_function_expr(self._expr._ir)) - def microsecond(self) -> DummyExpr: + def microsecond(self) -> Expr: return self._to_narwhals(self._ir.microsecond().to_function_expr(self._expr._ir)) - def nanosecond(self) -> DummyExpr: + def nanosecond(self) -> Expr: return self._to_narwhals(self._ir.nanosecond().to_function_expr(self._expr._ir)) - def ordinal_day(self) -> DummyExpr: + def ordinal_day(self) -> Expr: return self._to_narwhals(self._ir.ordinal_day().to_function_expr(self._expr._ir)) - def weekday(self) -> DummyExpr: + def weekday(self) -> Expr: return self._to_narwhals(self._ir.weekday().to_function_expr(self._expr._ir)) - def total_minutes(self) -> DummyExpr: + def total_minutes(self) -> Expr: return self._to_narwhals( self._ir.total_minutes().to_function_expr(self._expr._ir) ) - def total_seconds(self) -> DummyExpr: + def total_seconds(self) -> Expr: return self._to_narwhals( self._ir.total_seconds().to_function_expr(self._expr._ir) ) - def total_milliseconds(self) -> DummyExpr: + def total_milliseconds(self) -> Expr: return self._to_narwhals( self._ir.total_milliseconds().to_function_expr(self._expr._ir) ) - def total_microseconds(self) -> DummyExpr: + def total_microseconds(self) -> Expr: return self._to_narwhals( self._ir.total_microseconds().to_function_expr(self._expr._ir) ) - def total_nanoseconds(self) -> DummyExpr: + def total_nanoseconds(self) -> Expr: return self._to_narwhals( self._ir.total_nanoseconds().to_function_expr(self._expr._ir) ) - def to_string(self, format: str) -> DummyExpr: + def to_string(self, format: str) -> Expr: return self._to_narwhals( self._ir.to_string(format=format).to_function_expr(self._expr._ir) ) - def replace_time_zone(self, time_zone: str | None) -> DummyExpr: + def replace_time_zone(self, time_zone: str | None) -> Expr: return self._to_narwhals( self._ir.replace_time_zone(time_zone=time_zone).to_function_expr( self._expr._ir ) ) - def convert_time_zone(self, time_zone: str) -> DummyExpr: + def convert_time_zone(self, time_zone: str) -> Expr: return self._to_narwhals( self._ir.convert_time_zone(time_zone=time_zone).to_function_expr( self._expr._ir ) ) - def timestamp(self, time_unit: TimeUnit = "us") -> DummyExpr: + def timestamp(self, time_unit: TimeUnit = "us") -> Expr: return self._to_narwhals( self._ir.timestamp(time_unit=time_unit).to_function_expr(self._expr._ir) ) - def truncate(self, every: str) -> DummyExpr: + def truncate(self, every: str) -> Expr: return self._to_narwhals( self._ir.truncate(every=every).to_function_expr(self._expr._ir) ) diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 6dfdaa8d21..41861220a0 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -10,7 +10,7 @@ from narwhals import dtypes from narwhals._plan import operators as ops from narwhals._plan.common import ExprIR, Function, IRNamespace, NamedIR, SelectorIR - from narwhals._plan.dummy import DummyExpr, DummySeries + from narwhals._plan.dummy import Expr, Series from narwhals._plan.functions import RollingWindow from narwhals._plan.ranges import RangeFunction from narwhals.typing import ( @@ -71,9 +71,7 @@ ) NativeSeriesT = TypeVar("NativeSeriesT", bound="NativeSeries", default="NativeSeries") NativeFrameT = TypeVar("NativeFrameT", bound="NativeFrame", default="NativeFrame") -LiteralT = TypeVar( - "LiteralT", bound="NonNestedLiteral | DummySeries[t.Any]", default=t.Any -) +LiteralT = TypeVar("LiteralT", bound="NonNestedLiteral | Series[t.Any]", default=t.Any) MapIR: TypeAlias = "t.Callable[[ExprIR], ExprIR]" """A function to apply to all nodes in this tree.""" @@ -88,5 +86,5 @@ Udf: TypeAlias = "t.Callable[[t.Any], t.Any]" """Placeholder for `map_batches(function=...)`.""" -IntoExprColumn: TypeAlias = "DummyExpr | DummySeries[t.Any] | str" +IntoExprColumn: TypeAlias = "Expr | Series[t.Any] | str" IntoExpr: TypeAlias = "NonNestedLiteral | IntoExprColumn" diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index ef7ab9f6b3..429f8ed045 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any from narwhals._plan.common import Immutable, is_expr -from narwhals._plan.dummy import DummyExpr +from narwhals._plan.dummy import Expr from narwhals._plan.expr_parsing import ( parse_into_expr_ir, parse_predicates_constraints_into_expr_ir, @@ -25,7 +25,7 @@ def then(self, expr: IntoExpr, /) -> Then: return Then(condition=self.condition, statement=parse_into_expr_ir(expr)) @staticmethod - def _from_expr(expr: DummyExpr, /) -> When: + def _from_expr(expr: Expr, /) -> When: return When(condition=expr._ir) @staticmethod @@ -33,7 +33,7 @@ def _from_ir(ir: ExprIR, /) -> When: return When(condition=ir) -class Then(Immutable, DummyExpr): +class Then(Immutable, Expr): __slots__ = ("condition", "statement") condition: ExprIR statement: ExprIR @@ -46,7 +46,7 @@ def when( conditions=(self.condition, condition), statements=(self.statement,) ) - def otherwise(self, statement: IntoExpr, /) -> DummyExpr: + def otherwise(self, statement: IntoExpr, /) -> Expr: return self._from_ir(self._otherwise(statement)) def _otherwise(self, statement: IntoExpr = None, /) -> ExprIR: @@ -57,12 +57,12 @@ def _ir(self) -> ExprIR: # type: ignore[override] return self._otherwise() @classmethod - def _from_ir(cls, ir: ExprIR, /) -> DummyExpr: # type: ignore[override] - return DummyExpr._from_ir(ir) + def _from_ir(cls, ir: ExprIR, /) -> Expr: # type: ignore[override] + return Expr._from_ir(ir) - def __eq__(self, value: object) -> DummyExpr | bool: # type: ignore[override] + def __eq__(self, value: object) -> Expr | bool: # type: ignore[override] if is_expr(value): - return super(DummyExpr, self).__eq__(value) + return super(Expr, self).__eq__(value) return super().__eq__(value) @@ -78,7 +78,7 @@ def then(self, statement: IntoExpr, /) -> ChainedThen: ) -class ChainedThen(Immutable, DummyExpr): +class ChainedThen(Immutable, Expr): """https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/arity.rs#L89-L130.""" __slots__ = ("conditions", "statements") @@ -93,7 +93,7 @@ def when( conditions=(*self.conditions, condition), statements=self.statements ) - def otherwise(self, statement: IntoExpr, /) -> DummyExpr: + def otherwise(self, statement: IntoExpr, /) -> Expr: return self._from_ir(self._otherwise(statement)) def _otherwise(self, statement: IntoExpr = None, /) -> ExprIR: @@ -109,12 +109,12 @@ def _ir(self) -> ExprIR: # type: ignore[override] return self._otherwise() @classmethod - def _from_ir(cls, ir: ExprIR, /) -> DummyExpr: # type: ignore[override] - return DummyExpr._from_ir(ir) + def _from_ir(cls, ir: ExprIR, /) -> Expr: # type: ignore[override] + return Expr._from_ir(ir) - def __eq__(self, value: object) -> DummyExpr | bool: # type: ignore[override] + def __eq__(self, value: object) -> Expr | bool: # type: ignore[override] if is_expr(value): - return super(DummyExpr, self).__eq__(value) + return super(Expr, self).__eq__(value) return super().__eq__(value) diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index b6a868ccf4..7d7f1f6248 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -12,7 +12,7 @@ import narwhals as nw from narwhals._plan import demo as nwd, selectors as ndcs from narwhals._plan.common import is_expr -from narwhals._plan.dummy import DummyDataFrame +from narwhals._plan.dummy import DataFrame from narwhals._utils import Version from narwhals.exceptions import ComputeError from tests.utils import assert_equal_data @@ -20,7 +20,7 @@ if TYPE_CHECKING: from collections.abc import Sequence - from narwhals._plan.dummy import DummyExpr + from narwhals._plan.dummy import Expr from narwhals.typing import PythonLiteral @@ -64,7 +64,7 @@ def data_indexed() -> dict[str, Any]: } -def _ids_ir(expr: DummyExpr | Any) -> str: +def _ids_ir(expr: Expr | Any) -> str: if is_expr(expr): return repr(expr._ir) return repr(expr) @@ -403,12 +403,10 @@ def _ids_ir(expr: DummyExpr | Any) -> str: ids=_ids_ir, ) def test_select( - expr: DummyExpr | Sequence[DummyExpr], - expected: dict[str, Any], - data_small: dict[str, Any], + expr: Expr | Sequence[Expr], expected: dict[str, Any], data_small: dict[str, Any] ) -> None: frame = pa.table(data_small) - df = DummyDataFrame.from_native(frame) + df = DataFrame.from_native(frame) result = df.select(expr).to_dict(as_series=False) assert_equal_data(result, expected) @@ -477,21 +475,19 @@ def test_select( ], ) def test_with_columns( - expr: DummyExpr | Sequence[DummyExpr], - expected: dict[str, Any], - data_smaller: dict[str, Any], + expr: Expr | Sequence[Expr], expected: dict[str, Any], data_smaller: dict[str, Any] ) -> None: frame = pa.table(data_smaller) - df = DummyDataFrame.from_native(frame) + df = DataFrame.from_native(frame) result = df.with_columns(expr).to_dict(as_series=False) assert_equal_data(result, expected) -def first(*names: str) -> DummyExpr: +def first(*names: str) -> Expr: return nwd.col(*names).first() -def last(*names: str) -> DummyExpr: +def last(*names: str) -> Expr: return nwd.col(*names).last() @@ -507,12 +503,12 @@ def last(*names: str) -> DummyExpr: ], ) def test_first_last_expr_with_columns( - data_indexed: dict[str, Any], agg: DummyExpr, expected: PythonLiteral + data_indexed: dict[str, Any], agg: Expr, expected: PythonLiteral ) -> None: """Related https://github.com/narwhals-dev/narwhals/pull/2528#discussion_r2225930065.""" height = len(next(iter(data_indexed.values()))) expected_broadcast = height * [expected] - frame = DummyDataFrame.from_native(pa.table(data_indexed)) + frame = DataFrame.from_native(pa.table(data_indexed)) expr = agg.over(order_by="idx").alias("result") result = frame.with_columns(expr).select("result").to_dict(as_series=False) assert_equal_data(result, {"result": expected_broadcast}) diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 93ef769dd7..8396db6c74 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -22,7 +22,7 @@ from collections.abc import Iterable, Sequence from narwhals._plan.common import ExprIR - from narwhals._plan.dummy import DummyExpr, DummySelector + from narwhals._plan.dummy import Expr, Selector from narwhals._plan.typing import IntoExpr, MapIR from narwhals.dtypes import DType @@ -110,7 +110,7 @@ def udf_name_map(name: str) -> str: ), ], ) -def test_rewrite_special_aliases_single(expr: DummyExpr, expected: str) -> None: +def test_rewrite_special_aliases_single(expr: Expr, expected: str) -> None: # NOTE: We can't use `output_name()` without resolving these rewrites # Once they're done, `output_name()` just peeks into `Alias(name=...)` ir_input = expr._ir @@ -180,7 +180,7 @@ def fn(ir: ExprIR) -> ExprIR: ), ], ) -def test_map_ir_recursive(expr: DummyExpr, function: MapIR, expected: DummyExpr) -> None: +def test_map_ir_recursive(expr: Expr, function: MapIR, expected: Expr) -> None: actual = expr._ir.map_ir(function) assert_expr_ir_equal(actual, expected) @@ -248,9 +248,7 @@ def test_map_ir_recursive(expr: DummyExpr, function: MapIR, expected: DummyExpr) ], ) def test_replace_selector( - expr: DummySelector | DummyExpr, - expected: DummyExpr | ExprIR, - schema_1: dict[str, DType], + expr: Selector | Expr, expected: Expr | ExprIR, schema_1: dict[str, DType] ) -> None: actual = replace_selector(expr._ir, schema=freeze_schema(**schema_1)) assert_expr_ir_equal(actual, expected) @@ -424,7 +422,7 @@ def test_replace_selector( ) def test_prepare_projection( into_exprs: IntoExpr | Sequence[IntoExpr], - expected: Sequence[DummyExpr], + expected: Sequence[Expr], schema_1: dict[str, DType], ) -> None: irs_in = parse_into_seq_of_expr_ir(into_exprs) @@ -449,9 +447,7 @@ def test_prepare_projection( *MULTI_OUTPUT_EXPRS, ], ) -def test_prepare_projection_duplicate( - expr: DummyExpr, schema_1: dict[str, DType] -) -> None: +def test_prepare_projection_duplicate(expr: Expr, schema_1: dict[str, DType]) -> None: irs = parse_into_seq_of_expr_ir(expr.alias("dupe")) pattern = re.compile(r"\.alias\(.dupe.\)") with pytest.raises(DuplicateError, match=pattern): @@ -551,7 +547,7 @@ def test_prepare_projection_column_not_found( ) def test_prepare_projection_horizontal_alias( into_exprs: IntoExpr | Iterable[IntoExpr], - function: Callable[..., DummyExpr], + function: Callable[..., Expr], schema_1: dict[str, DType], ) -> None: # NOTE: See https://github.com/narwhals-dev/narwhals/pull/2572#discussion_r2139965411 diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 7b65848597..52068da571 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -13,7 +13,7 @@ import narwhals._plan.demo as nwd from narwhals._plan import boolean, expr, functions as F, operators as ops from narwhals._plan.common import ExprIR, Function -from narwhals._plan.dummy import DummyExpr, DummySeries +from narwhals._plan.dummy import Expr, Series from narwhals._plan.expr import BinaryExpr, FunctionExpr, RangeExpr from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir from narwhals._plan.literal import SeriesLiteral @@ -83,14 +83,14 @@ def test_parsing( ], ) def test_function_expr_horizontal( - function: Callable[..., DummyExpr], + function: Callable[..., Expr], ir_node: type[Function], args: Seq[IntoExpr | Iterable[IntoExpr]], ) -> None: variadic = function(*args) sequence = function(args) - assert isinstance(variadic, DummyExpr) - assert isinstance(sequence, DummyExpr) + assert isinstance(variadic, Expr) + assert isinstance(sequence, Expr) variadic_node = variadic._ir sequence_node = sequence._ir unrelated_node = nwd.lit(1)._ir @@ -247,7 +247,7 @@ def test_invalid_binary_expr_length_changing() -> None: a.map_batches(lambda x: x) / b.gather_every(1, 0) -def _is_expr_ir_binary_expr(expr: DummyExpr) -> bool: +def _is_expr_ir_binary_expr(expr: Expr) -> bool: return isinstance(expr._ir, BinaryExpr) @@ -300,7 +300,7 @@ def test_is_in_series() -> None: import pyarrow as pa native = pa.chunked_array([pa.array([1, 2, 3])]) - other = DummySeries.from_native(native) + other = Series.from_native(native) expr = nwd.col("a").is_in(other) ir = expr._ir assert isinstance(ir, FunctionExpr) @@ -393,7 +393,7 @@ def test_lit_series_roundtrip() -> None: data = ["a", "b", "c"] native = pa.chunked_array([pa.array(data)]) - series = DummySeries.from_native(native) + series = Series.from_native(native) lit_series = nwd.lit(series) assert lit_series.meta.is_literal() ir = lit_series._ir @@ -401,7 +401,7 @@ def test_lit_series_roundtrip() -> None: assert isinstance(ir.dtype, nw.String) assert isinstance(ir.value, SeriesLiteral) unwrapped = ir.unwrap() - assert isinstance(unwrapped, DummySeries) + assert isinstance(unwrapped, Series) assert isinstance(unwrapped.to_native(), pa.ChunkedArray) assert unwrapped.to_list() == data @@ -443,8 +443,8 @@ def test_operators_left_right( } result_1 = function(arg_1, arg_2) result_2 = function(arg_2, arg_1) - assert isinstance(result_1, DummyExpr) - assert isinstance(result_2, DummyExpr) + assert isinstance(result_1, Expr) + assert isinstance(result_2, Expr) ir_1 = result_1._ir ir_2 = result_2._ir if op in {ops.Eq, ops.NotEq}: diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index 9f67b800e2..8e5dd0f29c 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -18,7 +18,7 @@ from tests.plan.utils import assert_expr_ir_equal if TYPE_CHECKING: - from narwhals._plan.dummy import DummyExpr + from narwhals._plan.dummy import Expr from narwhals._plan.typing import IntoExpr from narwhals.dtypes import DType @@ -81,7 +81,7 @@ def test_rewrite_elementwise_over_multiple(schema_2: dict[str, DType]) -> None: assert_expr_ir_equal(lhs, rhs) -def named_ir(name: str, expr: DummyExpr | ExprIR, /) -> NamedIR[ExprIR]: +def named_ir(name: str, expr: Expr | ExprIR, /) -> NamedIR[ExprIR]: """Helper constructor for test compare.""" ir = expr._ir if is_expr(expr) else expr return NamedIR(expr=ir, name=name) diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index 6d1f20fc68..e783e55c31 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -8,7 +8,7 @@ from tests.utils import POLARS_VERSION if TYPE_CHECKING: - from narwhals._plan.dummy import DummyExpr + from narwhals._plan.dummy import Expr pytest.importorskip("polars") import polars as pl @@ -51,9 +51,7 @@ (nwd.all().mean().sort_by("d"), pl.all().mean().sort_by("d"), ["d"]), ], ) -def test_meta_root_names( - nw_expr: DummyExpr, pl_expr: pl.Expr, expected: list[str] -) -> None: +def test_meta_root_names(nw_expr: Expr, pl_expr: pl.Expr, expected: list[str]) -> None: pl_result = pl_expr.meta.root_names() nw_result = nw_expr.meta.root_names() assert nw_result == expected @@ -181,7 +179,7 @@ def test_meta_root_names( ), ], ) -def test_meta_output_name(nw_expr: DummyExpr, pl_expr: pl.Expr, expected: str) -> None: +def test_meta_output_name(nw_expr: Expr, pl_expr: pl.Expr, expected: str) -> None: pl_result = pl_expr.meta.output_name() nw_result = nw_expr.meta.output_name() assert nw_result == expected diff --git a/tests/plan/utils.py b/tests/plan/utils.py index e02604871a..6b818f82df 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -7,10 +7,10 @@ if TYPE_CHECKING: from typing_extensions import LiteralString - from narwhals._plan.dummy import DummyExpr + from narwhals._plan.dummy import Expr -def _unwrap_ir(obj: DummyExpr | ExprIR | NamedIR) -> ExprIR: +def _unwrap_ir(obj: Expr | ExprIR | NamedIR) -> ExprIR: if is_expr(obj): return obj._ir if isinstance(obj, ExprIR): @@ -21,9 +21,7 @@ def _unwrap_ir(obj: DummyExpr | ExprIR | NamedIR) -> ExprIR: def assert_expr_ir_equal( - actual: DummyExpr | ExprIR | NamedIR, - expected: DummyExpr | ExprIR | NamedIR | LiteralString, - /, + actual: Expr | ExprIR | NamedIR, expected: Expr | ExprIR | NamedIR | LiteralString, / ) -> None: """Assert that `actual` is equivalent to `expected`. From 25b744d5ea1d86ab50fe26dbd15d159445bfa7f2 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 21 Aug 2025 09:26:12 +0000 Subject: [PATCH 351/368] refactor(expr-ir): Heavily sugar `Function` defs (#3017) --- narwhals/_plan/boolean.py | 112 +++------------ narwhals/_plan/categorical.py | 11 +- narwhals/_plan/common.py | 65 ++++++++- narwhals/_plan/functions.py | 260 ++++------------------------------ narwhals/_plan/lists.py | 11 +- narwhals/_plan/options.py | 6 + narwhals/_plan/ranges.py | 13 +- narwhals/_plan/strings.py | 58 +------- narwhals/_plan/struct.py | 10 +- narwhals/_plan/temporal.py | 42 +----- narwhals/_plan/typing.py | 4 + 11 files changed, 138 insertions(+), 454 deletions(-) diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/boolean.py index a9ec2aeda7..7a3902cb6b 100644 --- a/narwhals/_plan/boolean.py +++ b/narwhals/_plan/boolean.py @@ -5,7 +5,7 @@ import typing as t from narwhals._plan.common import Function -from narwhals._plan.options import FunctionFlags, FunctionOptions +from narwhals._plan.options import FunctionOptions from narwhals._typing_compat import TypeVar if t.TYPE_CHECKING: @@ -21,99 +21,47 @@ ExprT = TypeVar("ExprT", bound="ExprIR", default="ExprIR") -class BooleanFunction(Function): - def __repr__(self) -> str: - tp = type(self) - if tp in {BooleanFunction, IsIn}: - return tp.__name__ - if isinstance(self, IsIn): - return "is_in" - m: dict[type[BooleanFunction], str] = { - All: "all", - Any: "any", - AllHorizontal: "all_horizontal", - AnyHorizontal: "any_horizontal", - IsBetween: "is_between", - IsDuplicated: "is_duplicated", - IsFinite: "is_finite", - IsNan: "is_nan", - IsNull: "is_null", - IsFirstDistinct: "is_first_distinct", - IsLastDistinct: "is_last_distinct", - IsUnique: "is_unique", - Not: "not", - } - return m[tp] - - -class All(BooleanFunction): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.aggregation() - - -class AllHorizontal(BooleanFunction): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise().with_flags( - FunctionFlags.INPUT_WILDCARD_EXPANSION - ) +class BooleanFunction(Function): ... -class Any(BooleanFunction): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.aggregation() +class All(BooleanFunction, options=FunctionOptions.aggregation): ... -class AnyHorizontal(BooleanFunction): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise().with_flags( - FunctionFlags.INPUT_WILDCARD_EXPANSION - ) +class AllHorizontal(BooleanFunction, options=FunctionOptions.horizontal): ... + + +class Any(BooleanFunction, options=FunctionOptions.aggregation): ... -class IsBetween(BooleanFunction): +class AnyHorizontal(BooleanFunction, options=FunctionOptions.horizontal): ... + + +class IsBetween(BooleanFunction, options=FunctionOptions.elementwise): """N-ary (expr, lower_bound, upper_bound).""" __slots__ = ("closed",) closed: ClosedInterval - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() - def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR, ExprIR]: expr, lower_bound, upper_bound = node.input return expr, lower_bound, upper_bound -class IsDuplicated(BooleanFunction): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.length_preserving() +class IsDuplicated(BooleanFunction, options=FunctionOptions.length_preserving): ... -class IsFinite(BooleanFunction): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() +class IsFinite(BooleanFunction, options=FunctionOptions.elementwise): ... -class IsFirstDistinct(BooleanFunction): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.length_preserving() +class IsFirstDistinct(BooleanFunction, options=FunctionOptions.length_preserving): ... -class IsIn(BooleanFunction, t.Generic[OtherT]): +class IsIn(BooleanFunction, t.Generic[OtherT], options=FunctionOptions.elementwise): __slots__ = ("other",) other: OtherT - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() + def __repr__(self) -> str: + return "is_in" class IsInSeq(IsIn["Seq[t.Any]"]): @@ -144,33 +92,17 @@ def __init__(self, *, other: ExprT) -> None: raise NotImplementedError(msg) -class IsLastDistinct(BooleanFunction): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.length_preserving() +class IsLastDistinct(BooleanFunction, options=FunctionOptions.length_preserving): ... -class IsNan(BooleanFunction): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() +class IsNan(BooleanFunction, options=FunctionOptions.elementwise): ... -class IsNull(BooleanFunction): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() +class IsNull(BooleanFunction, options=FunctionOptions.elementwise): ... -class IsUnique(BooleanFunction): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.length_preserving() +class IsUnique(BooleanFunction, options=FunctionOptions.length_preserving): ... -class Not(BooleanFunction): +class Not(BooleanFunction, options=FunctionOptions.elementwise): """`__invert__`.""" - - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() diff --git a/narwhals/_plan/categorical.py b/narwhals/_plan/categorical.py index cea0274c6e..2525548e60 100644 --- a/narwhals/_plan/categorical.py +++ b/narwhals/_plan/categorical.py @@ -9,19 +9,12 @@ from narwhals._plan.dummy import Expr -class CategoricalFunction(Function): ... +class CategoricalFunction(Function, accessor="cat"): ... -class GetCategories(CategoricalFunction): +class GetCategories(CategoricalFunction, options=FunctionOptions.groupwise): """https://github.com/pola-rs/polars/blob/62257860a43ec44a638e8492ed2cf98a49c05f2e/crates/polars-plan/src/dsl/function_expr/cat.rs#L7.""" - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.groupwise() - - def __repr__(self) -> str: - return "cat.get_categories" - class IRCatNamespace(IRNamespace): def get_categories(self) -> GetCategories: diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 69ce6aed6e..2397113b78 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -1,11 +1,13 @@ from __future__ import annotations import datetime as dt +import re from collections.abc import Iterable from decimal import Decimal -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast, overload from narwhals._plan.typing import ( + Accessor, DTypeT, ExprIRT, ExprIRT2, @@ -390,6 +392,12 @@ def _to_narwhals(self, ir: ExprIR, /) -> Expr: return self._expr._from_ir(ir) +def _function_options_default() -> FunctionOptions: + from narwhals._plan.options import FunctionOptions + + return FunctionOptions.default() + + class Function(Immutable): """Shared by expr functions and namespace functions. @@ -398,11 +406,16 @@ class Function(Immutable): https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L114 """ + _accessor: ClassVar[Accessor | None] = None + """Namespace accessor name, if any.""" + + _function_options: ClassVar[staticmethod[[], FunctionOptions]] = staticmethod( + _function_options_default + ) + @property def function_options(self) -> FunctionOptions: - from narwhals._plan.options import FunctionOptions - - return FunctionOptions.default() + return self._function_options() @property def is_scalar(self) -> bool: @@ -411,11 +424,53 @@ def is_scalar(self) -> bool: def to_function_expr(self, *inputs: ExprIR) -> FunctionExpr[Self]: from narwhals._plan.expr import FunctionExpr - # NOTE: Still need to figure out if using a closure is needed options = self.function_options # https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L442-L450. return FunctionExpr(input=inputs, function=self, options=options) + def __init_subclass__( + cls, + *args: Any, + accessor: Accessor | None = None, + options: Callable[[], FunctionOptions] | None = None, + **kwds: Any, + ) -> None: + # NOTE: Hook for defining namespaced functions + # All subclasses will use the prefix in `accessor` for their repr + super().__init_subclass__(*args, **kwds) + if accessor: + cls._accessor = accessor + if options: + cls._function_options = staticmethod(options) + + def __repr__(self) -> str: + return _function_repr(type(self)) + + +# TODO @dangotbanned: Add caching strategy? +def _function_repr(tp: type[Function], /) -> str: + name = _pascal_to_snake_case(tp.__name__) + return f"{ns_name}.{name}" if (ns_name := tp._accessor) else name + + +def _pascal_to_snake_case(s: str) -> str: + """Convert a PascalCase, camelCase string to snake_case. + + Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62 + """ + # Handle the sequence of uppercase letters followed by a lowercase letter + snake = _PATTERN_UPPER_LOWER.sub(_re_repl_snake, s) + # Insert an underscore between a lowercase letter and an uppercase letter + return _PATTERN_LOWER_UPPER.sub(_re_repl_snake, snake).lower() + + +_PATTERN_UPPER_LOWER = re.compile(r"([A-Z]+)([A-Z][a-z])") +_PATTERN_LOWER_UPPER = re.compile(r"([a-z])([A-Z])") + + +def _re_repl_snake(match: re.Match[str], /) -> str: + return f"{match.group(1)}_{match.group(2)}" + _NON_NESTED_LITERAL_TPS = ( int, diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 5021aaa2eb..1b7865fc20 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -21,25 +21,15 @@ from narwhals.typing import FillNullStrategy -class Abs(Function): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() - - def __repr__(self) -> str: - return "abs" +class Abs(Function, options=FunctionOptions.elementwise): ... -class Hist(Function): +class Hist(Function, options=FunctionOptions.groupwise): """Only supported for `Series` so far.""" __slots__ = ("include_breakpoint",) include_breakpoint: bool - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.groupwise() - def __repr__(self) -> str: return "hist" @@ -66,83 +56,37 @@ def __init__(self, *, bin_count: int = 10, include_breakpoint: bool = True) -> N object.__setattr__(self, "include_breakpoint", include_breakpoint) -class NullCount(Function): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.aggregation() - - def __repr__(self) -> str: - return "null_count" +class NullCount(Function, options=FunctionOptions.aggregation): ... -class Log(Function): +class Log(Function, options=FunctionOptions.elementwise): __slots__ = ("base",) base: float - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() - - def __repr__(self) -> str: - return "log" - -class Exp(Function): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() - - def __repr__(self) -> str: - return "exp" +class Exp(Function, options=FunctionOptions.elementwise): ... -class Pow(Function): +class Pow(Function, options=FunctionOptions.elementwise): """N-ary (base, exponent).""" - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() - - def __repr__(self) -> str: - return "pow" - def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: base, exponent = node.input return base, exponent -class Sqrt(Function): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() - - def __repr__(self) -> str: - return "sqrt" +class Sqrt(Function, options=FunctionOptions.elementwise): ... -class Kurtosis(Function): +class Kurtosis(Function, options=FunctionOptions.aggregation): __slots__ = ("bias", "fisher") fisher: bool bias: bool - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.aggregation() - def __repr__(self) -> str: - return "kurtosis" - - -class FillNull(Function): +class FillNull(Function, options=FunctionOptions.elementwise): """N-ary (expr, value).""" - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() - - def __repr__(self) -> str: - return "fill_null" - def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: expr, value = node.input return expr, value @@ -168,113 +112,38 @@ def function_options(self) -> FunctionOptions: else FunctionOptions.groupwise() ) - def __repr__(self) -> str: - return "fill_null_with_strategy" - -class Shift(Function): +class Shift(Function, options=FunctionOptions.length_preserving): __slots__ = ("n",) n: int - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.length_preserving() - def __repr__(self) -> str: - return "shift" +class DropNulls(Function, options=FunctionOptions.row_separable): ... -class DropNulls(Function): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.row_separable() +class Mode(Function, options=FunctionOptions.groupwise): ... - def __repr__(self) -> str: - return "drop_nulls" - -class Mode(Function): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.groupwise() - - def __repr__(self) -> str: - return "mode" +class Skew(Function, options=FunctionOptions.aggregation): ... -class Skew(Function): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.aggregation() - - def __repr__(self) -> str: - return "skew" - - -class Rank(Function): +class Rank(Function, options=FunctionOptions.groupwise): __slots__ = ("options",) options: RankOptions - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.groupwise() - def __repr__(self) -> str: - return "rank" +class Clip(Function, options=FunctionOptions.elementwise): ... -class Clip(Function): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() - - def __repr__(self) -> str: - return "clip" - - -class CumAgg(Function): +class CumAgg(Function, options=FunctionOptions.length_preserving): __slots__ = ("reverse",) reverse: bool - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.length_preserving() - def __repr__(self) -> str: - tp = type(self) - if tp is CumAgg: - return tp.__name__ - m: dict[type[CumAgg], str] = { - CumCount: "count", - CumMin: "min", - CumMax: "max", - CumProd: "prod", - CumSum: "sum", - } - return f"cum_{m[tp]}" - - -class RollingWindow(Function): +class RollingWindow(Function, options=FunctionOptions.length_preserving): __slots__ = ("options",) options: RollingOptionsFixedWindow - @property - def function_options(self) -> FunctionOptions: - """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/function_expr/mod.rs#L1276.""" - return FunctionOptions.length_preserving() - - def __repr__(self) -> str: - tp = type(self) - if tp is RollingWindow: - return tp.__name__ - m: dict[type[RollingWindow], str] = { - RollingSum: "sum", - RollingMean: "mean", - RollingVar: "var", - RollingStd: "std", - } - return f"rolling_{m[tp]}" - def to_function_expr(self, *inputs: ExprIR) -> RollingExpr[Self]: from narwhals._plan.expr import RollingExpr @@ -309,118 +178,46 @@ class RollingVar(RollingWindow): ... class RollingStd(RollingWindow): ... -class Diff(Function): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.length_preserving() +class Diff(Function, options=FunctionOptions.length_preserving): ... - def __repr__(self) -> str: - return "diff" - -class Unique(Function): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.groupwise() - - def __repr__(self) -> str: - return "unique" +class Unique(Function, options=FunctionOptions.groupwise): ... -class Round(Function): +class Round(Function, options=FunctionOptions.elementwise): __slots__ = ("decimals",) decimals: int - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() - - def __repr__(self) -> str: - return "round" - - -class SumHorizontal(Function): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise().with_flags( - FunctionFlags.INPUT_WILDCARD_EXPANSION - ) - - def __repr__(self) -> str: - return "sum_horizontal" - - -class MinHorizontal(Function): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise().with_flags( - FunctionFlags.INPUT_WILDCARD_EXPANSION - ) - def __repr__(self) -> str: - return "min_horizontal" +class SumHorizontal(Function, options=FunctionOptions.horizontal): ... -class MaxHorizontal(Function): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise().with_flags( - FunctionFlags.INPUT_WILDCARD_EXPANSION - ) +class MinHorizontal(Function, options=FunctionOptions.horizontal): ... - def __repr__(self) -> str: - return "max_horizontal" +class MaxHorizontal(Function, options=FunctionOptions.horizontal): ... -class MeanHorizontal(Function): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise().with_flags( - FunctionFlags.INPUT_WILDCARD_EXPANSION - ) - def __repr__(self) -> str: - return "mean_horizontal" +class MeanHorizontal(Function, options=FunctionOptions.horizontal): ... -class EwmMean(Function): +class EwmMean(Function, options=FunctionOptions.length_preserving): __slots__ = ("options",) options: EWMOptions - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.length_preserving() - def __repr__(self) -> str: - return "ewm_mean" - - -class ReplaceStrict(Function): +class ReplaceStrict(Function, options=FunctionOptions.elementwise): __slots__ = ("new", "old", "return_dtype") old: Seq[Any] new: Seq[Any] return_dtype: DType | None - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() - - def __repr__(self) -> str: - return "replace_strict" - -class GatherEvery(Function): +class GatherEvery(Function, options=FunctionOptions.groupwise): __slots__ = ("n", "offset") n: int offset: int - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.groupwise() - - def __repr__(self) -> str: - return "gather_every" - class MapBatches(Function): __slots__ = ("function", "is_elementwise", "return_dtype", "returns_scalar") @@ -439,9 +236,6 @@ def function_options(self) -> FunctionOptions: options = options.with_flags(FunctionFlags.RETURNS_SCALAR) return options - def __repr__(self) -> str: - return "map_batches" - def to_function_expr(self, *inputs: ExprIR) -> AnonymousExpr: from narwhals._plan.expr import AnonymousExpr diff --git a/narwhals/_plan/lists.py b/narwhals/_plan/lists.py index ea6bf382e0..9bc4d594d1 100644 --- a/narwhals/_plan/lists.py +++ b/narwhals/_plan/lists.py @@ -9,19 +9,12 @@ from narwhals._plan.dummy import Expr -class ListFunction(Function): ... +class ListFunction(Function, accessor="list"): ... -class Len(ListFunction): +class Len(ListFunction, options=FunctionOptions.elementwise): """https://github.com/pola-rs/polars/blob/62257860a43ec44a638e8492ed2cf98a49c05f2e/crates/polars-plan/src/dsl/function_expr/list.rs#L32.""" - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() - - def __repr__(self) -> str: - return "list.len" - class IRListNamespace(IRNamespace): def len(self) -> Len: diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 29b7fa4637..80d1dbcde6 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -136,6 +136,12 @@ def groupwise() -> FunctionOptions: def aggregation() -> FunctionOptions: return FunctionOptions.groupwise().with_flags(FunctionFlags.RETURNS_SCALAR) + @staticmethod + def horizontal() -> FunctionOptions: + return FunctionOptions.elementwise().with_flags( + FunctionFlags.INPUT_WILDCARD_EXPANSION + ) + class SortOptions(Immutable): __slots__ = ("descending", "nulls_last") diff --git a/narwhals/_plan/ranges.py b/narwhals/_plan/ranges.py index cfa54df748..14500cadc4 100644 --- a/narwhals/_plan/ranges.py +++ b/narwhals/_plan/ranges.py @@ -13,20 +13,13 @@ class RangeFunction(Function): - def __repr__(self) -> str: - tp = type(self) - if tp is RangeFunction: - return tp.__name__ - m: dict[type[RangeFunction], str] = {IntRange: "int_range"} - return m[tp] - def to_function_expr(self, *inputs: ExprIR) -> RangeExpr[Self]: from narwhals._plan.expr import RangeExpr return RangeExpr(input=inputs, function=self, options=self.function_options) -class IntRange(RangeFunction): +class IntRange(RangeFunction, options=FunctionOptions.row_separable): """N-ary (start, end). Not implemented yet, but might push forward [#2722]. @@ -47,10 +40,6 @@ class IntRange(RangeFunction): step: int dtype: IntegerType - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.row_separable() - def unwrap_input(self, node: RangeExpr[Self], /) -> tuple[ExprIR, ExprIR]: start, end = node.input return start, end diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/strings.py index f0eec5f138..9fb93749d2 100644 --- a/narwhals/_plan/strings.py +++ b/narwhals/_plan/strings.py @@ -3,56 +3,35 @@ from typing import TYPE_CHECKING from narwhals._plan.common import ExprNamespace, Function, IRNamespace -from narwhals._plan.options import FunctionFlags, FunctionOptions +from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: from narwhals._plan.dummy import Expr -class StringFunction(Function): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() - - def __repr__(self) -> str: - return "StringFunction" +class StringFunction(Function, accessor="str", options=FunctionOptions.elementwise): ... -class ConcatHorizontal(StringFunction): +class ConcatHorizontal(StringFunction, options=FunctionOptions.horizontal): """`nw.functions.concat_str`.""" __slots__ = ("ignore_nulls", "separator") separator: str ignore_nulls: bool - @property - def function_options(self) -> FunctionOptions: - return super().function_options.with_flags(FunctionFlags.INPUT_WILDCARD_EXPANSION) - - def __repr__(self) -> str: - return "str.concat_horizontal" - class Contains(StringFunction): __slots__ = ("literal", "pattern") pattern: str literal: bool - def __repr__(self) -> str: - return "str.contains" - class EndsWith(StringFunction): __slots__ = ("suffix",) suffix: str - def __repr__(self) -> str: - return "str.ends_with" - -class LenChars(StringFunction): - def __repr__(self) -> str: - return "str.len_chars" +class LenChars(StringFunction): ... class Replace(StringFunction): @@ -62,9 +41,6 @@ class Replace(StringFunction): literal: bool n: int - def __repr__(self) -> str: - return "str.replace" - class ReplaceAll(StringFunction): """`polars` uses a single node for this and `Replace`. @@ -77,9 +53,6 @@ class ReplaceAll(StringFunction): value: str literal: bool - def __repr__(self) -> str: - return "str.replace_all" - class Slice(StringFunction): """We're using for `Head`, `Tail` as well. @@ -91,33 +64,21 @@ class Slice(StringFunction): offset: int length: int | None - def __repr__(self) -> str: - return "str.slice" - class Split(StringFunction): __slots__ = ("by",) by: str - def __repr__(self) -> str: - return "str.split" - class StartsWith(StringFunction): __slots__ = ("prefix",) prefix: str - def __repr__(self) -> str: - return "str.starts_with" - class StripChars(StringFunction): __slots__ = ("characters",) characters: str | None - def __repr__(self) -> str: - return "str.strip_chars" - class ToDatetime(StringFunction): """`polars` uses `Strptime`. @@ -130,18 +91,11 @@ class ToDatetime(StringFunction): __slots__ = ("format",) format: str | None - def __repr__(self) -> str: - return "str.to_datetime" - -class ToLowercase(StringFunction): - def __repr__(self) -> str: - return "str.to_lowercase" +class ToLowercase(StringFunction): ... -class ToUppercase(StringFunction): - def __repr__(self) -> str: - return "str.to_uppercase" +class ToUppercase(StringFunction): ... class IRStringNamespace(IRNamespace): diff --git a/narwhals/_plan/struct.py b/narwhals/_plan/struct.py index 1f9abaf014..1896a6f953 100644 --- a/narwhals/_plan/struct.py +++ b/narwhals/_plan/struct.py @@ -9,21 +9,17 @@ from narwhals._plan.dummy import Expr -class StructFunction(Function): ... +class StructFunction(Function, accessor="struct"): ... -class FieldByName(StructFunction): +class FieldByName(StructFunction, options=FunctionOptions.elementwise): """https://github.com/pola-rs/polars/blob/62257860a43ec44a638e8492ed2cf98a49c05f2e/crates/polars-plan/src/dsl/function_expr/struct_.rs#L11.""" __slots__ = ("name",) name: str - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() - def __repr__(self) -> str: - return f"struct.field_by_name({self.name!r})" + return f"{super().__repr__()}({self.name!r})" class IRStructNamespace(IRNamespace): diff --git a/narwhals/_plan/temporal.py b/narwhals/_plan/temporal.py index 0ff0bb086c..02196247ba 100644 --- a/narwhals/_plan/temporal.py +++ b/narwhals/_plan/temporal.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal from narwhals._plan.common import ExprNamespace, Function, IRNamespace from narwhals._plan.options import FunctionOptions @@ -19,42 +19,7 @@ def _is_polars_time_unit(obj: Any) -> TypeIs[PolarsTimeUnit]: return obj in {"ns", "us", "ms"} -class TemporalFunction(Function): - @property - def function_options(self) -> FunctionOptions: - return FunctionOptions.elementwise() - - def __repr__(self) -> str: - tp = type(self) - if tp is TemporalFunction: - return tp.__name__ - if tp is Timestamp: - tu = cast("Timestamp", self).time_unit - return f"dt.timestamp[{tu!r}]" - m: dict[type[TemporalFunction], str] = { - Year: "year", - Month: "month", - WeekDay: "weekday", - Day: "day", - OrdinalDay: "ordinal_day", - Date: "date", - Hour: "hour", - Minute: "minute", - Second: "second", - Millisecond: "millisecond", - Microsecond: "microsecond", - Nanosecond: "nanosecond", - TotalMinutes: "total_minutes", - TotalSeconds: "total_seconds", - TotalMilliseconds: "total_milliseconds", - TotalMicroseconds: "total_microseconds", - TotalNanoseconds: "total_nanoseconds", - ToString: "to_string", - ConvertTimeZone: "convert_time_zone", - ReplaceTimeZone: "replace_time_zone", - Truncate: "truncate", - } - return f"dt.{m[tp]}" +class TemporalFunction(Function, accessor="dt", options=FunctionOptions.elementwise): ... class Date(TemporalFunction): ... @@ -139,6 +104,9 @@ def from_time_unit(time_unit: TimeUnit, /) -> Timestamp: raise ValueError(msg) return Timestamp(time_unit=time_unit) + def __repr__(self) -> str: + return f"{super().__repr__()}[{self.time_unit!r}]" + class Truncate(TemporalFunction): __slots__ = ("multiple", "unit") diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 41861220a0..56f2e65879 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -62,6 +62,10 @@ "SelectorOperatorT", bound="ops.SelectorOperator", default="ops.SelectorOperator" ) IRNamespaceT = TypeVar("IRNamespaceT", bound="IRNamespace") +Accessor: TypeAlias = t.Literal[ + "arr", "cat", "dt", "list", "meta", "name", "str", "bin", "struct" +] +"""Namespace accessor property name.""" DTypeT = TypeVar("DTypeT", bound="dtypes.DType") NonNestedDTypeT = TypeVar("NonNestedDTypeT", bound="NonNestedDType") From de884c586c3a0991dfd18603f9090ff508cd4fd0 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 21 Aug 2025 09:38:02 +0000 Subject: [PATCH 352/368] refactor(typing): Align `NativeDataFrame` #2944 --- narwhals/_plan/dummy.py | 6 +++--- narwhals/_plan/protocols.py | 12 ++++++------ narwhals/_plan/typing.py | 4 ++++ 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 5752e2f157..afa2d4411e 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -26,7 +26,7 @@ SortOptions, ) from narwhals._plan.selectors import by_name -from narwhals._plan.typing import NativeFrameT, NativeSeriesT +from narwhals._plan.typing import NativeDataFrameT, NativeFrameT, NativeSeriesT from narwhals._plan.window import Over from narwhals._utils import Version, generate_repr from narwhals.dependencies import is_pyarrow_chunked_array, is_pyarrow_table @@ -890,8 +890,8 @@ def sort( return self._from_compliant(self._compliant.sort(named_irs, opts)) -class DataFrame(BaseFrame[NativeFrameT], Generic[NativeFrameT, NativeSeriesT]): - _compliant: CompliantDataFrame[t.Any, NativeFrameT, NativeSeriesT] +class DataFrame(BaseFrame[NativeDataFrameT], Generic[NativeDataFrameT, NativeSeriesT]): + _compliant: CompliantDataFrame[t.Any, NativeDataFrameT, NativeSeriesT] @property def _series(self) -> type[Series[NativeSeriesT]]: diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 38bf13510d..abb12f4ca8 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -5,7 +5,7 @@ from narwhals._plan import aggregation as agg, boolean, expr, functions as F, strings from narwhals._plan.common import ExprIR, Function, NamedIR, flatten_hash_safe -from narwhals._plan.typing import NativeFrameT, NativeSeriesT, Seq +from narwhals._plan.typing import NativeDataFrameT, NativeFrameT, NativeSeriesT, Seq from narwhals._typing_compat import TypeVar from narwhals._utils import Version, _hasattr_static @@ -691,8 +691,8 @@ def sort(self, by: Seq[NamedIR], options: SortMultipleOptions) -> Self: ... class CompliantDataFrame( - CompliantBaseFrame[SeriesT, NativeFrameT], - Protocol[SeriesT, NativeFrameT, NativeSeriesT], + CompliantBaseFrame[SeriesT, NativeDataFrameT], + Protocol[SeriesT, NativeDataFrameT, NativeSeriesT], ): @classmethod def from_dict( @@ -703,7 +703,7 @@ def from_dict( schema: Mapping[str, DType] | Schema | None = None, ) -> Self: ... - def to_narwhals(self) -> DataFrame[NativeFrameT, NativeSeriesT]: ... + def to_narwhals(self) -> DataFrame[NativeDataFrameT, NativeSeriesT]: ... @overload def to_dict(self, *, as_series: Literal[True]) -> dict[str, SeriesT]: ... @@ -723,8 +723,8 @@ def with_row_index(self, name: str) -> Self: ... class EagerDataFrame( - CompliantDataFrame[SeriesT, NativeFrameT, NativeSeriesT], - Protocol[SeriesT, NativeFrameT, NativeSeriesT], + CompliantDataFrame[SeriesT, NativeDataFrameT, NativeSeriesT], + Protocol[SeriesT, NativeDataFrameT, NativeSeriesT], ): def __narwhals_namespace__(self) -> EagerNamespace[Self, SeriesT, Any, Any]: ... def select(self, irs: Seq[NamedIR]) -> Self: diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 56f2e65879..0b8b884b5f 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -14,6 +14,7 @@ from narwhals._plan.functions import RollingWindow from narwhals._plan.ranges import RangeFunction from narwhals.typing import ( + NativeDataFrame, NativeFrame, NativeSeries, NonNestedDType, @@ -75,6 +76,9 @@ ) NativeSeriesT = TypeVar("NativeSeriesT", bound="NativeSeries", default="NativeSeries") NativeFrameT = TypeVar("NativeFrameT", bound="NativeFrame", default="NativeFrame") +NativeDataFrameT = TypeVar( + "NativeDataFrameT", bound="NativeDataFrame", default="NativeDataFrame" +) LiteralT = TypeVar("LiteralT", bound="NonNestedLiteral | Series[t.Any]", default=t.Any) MapIR: TypeAlias = "t.Callable[[ExprIR], ExprIR]" """A function to apply to all nodes in this tree.""" From f5ecc6ed3d50dae8169c8925e921f55412828d90 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 25 Aug 2025 21:15:38 +0000 Subject: [PATCH 353/368] refactor(expr-ir): Add more builder sugar (#3040) --- narwhals/_plan/aggregation.py | 9 +- narwhals/_plan/categorical.py | 4 +- narwhals/_plan/common.py | 9 ++ narwhals/_plan/demo.py | 62 +------- narwhals/_plan/dummy.py | 258 +++++++++++-------------------- narwhals/_plan/expr.py | 13 +- narwhals/_plan/expr_expansion.py | 12 +- narwhals/_plan/expr_parsing.py | 4 +- narwhals/_plan/lists.py | 2 +- narwhals/_plan/options.py | 12 +- narwhals/_plan/strings.py | 48 ++---- narwhals/_plan/struct.py | 2 +- narwhals/_plan/temporal.py | 68 +++----- 13 files changed, 178 insertions(+), 325 deletions(-) diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index 17bdaa0c8d..aefea50741 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any -from narwhals._plan.common import ExprIR +from narwhals._plan.common import ExprIR, _pascal_to_snake_case from narwhals._plan.exceptions import agg_scalar_error if TYPE_CHECKING: @@ -23,12 +23,7 @@ def is_scalar(self) -> bool: return True def __repr__(self) -> str: - tp = type(self) - if tp in {AggExpr, OrderableAggExpr}: - return tp.__name__ - m = {ArgMin: "arg_min", ArgMax: "arg_max", NUnique: "n_unique"} - name = m.get(tp, tp.__name__.lower()) - return f"{self.expr!r}.{name}()" + return f"{self.expr!r}.{_pascal_to_snake_case(type(self).__name__)}()" def iter_left(self) -> Iterator[ExprIR]: yield from self.expr.iter_left() diff --git a/narwhals/_plan/categorical.py b/narwhals/_plan/categorical.py index 2525548e60..19217e9826 100644 --- a/narwhals/_plan/categorical.py +++ b/narwhals/_plan/categorical.py @@ -27,6 +27,4 @@ def _ir_namespace(self) -> type[IRCatNamespace]: return IRCatNamespace def get_categories(self) -> Expr: - return self._to_narwhals( - self._ir.get_categories().to_function_expr(self._expr._ir) - ) + return self._with_unary(self._ir.get_categories()) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 2397113b78..65f57b1170 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -32,6 +32,7 @@ from narwhals._plan.dummy import Expr, Selector, Series from narwhals._plan.expr import ( AggExpr, + Alias, BinaryExpr, Cast, Column, @@ -274,6 +275,11 @@ def cast(self, dtype: DType) -> Cast: return Cast(expr=self, dtype=dtype) + def alias(self, name: str) -> Alias: + from narwhals._plan.expr import Alias + + return Alias(expr=self, name=name) + def _repr_html_(self) -> str: return self.__repr__() @@ -391,6 +397,9 @@ def _ir(self) -> IRNamespaceT: def _to_narwhals(self, ir: ExprIR, /) -> Expr: return self._expr._from_ir(ir) + def _with_unary(self, function: Function, /) -> Expr: + return self._expr._with_unary(function) + def _function_options_default() -> FunctionOptions: from narwhals._plan.options import FunctionOptions diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index 1dd28617f6..a1425543f9 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -3,53 +3,36 @@ import builtins import typing as t -from narwhals._plan import ( - aggregation as agg, - boolean, - expr_parsing as parse, - functions as F, -) +from narwhals._plan import boolean, expr, expr_parsing as parse, functions as F from narwhals._plan.common import ( into_dtype, is_non_nested_literal, is_series, py_to_narwhals_dtype, ) -from narwhals._plan.expr import All, Column, Columns, IndexColumns, Len, Nth +from narwhals._plan.expr import All, Len from narwhals._plan.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.ranges import IntRange from narwhals._plan.strings import ConcatHorizontal from narwhals._plan.when_then import When from narwhals._utils import Version, flatten -from narwhals.exceptions import InvalidOperationError as OrderDependentExprError if t.TYPE_CHECKING: - from typing_extensions import TypeIs - from narwhals._plan.dummy import Expr, Series - from narwhals._plan.expr import SortBy from narwhals._plan.typing import IntoExpr, IntoExprColumn, NativeSeriesT from narwhals.dtypes import IntegerType from narwhals.typing import IntoDType, NonNestedLiteral def col(*names: str | t.Iterable[str]) -> Expr: - flat_names = tuple(flatten(names)) - node = ( - Column(name=flat_names[0]) - if builtins.len(flat_names) == 1 - else Columns(names=flat_names) - ) + flat = tuple(flatten(names)) + node = expr.col(flat[0]) if builtins.len(flat) == 1 else expr.cols(*flat) return node.to_narwhals() def nth(*indices: int | t.Sequence[int]) -> Expr: - flat_indices = tuple(flatten(indices)) - node = ( - Nth(index=flat_indices[0]) - if builtins.len(flat_indices) == 1 - else IndexColumns(indices=flat_indices) - ) + flat = tuple(flatten(indices)) + node = expr.nth(flat[0]) if builtins.len(flat) == 1 else expr.index_columns(*flat) return node.to_narwhals() @@ -194,36 +177,3 @@ def int_range( .to_function_expr(*parse.parse_into_seq_of_expr_ir(start, end)) .to_narwhals() ) - - -def _is_order_enforcing_previous(obj: t.Any) -> TypeIs[SortBy]: - """In theory, we could add other nodes to this check.""" - from narwhals._plan.expr import SortBy - - allowed = (SortBy,) - return isinstance(obj, allowed) - - -def _order_dependent_error(node: agg.OrderableAggExpr) -> OrderDependentExprError: - previous = node.expr - method = repr(node).removeprefix(f"{previous!r}.") - msg = ( - f"{method} is order-dependent and requires an ordering operation for lazy backends.\n" - f"Hint:\nInstead of:\n" - f" {node!r}\n\n" - "If you want to aggregate to a single value, try:\n" - f" {previous!r}.sort_by(...).{method}\n\n" - "Otherwise, try:\n" - f" {node!r}.over(order_by=...)" - ) - return OrderDependentExprError(msg) - - -def ensure_orderable_rules(*exprs: Expr) -> tuple[Expr, ...]: - for expr in exprs: - node = expr._ir - if isinstance(node, agg.OrderableAggExpr): - previous = node.expr - if not _is_order_enforcing_previous(previous): - raise _order_dependent_error(node) - return exprs diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index afa2d4411e..29d95f9cd1 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -20,10 +20,9 @@ from narwhals._plan.options import ( EWMOptions, RankOptions, - RollingOptionsFixedWindow, - RollingVarParams, SortMultipleOptions, SortOptions, + rolling_options, ) from narwhals._plan.selectors import by_name from narwhals._plan.typing import NativeDataFrameT, NativeFrameT, NativeSeriesT @@ -40,7 +39,7 @@ from typing_extensions import Never, Self from narwhals._plan.categorical import ExprCatNamespace - from narwhals._plan.common import ExprIR + from narwhals._plan.common import ExprIR, Function from narwhals._plan.lists import ExprListNamespace from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.name import ExprNameNamespace @@ -110,7 +109,7 @@ def version(self) -> Version: return self._version def alias(self, name: str) -> Self: - return self._from_ir(expr.Alias(expr=self._ir, name=name)) + return self._from_ir(self._ir.alias(name)) def cast(self, dtype: IntoDType) -> Self: return self._from_ir(self._ir.cast(into_dtype(dtype))) @@ -210,8 +209,11 @@ def filter( by = parse.parse_predicates_constraints_into_expr_ir(*predicates, **constraints) return self._from_ir(expr.Filter(expr=self._ir, by=by)) + def _with_unary(self, function: Function, /) -> Self: + return self._from_ir(function.to_function_expr(self._ir)) + def abs(self) -> Self: - return self._from_ir(F.Abs().to_function_expr(self._ir)) + return self._with_unary(F.Abs()) def hist( self, @@ -232,24 +234,22 @@ def hist( ) else: node = F.HistBinCount(include_breakpoint=include_breakpoint) - return self._from_ir(node.to_function_expr(self._ir)) + return self._with_unary(node) def log(self, base: float = math.e) -> Self: - return self._from_ir(F.Log(base=base).to_function_expr(self._ir)) + return self._with_unary(F.Log(base=base)) def exp(self) -> Self: - return self._from_ir(F.Exp().to_function_expr(self._ir)) + return self._with_unary(F.Exp()) def sqrt(self) -> Self: - return self._from_ir(F.Sqrt().to_function_expr(self._ir)) + return self._with_unary(F.Sqrt()) def kurtosis(self, *, fisher: bool = True, bias: bool = True) -> Self: - return self._from_ir( - F.Kurtosis(fisher=fisher, bias=bias).to_function_expr(self._ir) - ) + return self._with_unary(F.Kurtosis(fisher=fisher, bias=bias)) def null_count(self) -> Self: - return self._from_ir(F.NullCount().to_function_expr(self._ir)) + return self._with_unary(F.NullCount()) def fill_null( self, @@ -260,24 +260,23 @@ def fill_null( if strategy is None: ir = parse.parse_into_expr_ir(value, str_as_lit=True) return self._from_ir(F.FillNull().to_function_expr(self._ir, ir)) - fill = F.FillNullWithStrategy(strategy=strategy, limit=limit) - return self._from_ir(fill.to_function_expr(self._ir)) + return self._with_unary(F.FillNullWithStrategy(strategy=strategy, limit=limit)) def shift(self, n: int) -> Self: - return self._from_ir(F.Shift(n=n).to_function_expr(self._ir)) + return self._with_unary(F.Shift(n=n)) def drop_nulls(self) -> Self: - return self._from_ir(F.DropNulls().to_function_expr(self._ir)) + return self._with_unary(F.DropNulls()) def mode(self) -> Self: - return self._from_ir(F.Mode().to_function_expr(self._ir)) + return self._with_unary(F.Mode()) def skew(self) -> Self: - return self._from_ir(F.Skew().to_function_expr(self._ir)) + return self._with_unary(F.Skew()) def rank(self, method: RankMethod = "average", *, descending: bool = False) -> Self: options = RankOptions(method=method, descending=descending) - return self._from_ir(F.Rank(options=options).to_function_expr(self._ir)) + return self._with_unary(F.Rank(options=options)) def clip( self, @@ -291,47 +290,31 @@ def clip( ) def cum_count(self, *, reverse: bool = False) -> Self: - return self._from_ir(F.CumCount(reverse=reverse).to_function_expr(self._ir)) + return self._with_unary(F.CumCount(reverse=reverse)) def cum_min(self, *, reverse: bool = False) -> Self: - return self._from_ir(F.CumMin(reverse=reverse).to_function_expr(self._ir)) + return self._with_unary(F.CumMin(reverse=reverse)) def cum_max(self, *, reverse: bool = False) -> Self: - return self._from_ir(F.CumMax(reverse=reverse).to_function_expr(self._ir)) + return self._with_unary(F.CumMax(reverse=reverse)) def cum_prod(self, *, reverse: bool = False) -> Self: - return self._from_ir(F.CumProd(reverse=reverse).to_function_expr(self._ir)) + return self._with_unary(F.CumProd(reverse=reverse)) def cum_sum(self, *, reverse: bool = False) -> Self: - return self._from_ir(F.CumSum(reverse=reverse).to_function_expr(self._ir)) + return self._with_unary(F.CumSum(reverse=reverse)) def rolling_sum( self, window_size: int, *, min_samples: int | None = None, center: bool = False ) -> Self: - min_samples = window_size if min_samples is None else min_samples - fn_params = None - options = RollingOptionsFixedWindow( - window_size=window_size, - min_samples=min_samples, - center=center, - fn_params=fn_params, - ) - function = F.RollingSum(options=options) - return self._from_ir(function.to_function_expr(self._ir)) + options = rolling_options(window_size, min_samples, center=center) + return self._with_unary(F.RollingSum(options=options)) def rolling_mean( self, window_size: int, *, min_samples: int | None = None, center: bool = False ) -> Self: - min_samples = window_size if min_samples is None else min_samples - fn_params = None - options = RollingOptionsFixedWindow( - window_size=window_size, - min_samples=min_samples, - center=center, - fn_params=fn_params, - ) - function = F.RollingMean(options=options) - return self._from_ir(function.to_function_expr(self._ir)) + options = rolling_options(window_size, min_samples, center=center) + return self._with_unary(F.RollingMean(options=options)) def rolling_var( self, @@ -341,16 +324,8 @@ def rolling_var( center: bool = False, ddof: int = 1, ) -> Self: - min_samples = window_size if min_samples is None else min_samples - fn_params = RollingVarParams(ddof=ddof) - options = RollingOptionsFixedWindow( - window_size=window_size, - min_samples=min_samples, - center=center, - fn_params=fn_params, - ) - function = F.RollingVar(options=options) - return self._from_ir(function.to_function_expr(self._ir)) + options = rolling_options(window_size, min_samples, center=center, ddof=ddof) + return self._with_unary(F.RollingVar(options=options)) def rolling_std( self, @@ -360,25 +335,17 @@ def rolling_std( center: bool = False, ddof: int = 1, ) -> Self: - min_samples = window_size if min_samples is None else min_samples - fn_params = RollingVarParams(ddof=ddof) - options = RollingOptionsFixedWindow( - window_size=window_size, - min_samples=min_samples, - center=center, - fn_params=fn_params, - ) - function = F.RollingStd(options=options) - return self._from_ir(function.to_function_expr(self._ir)) + options = rolling_options(window_size, min_samples, center=center, ddof=ddof) + return self._with_unary(F.RollingStd(options=options)) def diff(self) -> Self: - return self._from_ir(F.Diff().to_function_expr(self._ir)) + return self._with_unary(F.Diff()) def unique(self) -> Self: - return self._from_ir(F.Unique().to_function_expr(self._ir)) + return self._with_unary(F.Unique()) def round(self, decimals: int = 0) -> Self: - return self._from_ir(F.Round(decimals=decimals).to_function_expr(self._ir)) + return self._with_unary(F.Round(decimals=decimals)) def ewm_mean( self, @@ -400,7 +367,7 @@ def ewm_mean( min_samples=min_samples, ignore_nulls=ignore_nulls, ) - return self._from_ir(F.EwmMean(options=options).to_function_expr(self._ir)) + return self._with_unary(F.EwmMean(options=options)) def replace_strict( self, @@ -429,10 +396,10 @@ def replace_strict( if return_dtype is not None: return_dtype = into_dtype(return_dtype) function = F.ReplaceStrict(old=before, new=after, return_dtype=return_dtype) - return self._from_ir(function.to_function_expr(self._ir)) + return self._with_unary(function) def gather_every(self, n: int, offset: int = 0) -> Self: - return self._from_ir(F.GatherEvery(n=n, offset=offset).to_function_expr(self._ir)) + return self._with_unary(F.GatherEvery(n=n, offset=offset)) def map_batches( self, @@ -444,41 +411,41 @@ def map_batches( ) -> Self: if return_dtype is not None: return_dtype = into_dtype(return_dtype) - return self._from_ir( + return self._with_unary( F.MapBatches( function=function, return_dtype=return_dtype, is_elementwise=is_elementwise, returns_scalar=returns_scalar, - ).to_function_expr(self._ir) + ) ) def any(self) -> Self: - return self._from_ir(boolean.Any().to_function_expr(self._ir)) + return self._with_unary(boolean.Any()) def all(self) -> Self: - return self._from_ir(boolean.All().to_function_expr(self._ir)) + return self._with_unary(boolean.All()) def is_duplicated(self) -> Self: - return self._from_ir(boolean.IsDuplicated().to_function_expr(self._ir)) + return self._with_unary(boolean.IsDuplicated()) def is_finite(self) -> Self: - return self._from_ir(boolean.IsFinite().to_function_expr(self._ir)) + return self._with_unary(boolean.IsFinite()) def is_nan(self) -> Self: - return self._from_ir(boolean.IsNan().to_function_expr(self._ir)) + return self._with_unary(boolean.IsNan()) def is_null(self) -> Self: - return self._from_ir(boolean.IsNull().to_function_expr(self._ir)) + return self._with_unary(boolean.IsNull()) def is_first_distinct(self) -> Self: - return self._from_ir(boolean.IsFirstDistinct().to_function_expr(self._ir)) + return self._with_unary(boolean.IsFirstDistinct()) def is_last_distinct(self) -> Self: - return self._from_ir(boolean.IsLastDistinct().to_function_expr(self._ir)) + return self._with_unary(boolean.IsLastDistinct()) def is_unique(self) -> Self: - return self._from_ir(boolean.IsUnique().to_function_expr(self._ir)) + return self._with_unary(boolean.IsUnique()) def is_between( self, @@ -492,137 +459,98 @@ def is_between( ) def is_in(self, other: t.Iterable[t.Any]) -> Self: - node: boolean.IsIn[t.Any] if is_series(other): - node = boolean.IsInSeries.from_series(other) - elif isinstance(other, t.Iterable): - node = boolean.IsInSeq.from_iterable(other) - elif is_expr(other): - node = boolean.IsInExpr(other=other._ir) - else: - msg = f"`is_in` only supports iterables, got: {type(other).__name__}" - raise TypeError(msg) - return self._from_ir(node.to_function_expr(self._ir)) + return self._with_unary(boolean.IsInSeries.from_series(other)) + if isinstance(other, t.Iterable): + return self._with_unary(boolean.IsInSeq.from_iterable(other)) + if is_expr(other): + return self._with_unary(boolean.IsInExpr(other=other._ir)) + msg = f"`is_in` only supports iterables, got: {type(other).__name__}" + raise TypeError(msg) + + def _with_binary( + self, + op: type[ops.Operator], + other: IntoExpr, + *, + str_as_lit: bool = False, + reflect: bool = False, + ) -> Self: + other_ir = parse.parse_into_expr_ir(other, str_as_lit=str_as_lit) + args = (self._ir, other_ir) if not reflect else (other_ir, self._ir) + return self._from_ir(op().to_binary_expr(*args)) def __eq__(self, other: IntoExpr) -> Self: # type: ignore[override] - op = ops.Eq() - rhs = parse.parse_into_expr_ir(other, str_as_lit=True) - return self._from_ir(op.to_binary_expr(self._ir, rhs)) + return self._with_binary(ops.Eq, other, str_as_lit=True) def __ne__(self, other: IntoExpr) -> Self: # type: ignore[override] - op = ops.NotEq() - rhs = parse.parse_into_expr_ir(other, str_as_lit=True) - return self._from_ir(op.to_binary_expr(self._ir, rhs)) + return self._with_binary(ops.NotEq, other, str_as_lit=True) def __lt__(self, other: IntoExpr) -> Self: - op = ops.Lt() - rhs = parse.parse_into_expr_ir(other, str_as_lit=True) - return self._from_ir(op.to_binary_expr(self._ir, rhs)) + return self._with_binary(ops.Lt, other, str_as_lit=True) def __le__(self, other: IntoExpr) -> Self: - op = ops.LtEq() - rhs = parse.parse_into_expr_ir(other, str_as_lit=True) - return self._from_ir(op.to_binary_expr(self._ir, rhs)) + return self._with_binary(ops.LtEq, other, str_as_lit=True) def __gt__(self, other: IntoExpr) -> Self: - op = ops.Gt() - rhs = parse.parse_into_expr_ir(other, str_as_lit=True) - return self._from_ir(op.to_binary_expr(self._ir, rhs)) + return self._with_binary(ops.Gt, other, str_as_lit=True) def __ge__(self, other: IntoExpr) -> Self: - op = ops.GtEq() - rhs = parse.parse_into_expr_ir(other, str_as_lit=True) - return self._from_ir(op.to_binary_expr(self._ir, rhs)) + return self._with_binary(ops.GtEq, other, str_as_lit=True) def __add__(self, other: IntoExpr) -> Self: - op = ops.Add() - rhs = parse.parse_into_expr_ir(other, str_as_lit=True) - return self._from_ir(op.to_binary_expr(self._ir, rhs)) + return self._with_binary(ops.Add, other, str_as_lit=True) def __radd__(self, other: IntoExpr) -> Self: - op = ops.Add() - lhs = parse.parse_into_expr_ir(other, str_as_lit=True) - return self._from_ir(op.to_binary_expr(lhs, self._ir)) + return self._with_binary(ops.Add, other, str_as_lit=True, reflect=True) def __sub__(self, other: IntoExpr) -> Self: - op = ops.Sub() - rhs = parse.parse_into_expr_ir(other) - return self._from_ir(op.to_binary_expr(self._ir, rhs)) + return self._with_binary(ops.Sub, other) def __rsub__(self, other: IntoExpr) -> Self: - op = ops.Sub() - lhs = parse.parse_into_expr_ir(other) - return self._from_ir(op.to_binary_expr(lhs, self._ir)) + return self._with_binary(ops.Sub, other, reflect=True) def __mul__(self, other: IntoExpr) -> Self: - op = ops.Multiply() - rhs = parse.parse_into_expr_ir(other) - return self._from_ir(op.to_binary_expr(self._ir, rhs)) + return self._with_binary(ops.Multiply, other) def __rmul__(self, other: IntoExpr) -> Self: - op = ops.Multiply() - lhs = parse.parse_into_expr_ir(other) - return self._from_ir(op.to_binary_expr(lhs, self._ir)) + return self._with_binary(ops.Multiply, other, reflect=True) def __truediv__(self, other: IntoExpr) -> Self: - op = ops.TrueDivide() - rhs = parse.parse_into_expr_ir(other) - return self._from_ir(op.to_binary_expr(self._ir, rhs)) + return self._with_binary(ops.TrueDivide, other) def __rtruediv__(self, other: IntoExpr) -> Self: - op = ops.TrueDivide() - lhs = parse.parse_into_expr_ir(other) - return self._from_ir(op.to_binary_expr(lhs, self._ir)) + return self._with_binary(ops.TrueDivide, other, reflect=True) def __floordiv__(self, other: IntoExpr) -> Self: - op = ops.FloorDivide() - rhs = parse.parse_into_expr_ir(other) - return self._from_ir(op.to_binary_expr(self._ir, rhs)) + return self._with_binary(ops.FloorDivide, other) def __rfloordiv__(self, other: IntoExpr) -> Self: - op = ops.FloorDivide() - lhs = parse.parse_into_expr_ir(other) - return self._from_ir(op.to_binary_expr(lhs, self._ir)) + return self._with_binary(ops.FloorDivide, other, reflect=True) def __mod__(self, other: IntoExpr) -> Self: - op = ops.Modulus() - rhs = parse.parse_into_expr_ir(other) - return self._from_ir(op.to_binary_expr(self._ir, rhs)) + return self._with_binary(ops.Modulus, other) def __rmod__(self, other: IntoExpr) -> Self: - op = ops.Modulus() - lhs = parse.parse_into_expr_ir(other) - return self._from_ir(op.to_binary_expr(lhs, self._ir)) + return self._with_binary(ops.Modulus, other, reflect=True) def __and__(self, other: IntoExprColumn | int | bool) -> Self: - op = ops.And() - rhs = parse.parse_into_expr_ir(other) - return self._from_ir(op.to_binary_expr(self._ir, rhs)) + return self._with_binary(ops.And, other) def __rand__(self, other: IntoExprColumn | int | bool) -> Self: - op = ops.And() - lhs = parse.parse_into_expr_ir(other) - return self._from_ir(op.to_binary_expr(lhs, self._ir)) + return self._with_binary(ops.And, other, reflect=True) def __or__(self, other: IntoExprColumn | int | bool) -> Self: - op = ops.Or() - rhs = parse.parse_into_expr_ir(other) - return self._from_ir(op.to_binary_expr(self._ir, rhs)) + return self._with_binary(ops.Or, other) def __ror__(self, other: IntoExprColumn | int | bool) -> Self: - op = ops.Or() - lhs = parse.parse_into_expr_ir(other) - return self._from_ir(op.to_binary_expr(lhs, self._ir)) + return self._with_binary(ops.Or, other, reflect=True) def __xor__(self, other: IntoExprColumn | int | bool) -> Self: - op = ops.ExclusiveOr() - rhs = parse.parse_into_expr_ir(other) - return self._from_ir(op.to_binary_expr(self._ir, rhs)) + return self._with_binary(ops.ExclusiveOr, other) def __rxor__(self, other: IntoExprColumn | int | bool) -> Self: - op = ops.ExclusiveOr() - lhs = parse.parse_into_expr_ir(other) - return self._from_ir(op.to_binary_expr(lhs, self._ir)) + return self._with_binary(ops.ExclusiveOr, other, reflect=True) def __pow__(self, exponent: IntoExprColumn | float) -> Self: exp = parse.parse_into_expr_ir(exponent) @@ -633,7 +561,7 @@ def __rpow__(self, base: IntoExprColumn | float) -> Self: return self._from_ir(F.Pow().to_function_expr(base_, self._ir)) def __invert__(self) -> Self: - return self._from_ir(boolean.Not().to_function_expr(self._ir)) + return self._with_unary(boolean.Not()) @property def meta(self) -> IRMetaNamespace: diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 786af72b20..bb6b1018ff 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -72,10 +72,21 @@ def col(name: str, /) -> Column: - """Sugar for a **single** column selection node.""" return Column(name=name) +def cols(*names: str) -> Columns: + return Columns(names=names) + + +def nth(index: int, /) -> Nth: + return Nth(index=index) + + +def index_columns(*indices: int) -> IndexColumns: + return IndexColumns(indices=indices) + + class Alias(ExprIR): __slots__ = ("expr", "name") expr: ExprIR diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 0ca3d7afa1..f3bd681d4a 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -66,6 +66,7 @@ RenameAlias, _ColumnSelection, col, + cols, ) from narwhals._plan.schema import ( FrozenColumns, @@ -313,8 +314,8 @@ def selector_matches_column(selector: SelectorIR, name: str, dtype: DType, /) -> @lru_cache(maxsize=100) def expand_selector(selector: SelectorIR, *, schema: FrozenSchema) -> Columns: """Expand `selector` into `Columns`, within the context of `schema`.""" - cols = (k for k, v in schema.items() if selector_matches_column(selector, k, v)) - return Columns(names=tuple(cols)) + matches = selector_matches_column + return cols(*(k for k, v in schema.items() if matches(selector, k, v))) def rewrite_projections( @@ -467,16 +468,13 @@ def rewrite_special_aliases(origin: ExprIR, /) -> ExprIR: if meta.has_expr_ir(origin, KeepName, RenameAlias): if isinstance(origin, KeepName): parent = origin.expr - roots = parent.meta.root_names() - alias = next(iter(roots)) - return Alias(expr=parent, name=alias) + return parent.alias(next(iter(parent.meta.root_names()))) if isinstance(origin, RenameAlias): parent = origin.expr leaf_name_or_err = meta.get_single_leaf_name(parent) if not isinstance(leaf_name_or_err, str): raise leaf_name_or_err - alias = origin.function(leaf_name_or_err) - return Alias(expr=parent, name=alias) + return parent.alias(origin.function(leaf_name_or_err)) msg = "`keep`, `suffix`, `prefix` should be last expression" raise InvalidOperationError(msg) return origin diff --git a/narwhals/_plan/expr_parsing.py b/narwhals/_plan/expr_parsing.py index fc480a0f66..b1ef22fbc9 100644 --- a/narwhals/_plan/expr_parsing.py +++ b/narwhals/_plan/expr_parsing.py @@ -154,10 +154,8 @@ def _parse_positional_inputs(inputs: Iterable[IntoExpr], /) -> Iterator[ExprIR]: def _parse_named_inputs(named_inputs: dict[str, IntoExpr], /) -> Iterator[ExprIR]: - from narwhals._plan.expr import Alias - for name, input in named_inputs.items(): - yield Alias(expr=parse_into_expr_ir(input), name=name) + yield parse_into_expr_ir(input).alias(name) def _parse_constraints(constraints: dict[str, IntoExpr], /) -> Iterator[ExprIR]: diff --git a/narwhals/_plan/lists.py b/narwhals/_plan/lists.py index 9bc4d594d1..0f5eaa1a8e 100644 --- a/narwhals/_plan/lists.py +++ b/narwhals/_plan/lists.py @@ -27,4 +27,4 @@ def _ir_namespace(self) -> type[IRListNamespace]: return IRListNamespace def len(self) -> Expr: - return self._to_narwhals(self._ir.len().to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.len()) diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 80d1dbcde6..5532c7f727 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -253,6 +253,16 @@ class RollingOptionsFixedWindow(Immutable): __slots__ = ("center", "fn_params", "min_samples", "window_size") window_size: int min_samples: int - """Renamed from `min_periods`, reuses `window_size` if null.""" center: bool fn_params: RollingVarParams | None + + +def rolling_options( + window_size: int, min_samples: int | None, /, *, center: bool, ddof: int | None = None +) -> RollingOptionsFixedWindow: + return RollingOptionsFixedWindow( + window_size=window_size, + min_samples=window_size if min_samples is None else min_samples, + center=center, + fn_params=ddof if ddof is None else RollingVarParams(ddof=ddof), + ) diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/strings.py index 9fb93749d2..08e012b9fe 100644 --- a/narwhals/_plan/strings.py +++ b/narwhals/_plan/strings.py @@ -152,65 +152,45 @@ def _ir_namespace(self) -> type[IRStringNamespace]: return IRStringNamespace def len_chars(self) -> Expr: - return self._to_narwhals(self._ir.len_chars().to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.len_chars()) def replace( self, pattern: str, value: str, *, literal: bool = False, n: int = 1 ) -> Expr: - return self._to_narwhals( - self._ir.replace(pattern, value, literal=literal, n=n).to_function_expr( - self._expr._ir - ) - ) + return self._with_unary(self._ir.replace(pattern, value, literal=literal, n=n)) def replace_all(self, pattern: str, value: str, *, literal: bool = False) -> Expr: - return self._to_narwhals( - self._ir.replace_all(pattern, value, literal=literal).to_function_expr( - self._expr._ir - ) - ) + return self._with_unary(self._ir.replace_all(pattern, value, literal=literal)) def strip_chars(self, characters: str | None = None) -> Expr: - return self._to_narwhals( - self._ir.strip_chars(characters).to_function_expr(self._expr._ir) - ) + return self._with_unary(self._ir.strip_chars(characters)) def starts_with(self, prefix: str) -> Expr: - return self._to_narwhals( - self._ir.starts_with(prefix).to_function_expr(self._expr._ir) - ) + return self._with_unary(self._ir.starts_with(prefix)) def ends_with(self, suffix: str) -> Expr: - return self._to_narwhals( - self._ir.ends_with(suffix).to_function_expr(self._expr._ir) - ) + return self._with_unary(self._ir.ends_with(suffix)) def contains(self, pattern: str, *, literal: bool = False) -> Expr: - return self._to_narwhals( - self._ir.contains(pattern, literal=literal).to_function_expr(self._expr._ir) - ) + return self._with_unary(self._ir.contains(pattern, literal=literal)) def slice(self, offset: int, length: int | None = None) -> Expr: - return self._to_narwhals( - self._ir.slice(offset, length).to_function_expr(self._expr._ir) - ) + return self._with_unary(self._ir.slice(offset, length)) def head(self, n: int = 5) -> Expr: - return self._to_narwhals(self._ir.head(n).to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.head(n)) def tail(self, n: int = 5) -> Expr: - return self._to_narwhals(self._ir.tail(n).to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.tail(n)) def split(self, by: str) -> Expr: - return self._to_narwhals(self._ir.split(by).to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.split(by)) def to_datetime(self, format: str | None = None) -> Expr: - return self._to_narwhals( - self._ir.to_datetime(format).to_function_expr(self._expr._ir) - ) + return self._with_unary(self._ir.to_datetime(format)) def to_lowercase(self) -> Expr: - return self._to_narwhals(self._ir.to_lowercase().to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.to_lowercase()) def to_uppercase(self) -> Expr: - return self._to_narwhals(self._ir.to_uppercase().to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.to_uppercase()) diff --git a/narwhals/_plan/struct.py b/narwhals/_plan/struct.py index 1896a6f953..d91fef6458 100644 --- a/narwhals/_plan/struct.py +++ b/narwhals/_plan/struct.py @@ -33,4 +33,4 @@ def _ir_namespace(self) -> type[IRStructNamespace]: return IRStructNamespace def field(self, name: str) -> Expr: - return self._to_narwhals(self._ir.field(name).to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.field(name)) diff --git a/narwhals/_plan/temporal.py b/narwhals/_plan/temporal.py index 02196247ba..85a71de87e 100644 --- a/narwhals/_plan/temporal.py +++ b/narwhals/_plan/temporal.py @@ -198,91 +198,67 @@ def _ir_namespace(self) -> type[IRDateTimeNamespace]: return IRDateTimeNamespace def date(self) -> Expr: - return self._to_narwhals(self._ir.date().to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.date()) def year(self) -> Expr: - return self._to_narwhals(self._ir.year().to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.year()) def month(self) -> Expr: - return self._to_narwhals(self._ir.month().to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.month()) def day(self) -> Expr: - return self._to_narwhals(self._ir.day().to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.day()) def hour(self) -> Expr: - return self._to_narwhals(self._ir.hour().to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.hour()) def minute(self) -> Expr: - return self._to_narwhals(self._ir.minute().to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.minute()) def second(self) -> Expr: - return self._to_narwhals(self._ir.second().to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.second()) def millisecond(self) -> Expr: - return self._to_narwhals(self._ir.millisecond().to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.millisecond()) def microsecond(self) -> Expr: - return self._to_narwhals(self._ir.microsecond().to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.microsecond()) def nanosecond(self) -> Expr: - return self._to_narwhals(self._ir.nanosecond().to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.nanosecond()) def ordinal_day(self) -> Expr: - return self._to_narwhals(self._ir.ordinal_day().to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.ordinal_day()) def weekday(self) -> Expr: - return self._to_narwhals(self._ir.weekday().to_function_expr(self._expr._ir)) + return self._with_unary(self._ir.weekday()) def total_minutes(self) -> Expr: - return self._to_narwhals( - self._ir.total_minutes().to_function_expr(self._expr._ir) - ) + return self._with_unary(self._ir.total_minutes()) def total_seconds(self) -> Expr: - return self._to_narwhals( - self._ir.total_seconds().to_function_expr(self._expr._ir) - ) + return self._with_unary(self._ir.total_seconds()) def total_milliseconds(self) -> Expr: - return self._to_narwhals( - self._ir.total_milliseconds().to_function_expr(self._expr._ir) - ) + return self._with_unary(self._ir.total_milliseconds()) def total_microseconds(self) -> Expr: - return self._to_narwhals( - self._ir.total_microseconds().to_function_expr(self._expr._ir) - ) + return self._with_unary(self._ir.total_microseconds()) def total_nanoseconds(self) -> Expr: - return self._to_narwhals( - self._ir.total_nanoseconds().to_function_expr(self._expr._ir) - ) + return self._with_unary(self._ir.total_nanoseconds()) def to_string(self, format: str) -> Expr: - return self._to_narwhals( - self._ir.to_string(format=format).to_function_expr(self._expr._ir) - ) + return self._with_unary(self._ir.to_string(format=format)) def replace_time_zone(self, time_zone: str | None) -> Expr: - return self._to_narwhals( - self._ir.replace_time_zone(time_zone=time_zone).to_function_expr( - self._expr._ir - ) - ) + return self._with_unary(self._ir.replace_time_zone(time_zone=time_zone)) def convert_time_zone(self, time_zone: str) -> Expr: - return self._to_narwhals( - self._ir.convert_time_zone(time_zone=time_zone).to_function_expr( - self._expr._ir - ) - ) + return self._with_unary(self._ir.convert_time_zone(time_zone=time_zone)) def timestamp(self, time_unit: TimeUnit = "us") -> Expr: - return self._to_narwhals( - self._ir.timestamp(time_unit=time_unit).to_function_expr(self._expr._ir) - ) + return self._with_unary(self._ir.timestamp(time_unit=time_unit)) def truncate(self, every: str) -> Expr: - return self._to_narwhals( - self._ir.truncate(every=every).to_function_expr(self._expr._ir) - ) + return self._with_unary(self._ir.truncate(every=every)) From c934d25a5095222f052cc0f1a1bcedf3fe6d59c1 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 29 Aug 2025 20:30:08 +0000 Subject: [PATCH 354/368] refactor(expr-ir): Clearing out the cobwebs (#3053) --- narwhals/_plan/aggregation.py | 6 +-- narwhals/_plan/arrow/dataframe.py | 2 - narwhals/_plan/arrow/functions.py | 9 ----- narwhals/_plan/categorical.py | 3 +- narwhals/_plan/common.py | 10 +---- narwhals/_plan/demo.py | 13 ------- narwhals/_plan/dummy.py | 17 -------- narwhals/_plan/expr.py | 64 +++---------------------------- narwhals/_plan/expr_expansion.py | 18 +-------- narwhals/_plan/functions.py | 17 -------- narwhals/_plan/lists.py | 3 +- narwhals/_plan/literal.py | 5 --- narwhals/_plan/name.py | 7 ---- narwhals/_plan/operators.py | 7 +--- narwhals/_plan/options.py | 5 +-- narwhals/_plan/protocols.py | 26 +------------ narwhals/_plan/ranges.py | 16 +------- narwhals/_plan/strings.py | 17 -------- narwhals/_plan/temporal.py | 12 ++---- narwhals/_plan/when_then.py | 2 - narwhals/_plan/window.py | 5 +-- 21 files changed, 21 insertions(+), 243 deletions(-) diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index aefea50741..dfa61a39d5 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -92,12 +92,10 @@ class Var(AggExpr): class OrderableAggExpr(AggExpr): ... -class First(OrderableAggExpr): - """https://github.com/narwhals-dev/narwhals/issues/2526.""" +class First(OrderableAggExpr): ... -class Last(OrderableAggExpr): - """https://github.com/narwhals-dev/narwhals/issues/2526.""" +class Last(OrderableAggExpr): ... class ArgMin(OrderableAggExpr): ... diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 214fe908ff..cbc5f600b4 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -93,8 +93,6 @@ def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[ArrowSe from_named_ir = ns._expr.from_named_ir yield from ns._expr.align(from_named_ir(e, self) for e in nodes) - # NOTE: Not handling actual expressions yet - # `BaseFrame` is typed for just `str` names def sort(self, by: Seq[NamedIR], options: SortMultipleOptions) -> Self: df_by = self.select(by) indices = pc.sort_indices(df_by.native, options=options.to_arrow(df_by.columns)) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 095aeba780..83ecafae5f 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -113,14 +113,10 @@ def modulus(lhs: Any, rhs: Any) -> Any: def cast( native: Scalar[Any], target_type: DataTypeT, *, safe: bool | None = ... ) -> Scalar[DataTypeT]: ... - - @t.overload def cast( native: ChunkedArray[Any], target_type: DataTypeT, *, safe: bool | None = ... ) -> ChunkedArray[Scalar[DataTypeT]]: ... - - @t.overload def cast( native: ChunkedOrScalar[Scalar[Any]], @@ -128,8 +124,6 @@ def cast( *, safe: bool | None = ..., ) -> ChunkedArray[Scalar[DataTypeT]] | Scalar[DataTypeT]: ... - - def cast( native: ChunkedOrScalar[Scalar[Any]], target_type: DataTypeT, @@ -220,8 +214,6 @@ def int_range( def lit(value: Any, dtype: DataType | None = None) -> NativeScalar: - # NOTE: PR that fixed these the overloads was closed - # https://github.com/zen-xu/pyarrow-stubs/pull/208 return pa.scalar(value) if dtype is None else pa.scalar(value, dtype) @@ -244,7 +236,6 @@ def chunked_array( def concat_vertical_chunked( arrays: Iterable[ChunkedArrayAny], dtype: DataType | None = None, / ) -> ChunkedArrayAny: - # NOTE: Overloads are broken, this is legit v_concat: Incomplete = pa.chunked_array return v_concat(arrays, dtype) # type: ignore[no-any-return] diff --git a/narwhals/_plan/categorical.py b/narwhals/_plan/categorical.py index 19217e9826..7fb58367f9 100644 --- a/narwhals/_plan/categorical.py +++ b/narwhals/_plan/categorical.py @@ -12,8 +12,7 @@ class CategoricalFunction(Function, accessor="cat"): ... -class GetCategories(CategoricalFunction, options=FunctionOptions.groupwise): - """https://github.com/pola-rs/polars/blob/62257860a43ec44a638e8492ed2cf98a49c05f2e/crates/polars-plan/src/dsl/function_expr/cat.rs#L7.""" +class GetCategories(CategoricalFunction, options=FunctionOptions.groupwise): ... class IRCatNamespace(IRNamespace): diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 65f57b1170..ce84eb05ef 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -433,9 +433,7 @@ def is_scalar(self) -> bool: def to_function_expr(self, *inputs: ExprIR) -> FunctionExpr[Self]: from narwhals._plan.expr import FunctionExpr - options = self.function_options - # https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L442-L450. - return FunctionExpr(input=inputs, function=self, options=options) + return FunctionExpr(input=inputs, function=self, options=self.function_options) def __init_subclass__( cls, @@ -444,8 +442,6 @@ def __init_subclass__( options: Callable[[], FunctionOptions] | None = None, **kwds: Any, ) -> None: - # NOTE: Hook for defining namespaced functions - # All subclasses will use the prefix in `accessor` for their repr super().__init_subclass__(*args, **kwds) if accessor: cls._accessor = accessor @@ -529,10 +525,6 @@ def is_iterable_reject(obj: Any) -> TypeIs[str | bytes | Series | CompliantSerie return isinstance(obj, (str, bytes, Series)) or is_compliant_series(obj) -def is_regex_projection(name: str) -> bool: - return name.startswith("^") and name.endswith("$") - - def is_window_expr(obj: Any) -> TypeIs[WindowExpr]: from narwhals._plan.expr import WindowExpr diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index a1425543f9..85b8ac2ad7 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -135,19 +135,6 @@ def when( Examples: >>> from narwhals._plan import demo as nwd - >>> when_then_many = ( - ... nwd.when(nwd.col("x") == "a") - ... .then(1) - ... .when(nwd.col("x") == "b") - ... .then(2) - ... .when(nwd.col("x") == "c") - ... .then(3) - ... .otherwise(4) - ... ) - >>> when_then_many - nw._plan.Expr(main): - .when([(col('x')) == (lit(str: a))]).then(lit(int: 1)).otherwise(.when([(col('x')) == (lit(str: b))]).then(lit(int: 2)).otherwise(.when([(col('x')) == (lit(str: c))]).then(lit(int: 3)).otherwise(lit(int: 4)))) - >>> >>> nwd.when(nwd.col("y") == "b").then(1) nw._plan.Expr(main): .when([(col('y')) == (lit(str: b))]).then(lit(int: 1)).otherwise(lit(null)) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 29d95f9cd1..6ef174ec06 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -385,9 +385,6 @@ def replace_strict( before = tuple(old) after = tuple(old.values()) elif isinstance(old, t.Mapping): - # NOTE: polars raises later when this occurs - # TypeError: cannot create expression literal for value of type dict. - # Hint: Pass `allow_object=True` to accept any value and create a literal of type Object. msg = "`new` argument cannot be used if `old` argument is a Mapping type" raise TypeError(msg) else: @@ -616,20 +613,6 @@ def str(self) -> ExprStringNamespace: class Selector(Expr): - """Selectors placeholder. - - Examples: - >>> from narwhals._plan import selectors as ncs - >>> - >>> (ncs.matches("[^z]a") & ncs.string()) | ncs.datetime("us", None) - nw._plan.Selector(main): - [([(ncs.matches(pattern='[^z]a')) & (ncs.string())]) | (ncs.datetime(time_unit=['us'], time_zone=[None]))] - >>> - >>> ~(ncs.boolean() | ncs.matches(r"is_.*")) - nw._plan.Selector(main): - ~[(ncs.boolean()) | (ncs.matches(pattern='is_.*'))] - """ - _ir: expr.SelectorIR def __repr__(self) -> str: diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index bb6b1018ff..279430405c 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -7,7 +7,7 @@ import typing as t from narwhals._plan.aggregation import AggExpr, OrderableAggExpr -from narwhals._plan.common import ExprIR, SelectorIR, collect, is_regex_projection +from narwhals._plan.common import ExprIR, SelectorIR, collect from narwhals._plan.exceptions import function_expr_invalid_operation_error from narwhals._plan.name import KeepName, RenameAlias from narwhals._plan.typing import ( @@ -152,13 +152,6 @@ def __repr__(self) -> str: class IndexColumns(_ColumnSelection): - """Renamed from `IndexColumn`. - - `Nth` provides the singular variant. - - https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L80 - """ - __slots__ = ("indices",) indices: Seq[int] @@ -167,11 +160,6 @@ def __repr__(self) -> str: class All(_ColumnSelection): - """Aka Wildcard (`pl.all()` or `pl.col("*")`). - - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L137 - """ - def __repr__(self) -> str: return "all()" @@ -181,19 +169,12 @@ class Exclude(_ColumnSelection): expr: ExprIR """Default is `all()`.""" names: Seq[str] - """Excluded names. - - - We're using a `frozenset` in main. - - Might want to switch to that later. - """ + """Excluded names.""" @staticmethod def from_names(expr: ExprIR, *names: str | t.Iterable[str]) -> Exclude: flat = flatten(names) - if any(is_regex_projection(nm) for nm in flat): - msg = f"Using regex in `exclude(...)` is not yet supported.\nnames={flat!r}" - raise NotImplementedError(msg) - return Exclude(expr=expr, names=tuple(flat)) + return Exclude(expr=expr, names=collect(flat)) def __repr__(self) -> str: return f"{self.expr!r}.exclude({list(self.names)!r})" @@ -408,14 +389,7 @@ class FunctionExpr(ExprIR, t.Generic[FunctionT]): __slots__ = ("function", "input", "options") input: Seq[ExprIR] function: FunctionT - """Operation applied to each element of `input`. - - Notes: - [Upstream enum type] is named `FunctionExpr` in `rust`. - Mirroring *exactly* doesn't make much sense in OOP. - - [Upstream enum type]: https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/function_expr/mod.rs#L123 - """ + """Operation applied to each element of `input`.""" options: FunctionOptions """Combined flags from chained operations.""" @@ -495,10 +469,6 @@ class RangeExpr(FunctionExpr[RangeT]): """E.g. `int_range(...)`. Special-cased as it is only allowed scalar inputs, and is row_separable. - - Contradicts the check in `FunctionExpr`, so we've got something *like* [`ensure_range_bounds_contain_exactly_one_value`]. - - [`ensure_range_bounds_contain_exactly_one_value`]:https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/aexpr/function_expr/range/int_range.rs#L9-L14 """ def __init__( @@ -557,12 +527,9 @@ def map_ir(self, function: MapIR, /) -> ExprIR: return function(Filter(expr=expr, by=by)) -# TODO @dangotbanned: Clean up docs/notes class WindowExpr(ExprIR): """A fully specified `.over()`, that occurred after another expression. - I think we want variants for partitioned, ordered, both. - Related: - https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L129-L136 - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/mod.rs#L835-L838 @@ -571,17 +538,9 @@ class WindowExpr(ExprIR): __slots__ = ("expr", "partition_by", "options") # noqa: RUF023 expr: ExprIR - """Renamed from `function`. - - For lazy backends, this should be the only place we allow `rolling_*`, `cum_*`. - """ + """For lazy backends, this should be the only place we allow `rolling_*`, `cum_*`.""" partition_by: Seq[ExprIR] options: Window - """Currently **always** represents over. - - Expr::Window { options: WindowType::Over(WindowMapping) } - Expr::Window { options: WindowType::Rolling(RollingGroupOptions) } - """ def __repr__(self) -> str: return f"{self.expr!r}.over({list(self.partition_by)!r})" @@ -619,16 +578,11 @@ def with_partition_by(self, partition_by: t.Iterable[ExprIR], /) -> Self: return type(self)(expr=self.expr, partition_by=by, options=self.options) -# TODO @dangotbanned: Reduce repetition from `WindowExpr` class OrderedWindowExpr(WindowExpr): __slots__ = ("expr", "partition_by", "order_by", "sort_options", "options") # noqa: RUF023 expr: ExprIR partition_by: Seq[ExprIR] order_by: Seq[ExprIR] - """Deviates from the `polars` version. - - - `order_by` starts the same as here, but `polars` reduces into a struct - becoming a single (nested) node. - """ sort_options: SortOptions options: Window @@ -727,7 +681,6 @@ class RootSelector(SelectorIR): __slots__ = ("selector",) selector: Selector - """by_dtype, matches, numeric, boolean, string, categorical, datetime, all.""" def __repr__(self) -> str: return f"{self.selector!r}" @@ -744,11 +697,7 @@ class BinarySelector( SelectorIR, t.Generic[LeftSelectorT, SelectorOperatorT, RightSelectorT], ): - """Application of two selector exprs via a set operator. - - Note: - `left` and `right` may also nest other `BinarySelector`s. - """ + """Application of two selector exprs via a set operator.""" def matches_column(self, name: str, dtype: DType) -> bool: left = self.left.matches_column(name, dtype) @@ -762,7 +711,6 @@ def map_ir(self, function: MapIR, /) -> ExprIR: class InvertSelector(SelectorIR, t.Generic[SelectorT]): __slots__ = ("selector",) selector: SelectorT - """`(Root|Binary)Selector`.""" def __repr__(self) -> str: return f"~{self.selector!r}" diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index f3bd681d4a..f62ae90d81 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -184,11 +184,6 @@ def prepare_projection( frozen_schema = freeze_schema(schema) rewritten = rewrite_projections(tuple(exprs), keys=(), schema=frozen_schema) output_names = ensure_valid_exprs(rewritten, frozen_schema) - # TODO @dangotbanned: (Seq[ExprIR], OutputNames) -> (Seq[NamedIR]) - # See `expr_rewrites.rewrite_all` - # TODO @dangotbanned: Return a new schema, with the changes (name only) from projecting exprs - # - `select` (subset from schema, maybe need root names as well?) - # - `with_columns` https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs#L1045-L1079 return rewritten, frozen_schema, output_names @@ -274,10 +269,7 @@ def fn(child: ExprIR, /) -> ExprIR: def replace_with_column( origin: ExprIR, tp: type[_ColumnSelection], /, name: str ) -> ExprIR: - """Expand a single column within a multi-selection using `name`. - - For `Columns`, `IndexColumns`, `All`. - """ + """Expand a single column within a multi-selection using `name`.""" def fn(child: ExprIR, /) -> ExprIR: if isinstance(child, tp): @@ -386,13 +378,7 @@ def _iter_exclude_names(origin: ExprIR, /) -> Iterator[str]: def prepare_excluded( origin: ExprIR, /, keys: GroupByKeys, *, has_exclude: bool ) -> Excluded: - """Huge simplification of [`polars_plan::plans::conversion::expr_expansion::prepare_excluded`]. - - - `DTypes` are not allowed - - regex in `exclude(...)` is not allowed - - [`polars_plan::plans::conversion::expr_expansion::prepare_excluded`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L484-L555 - """ + """Huge simplification of https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L484-L555.""" exclude: set[str] = set() if has_exclude: exclude.update(_iter_exclude_names(origin)) diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 1b7865fc20..4c80849a89 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -49,7 +49,6 @@ def __init__(self, *, bins: Seq[float], include_breakpoint: bool = True) -> None class HistBinCount(Hist): __slots__ = ("bin_count", *Hist.__slots__) bin_count: int - """Polars (v1.20) sets `bin_count=10` if neither `bins` or `bin_count` are provided.""" def __init__(self, *, bin_count: int = 10, include_breakpoint: bool = True) -> None: object.__setattr__(self, "bin_count", bin_count) @@ -93,25 +92,10 @@ def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: class FillNullWithStrategy(Function): - """We don't support this variant in a lot of backends, so worth keeping it split out. - - https://github.com/narwhals-dev/narwhals/pull/2555 - """ - __slots__ = ("limit", "strategy") strategy: FillNullStrategy limit: int | None - @property - def function_options(self) -> FunctionOptions: - # NOTE: We don't support these strategies yet - # but might be good to encode this difference now - return ( - FunctionOptions.elementwise() - if self.strategy in {"one", "zero"} - else FunctionOptions.groupwise() - ) - class Shift(Function, options=FunctionOptions.length_preserving): __slots__ = ("n",) @@ -228,7 +212,6 @@ class MapBatches(Function): @property def function_options(self) -> FunctionOptions: - """https://github.com/narwhals-dev/narwhals/issues/2522.""" options = super().function_options if self.is_elementwise: options = options.with_elementwise() diff --git a/narwhals/_plan/lists.py b/narwhals/_plan/lists.py index 0f5eaa1a8e..046db5615d 100644 --- a/narwhals/_plan/lists.py +++ b/narwhals/_plan/lists.py @@ -12,8 +12,7 @@ class ListFunction(Function, accessor="list"): ... -class Len(ListFunction, options=FunctionOptions.elementwise): - """https://github.com/pola-rs/polars/blob/62257860a43ec44a638e8492ed2cf98a49c05f2e/crates/polars-plan/src/dsl/function_expr/list.rs#L32.""" +class Len(ListFunction, options=FunctionOptions.elementwise): ... class IRListNamespace(IRNamespace): diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index 7349ea24cc..e0dba305fa 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -56,11 +56,6 @@ def unwrap(self) -> NonNestedLiteralT: class SeriesLiteral(LiteralValue["Series[NativeSeriesT]"]): - """We already need this. - - https://github.com/narwhals-dev/narwhals/blob/e51eba891719a5eb1f7ce91c02a477af39c0baee/narwhals/_expression_parsing.py#L96-L97 - """ - __slots__ = ("value",) value: Series[NativeSeriesT] diff --git a/narwhals/_plan/name.py b/narwhals/_plan/name.py index 4a71e00995..b7fb5cb189 100644 --- a/narwhals/_plan/name.py +++ b/narwhals/_plan/name.py @@ -15,8 +15,6 @@ class KeepName(ExprIR): - """Keep the original root name.""" - __slots__ = ("expr",) expr: ExprIR @@ -92,23 +90,18 @@ def keep(self) -> KeepName: return KeepName(expr=self._ir) def map(self, function: AliasName) -> RenameAlias: - """Define an alias by mapping a function over the original root column name.""" return RenameAlias(expr=self._ir, function=function) def prefix(self, prefix: str) -> RenameAlias: - """Add a prefix to the root column name.""" return self.map(Prefix(prefix=prefix)) def suffix(self, suffix: str) -> RenameAlias: - """Add a suffix to the root column name.""" return self.map(Suffix(suffix=suffix)) def to_lowercase(self) -> RenameAlias: - """Update the root column name to use lowercase characters.""" return self.map(str.lower) def to_uppercase(self) -> RenameAlias: - """Update the root column name to use uppercase characters.""" return self.map(str.upper) diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index 7c22346900..09d072e7bd 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -82,12 +82,7 @@ def _is_filtration(ir: ExprIR) -> bool: class SelectorOperator(Operator): - """Operators that can *also* be used in selectors. - - Remember that `Or` is named [`meta._selector_add`]! - - [`meta._selector_add`]: https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/meta.rs#L113-L124 - """ + """Operators that can *also* be used in selectors.""" def to_binary_selector( self, left: LeftSelectorT, right: RightSelectorT, / diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 5532c7f727..ca6cf91a04 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -71,10 +71,7 @@ def __str__(self) -> str: class FunctionOptions(Immutable): - """ExprMetadata` but less god object. - - https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs - """ + """https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs""" # noqa: D415 __slots__ = ("flags",) flags: FunctionFlags diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index abb12f4ca8..821e7a338a 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -277,17 +277,7 @@ def __narwhals_namespace__(self) -> NamespaceT_co: ... class CompliantExpr(StoresVersion, Protocol[FrameT_contra, SeriesT_co]): - """Everything common to `Expr`/`Series` and `Scalar` literal values. - - Early notes: - - Separating series/scalar makes a lot of sense - - Handling the recursive case *without* intermediate (non-pyarrow) objects seems unachievable - - Everywhere would need to first check if it a scalar, which isn't ergonomic - - Broadcasting being separated is working - - A lot of `pyarrow.compute` (section 2) can work on either scalar or series (`FunctionExpr`) - - Aggregation can't, but that is already handled in `ExprIR` - - `polars` noops on aggregating a scalar, which we might be able to support this way - """ + """Everything common to `Expr`/`Series` and `Scalar` literal values.""" _evaluated: Any """Compliant or native value.""" @@ -551,20 +541,6 @@ def _concat_vertical( class CompliantNamespace(StoresVersion, Protocol[FrameT, ExprT_co, ScalarT_co]): - """Need to hold `Expr` and `Scalar` types outside of their defs. - - Likely, re-wrapping the output types will work like: - - - ns = DataFrame().__narwhals_namespace__() - if ns._expr.is_native(out): - ns._expr.from_native(out, ...) - elif ns._scalar.is_native(out): - ns._scalar.from_native(out, ...) - else: - assert_never(out) - """ - @property def _frame(self) -> type[FrameT]: ... @property diff --git a/narwhals/_plan/ranges.py b/narwhals/_plan/ranges.py index 14500cadc4..4414afabf7 100644 --- a/narwhals/_plan/ranges.py +++ b/narwhals/_plan/ranges.py @@ -20,21 +20,7 @@ def to_function_expr(self, *inputs: ExprIR) -> RangeExpr[Self]: class IntRange(RangeFunction, options=FunctionOptions.row_separable): - """N-ary (start, end). - - Not implemented yet, but might push forward [#2722]. - - See [`rust` entrypoint], which is roughly: - - Expr::Function { [start, end], FunctionExpr::Range(RangeFunction::IntRange { step, dtype }) } - - `narwhals` equivalent: - - FunctionExpr(input=(start, end), function=IntRange(step=step, dtype=dtype)) - - [#2722]: https://github.com/narwhals-dev/narwhals/issues/2722 - [`rust` entrypoint]: https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/dsl/functions/range.rs#L14-L23 - """ + """N-ary (start, end).""" __slots__ = ("step", "dtype") # noqa: RUF023 step: int diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/strings.py index 08e012b9fe..8a1789b079 100644 --- a/narwhals/_plan/strings.py +++ b/narwhals/_plan/strings.py @@ -43,11 +43,6 @@ class Replace(StringFunction): class ReplaceAll(StringFunction): - """`polars` uses a single node for this and `Replace`. - - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/function_expr/strings.rs#L65-L70 - """ - __slots__ = ("literal", "pattern", "value") pattern: str value: str @@ -55,11 +50,6 @@ class ReplaceAll(StringFunction): class Slice(StringFunction): - """We're using for `Head`, `Tail` as well. - - https://github.com/dangotbanned/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/function_expr/strings.rs#L87-L89 - """ - __slots__ = ("length", "offset") offset: int length: int | None @@ -81,13 +71,6 @@ class StripChars(StringFunction): class ToDatetime(StringFunction): - """`polars` uses `Strptime`. - - We've got a fairly different representation. - - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/function_expr/strings.rs#L112 - """ - __slots__ = ("format",) format: str | None diff --git a/narwhals/_plan/temporal.py b/narwhals/_plan/temporal.py index 85a71de87e..f6a74587f7 100644 --- a/narwhals/_plan/temporal.py +++ b/narwhals/_plan/temporal.py @@ -2,13 +2,14 @@ from typing import TYPE_CHECKING, Any, Literal +from narwhals._duration import Interval from narwhals._plan.common import ExprNamespace, Function, IRNamespace from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: from typing_extensions import TypeAlias, TypeIs - from narwhals._duration import Interval, IntervalUnit + from narwhals._duration import IntervalUnit from narwhals._plan.dummy import Expr from narwhals.typing import TimeUnit @@ -95,12 +96,7 @@ class Timestamp(TemporalFunction): @staticmethod def from_time_unit(time_unit: TimeUnit, /) -> Timestamp: if not _is_polars_time_unit(time_unit): - from typing import get_args - - msg = ( - "invalid `time_unit`" - f"\n\nExpected one of {get_args(PolarsTimeUnit)}, got {time_unit!r}." - ) + msg = f"invalid `time_unit` \n\nExpected one of ['ns', 'us', 'ms'], got {time_unit!r}." raise ValueError(msg) return Timestamp(time_unit=time_unit) @@ -115,8 +111,6 @@ class Truncate(TemporalFunction): @staticmethod def from_string(every: str, /) -> Truncate: - from narwhals._duration import Interval - return Truncate.from_interval(Interval.parse(every)) @staticmethod diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index 429f8ed045..d264f39733 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -79,8 +79,6 @@ def then(self, statement: IntoExpr, /) -> ChainedThen: class ChainedThen(Immutable, Expr): - """https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/arity.rs#L89-L130.""" - __slots__ = ("conditions", "statements") conditions: Seq[ExprIR] statements: Seq[ExprIR] diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py index 5c1eafa7a4..f575d9d303 100644 --- a/narwhals/_plan/window.py +++ b/narwhals/_plan/window.py @@ -18,10 +18,7 @@ class Window(Immutable): - """Renamed from `WindowType`. - - https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/options/mod.rs#L139 - """ + """Renamed from `WindowType` https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/options/mod.rs#L139.""" class Over(Window): From 631d3a387e1cbfdd89856bd1ae8f3df6a8353207 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 31 Aug 2025 14:25:08 +0000 Subject: [PATCH 355/368] refactor(expr-ir): `copy.replace` most `with_*` methods (#3063) --- narwhals/_plan/aggregation.py | 7 +-- narwhals/_plan/common.py | 33 ++++++++++- narwhals/_plan/expr.py | 95 ++++++++------------------------ narwhals/_plan/expr_expansion.py | 9 +-- narwhals/_plan/name.py | 7 +-- 5 files changed, 59 insertions(+), 92 deletions(-) diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index dfa61a39d5..ea25e82ad1 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any -from narwhals._plan.common import ExprIR, _pascal_to_snake_case +from narwhals._plan.common import ExprIR, _pascal_to_snake_case, replace from narwhals._plan.exceptions import agg_scalar_error if TYPE_CHECKING: @@ -40,10 +40,7 @@ def map_ir(self, function: MapIR, /) -> ExprIR: return function(self.with_expr(self.expr.map_ir(function))) def with_expr(self, expr: ExprIR, /) -> Self: - if expr == self.expr: - return self - it = ((k, v) for k, v in self.__immutable_items__ if k != "expr") - return type(self)(expr=expr, **dict(it)) + return replace(self, expr=expr) def __init__(self, *, expr: ExprIR, **kwds: Any) -> None: if expr.is_scalar: diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index ce84eb05ef..3e77d493d0 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -2,6 +2,7 @@ import datetime as dt import re +import sys from collections.abc import Iterable from decimal import Decimal from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast, overload @@ -72,6 +73,19 @@ def decorator(cls_or_fn: T) -> T: return decorator +if sys.version_info >= (3, 13): + from copy import replace as replace # noqa: PLC0414 +else: + + def replace(obj: T, /, **changes: Any) -> T: + cls = obj.__class__ + func = getattr(cls, "__replace__", None) + if func is None: + msg = f"replace() does not support {cls.__name__} objects" + raise TypeError(msg) + return func(obj, **changes) # type: ignore[no-any-return] + + T = TypeVar("T") _IMMUTABLE_HASH_NAME: Literal["__immutable_hash_value__"] = "__immutable_hash_value__" @@ -111,6 +125,21 @@ def __setattr__(self, name: str, value: Never) -> Never: msg = f"{type(self).__name__!r} is immutable, {name!r} cannot be set." raise AttributeError(msg) + def __replace__(self, **changes: Any) -> Self: + """https://docs.python.org/3.13/library/copy.html#copy.replace""" # noqa: D415 + if len(changes) == 1: + k_new, v_new = next(iter(changes.items())) + # NOTE: Will trigger an attribute error if invalid name + if getattr(self, k_new) == v_new: + return self + changed = dict(self.__immutable_items__) + # Now we *don't* need to check the key is valid + changed[k_new] = v_new + else: + changed = dict(self.__immutable_items__) + changed |= changes + return type(self)(**changed) + def __init_subclass__(cls, *args: Any, **kwds: Any) -> None: super().__init_subclass__(*args, **kwds) if cls.__slots__: @@ -342,9 +371,7 @@ def map_ir(self, function: MapIR, /) -> NamedIR[ExprIR]: return self.with_expr(function(self.expr.map_ir(function))) def with_expr(self, expr: ExprIRT2, /) -> NamedIR[ExprIRT2]: - if expr == self.expr: - return cast("NamedIR[ExprIRT2]", self) - return NamedIR(expr=expr, name=self.name) + return cast("NamedIR[ExprIRT2]", replace(self, expr=expr)) def __repr__(self) -> str: return f"{self.name}={self.expr!r}" diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 279430405c..610a9e80a1 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -6,6 +6,7 @@ # - Literal import typing as t +from narwhals._plan import common from narwhals._plan.aggregation import AggExpr, OrderableAggExpr from narwhals._plan.common import ExprIR, SelectorIR, collect from narwhals._plan.exceptions import function_expr_invalid_operation_error @@ -111,7 +112,7 @@ def map_ir(self, function: MapIR, /) -> ExprIR: return function(self.with_expr(self.expr.map_ir(function))) def with_expr(self, expr: ExprIR, /) -> Self: - return self if expr == self.expr else type(self)(expr=expr, name=self.name) + return common.replace(self, expr=expr) class Column(ExprIR): @@ -122,7 +123,7 @@ def __repr__(self) -> str: return f"col({self.name!r})" def with_name(self, name: str, /) -> Column: - return self if name == self.name else col(name) + return common.replace(self, name=name) def map_ir(self, function: MapIR, /) -> ExprIR: return function(self) @@ -191,7 +192,7 @@ def map_ir(self, function: MapIR, /) -> ExprIR: return function(self.with_expr(self.expr.map_ir(function))) def with_expr(self, expr: ExprIR, /) -> Self: - return self if expr == self.expr else type(self)(expr=expr, names=self.names) + return common.replace(self, expr=expr) class Literal(ExprIR, t.Generic[LiteralT]): @@ -255,14 +256,12 @@ def iter_output_name(self) -> t.Iterator[ExprIR]: yield from self.left.iter_output_name() def with_left(self, left: LeftT2, /) -> BinaryExpr[LeftT2, OperatorT, RightT]: - if left == self.left: - return t.cast("BinaryExpr[LeftT2, OperatorT, RightT]", self) - return BinaryExpr(left=left, op=self.op, right=self.right) + changed = common.replace(self, left=left) + return t.cast("BinaryExpr[LeftT2, OperatorT, RightT]", changed) def with_right(self, right: RightT2, /) -> BinaryExpr[LeftT, OperatorT, RightT2]: - if right == self.right: - return t.cast("BinaryExpr[LeftT, OperatorT, RightT2]", self) - return BinaryExpr(left=self.left, op=self.op, right=right) + changed = common.replace(self, right=right) + return t.cast("BinaryExpr[LeftT, OperatorT, RightT2]", changed) def map_ir(self, function: MapIR, /) -> ExprIR: return function( @@ -299,7 +298,7 @@ def map_ir(self, function: MapIR, /) -> ExprIR: return function(self.with_expr(self.expr.map_ir(function))) def with_expr(self, expr: ExprIR, /) -> Self: - return self if expr == self.expr else type(self)(expr=expr, dtype=self.dtype) + return common.replace(self, expr=expr) class Sort(ExprIR): @@ -330,7 +329,7 @@ def map_ir(self, function: MapIR, /) -> ExprIR: return function(self.with_expr(self.expr.map_ir(function))) def with_expr(self, expr: ExprIR, /) -> Self: - return self if expr == self.expr else type(self)(expr=expr, options=self.options) + return common.replace(self, expr=expr) class SortBy(ExprIR): @@ -368,15 +367,10 @@ def map_ir(self, function: MapIR, /) -> ExprIR: return function(self.with_expr(self.expr.map_ir(function)).with_by(by)) def with_expr(self, expr: ExprIR, /) -> Self: - if expr == self.expr: - return self - return type(self)(expr=expr, by=self.by, options=self.options) + return common.replace(self, expr=expr) def with_by(self, by: t.Iterable[ExprIR], /) -> Self: - by = collect(by) - if by == self.by: - return self - return type(self)(expr=self.expr, by=by, options=self.options) + return common.replace(self, by=collect(by)) class FunctionExpr(ExprIR, t.Generic[FunctionT]): @@ -399,14 +393,10 @@ def is_scalar(self) -> bool: return self.function.is_scalar def with_options(self, options: FunctionOptions, /) -> Self: - options = self.options.with_flags(options.flags) - return type(self)(input=self.input, function=self.function, options=options) + return common.replace(self, options=self.options.with_flags(options.flags)) def with_input(self, input: t.Iterable[ExprIR], /) -> Self: # noqa: A002 - input = collect(input) - if input == self.input: - return self - return type(self)(input=input, function=self.function, options=self.options) + return common.replace(self, input=collect(input)) def map_ir(self, function: MapIR, /) -> ExprIR: return function(self.with_input(ir.map_ir(function) for ir in self.input)) @@ -520,11 +510,9 @@ def iter_output_name(self) -> t.Iterator[ExprIR]: yield from self.expr.iter_output_name() def map_ir(self, function: MapIR, /) -> ExprIR: - expr = self.expr.map_ir(function) - by = self.by.map_ir(function) - expr = self.expr if self.expr == expr else expr - by = self.by if self.by == by else by - return function(Filter(expr=expr, by=by)) + expr, by = self.expr, self.by + changed = common.replace(self, expr=expr.map_ir(function), by=by.map_ir(function)) + return function(changed) class WindowExpr(ExprIR): @@ -567,15 +555,10 @@ def map_ir(self, function: MapIR, /) -> ExprIR: return function(over) def with_expr(self, expr: ExprIR, /) -> Self: - if expr == self.expr: - return self - return type(self)(expr=expr, partition_by=self.partition_by, options=self.options) + return common.replace(self, expr=expr) def with_partition_by(self, partition_by: t.Iterable[ExprIR], /) -> Self: - by = collect(partition_by) - if by == self.partition_by: - return self - return type(self)(expr=self.expr, partition_by=by, options=self.options) + return common.replace(self, partition_by=collect(partition_by)) class OrderedWindowExpr(WindowExpr): @@ -625,39 +608,7 @@ def map_ir(self, function: MapIR, /) -> ExprIR: return function(over) def with_order_by(self, order_by: t.Iterable[ExprIR], /) -> Self: - by = collect(order_by) - if by == self.order_by: - return self - return type(self)( - expr=self.expr, - partition_by=self.partition_by, - order_by=by, - sort_options=self.sort_options, - options=self.options, - ) - - def with_expr(self, expr: ExprIR, /) -> Self: - if expr == self.expr: - return self - return type(self)( - expr=expr, - partition_by=self.partition_by, - order_by=self.order_by, - sort_options=self.sort_options, - options=self.options, - ) - - def with_partition_by(self, partition_by: t.Iterable[ExprIR], /) -> Self: - by = collect(partition_by) - if by == self.partition_by: - return self - return type(self)( - expr=self.expr, - partition_by=by, - order_by=self.order_by, - sort_options=self.sort_options, - options=self.options, - ) + return common.replace(self, order_by=collect(order_by)) class Len(ExprIR): @@ -758,7 +709,5 @@ def map_ir(self, function: MapIR, /) -> ExprIR: predicate = self.predicate.map_ir(function) truthy = self.truthy.map_ir(function) falsy = self.falsy.map_ir(function) - predicate = self.predicate if self.predicate == predicate else predicate - truthy = self.truthy if self.truthy == truthy else truthy - falsy = self.falsy if self.falsy == falsy else falsy - return function(Ternary(predicate=predicate, truthy=truthy, falsy=falsy)) + changed = common.replace(self, predicate=predicate, truthy=truthy, falsy=falsy) + return function(changed) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index f62ae90d81..d7ab345f81 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -43,6 +43,7 @@ from itertools import chain from typing import TYPE_CHECKING +from narwhals._plan import common from narwhals._plan.common import ( ExprIR, Immutable, @@ -157,13 +158,7 @@ def from_expr(cls, expr: Expr, /) -> ExpansionFlags: return cls.from_ir(expr._ir) def with_multiple_columns(self) -> ExpansionFlags: - return ExpansionFlags( - multiple_columns=True, - has_nth=self.has_nth, - has_wildcard=self.has_wildcard, - has_selector=self.has_selector, - has_exclude=self.has_exclude, - ) + return common.replace(self, multiple_columns=True) def prepare_projection( diff --git a/narwhals/_plan/name.py b/narwhals/_plan/name.py index b7fb5cb189..7c695599bc 100644 --- a/narwhals/_plan/name.py +++ b/narwhals/_plan/name.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING +from narwhals._plan import common from narwhals._plan.common import ExprIR, ExprNamespace, Immutable, IRNamespace if TYPE_CHECKING: @@ -37,7 +38,7 @@ def map_ir(self, function: MapIR, /) -> ExprIR: return function(self.with_expr(self.expr.map_ir(function))) def with_expr(self, expr: ExprIR, /) -> Self: - return self if expr == self.expr else type(self)(expr=expr) + return common.replace(self, expr=expr) class RenameAlias(ExprIR): @@ -64,9 +65,7 @@ def map_ir(self, function: MapIR, /) -> ExprIR: return function(self.with_expr(self.expr.map_ir(function))) def with_expr(self, expr: ExprIR, /) -> Self: - return ( - self if expr == self.expr else type(self)(expr=expr, function=self.function) - ) + return common.replace(self, expr=expr) class Prefix(Immutable): From 973824c28dde83e1220a54e380d27faa5e97d34c Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 10 Sep 2025 15:15:39 +0000 Subject: [PATCH 356/368] refactor(expr-ir): Shrinking `ExprIR` (main) (#3066) Pretty happy with the progress so far, will carry over some of (https://github.com/narwhals-dev/narwhals/pull/3066#issuecomment-3242037939) after rolling back into #2572 --- narwhals/_plan/_guards.py | 103 ++++++++ narwhals/_plan/_immutable.py | 151 +++++++++++ narwhals/_plan/aggregation.py | 60 +---- narwhals/_plan/arrow/dataframe.py | 64 ++--- narwhals/_plan/arrow/expr.py | 213 +++++++-------- narwhals/_plan/arrow/namespace.py | 130 ++++----- narwhals/_plan/boolean.py | 59 ++--- narwhals/_plan/categorical.py | 14 +- narwhals/_plan/common.py | 423 ++++++++++-------------------- narwhals/_plan/demo.py | 17 +- narwhals/_plan/dummy.py | 150 +++++------ narwhals/_plan/expr.py | 297 ++++----------------- narwhals/_plan/expr_expansion.py | 161 +++--------- narwhals/_plan/expr_parsing.py | 12 +- narwhals/_plan/expr_rewrites.py | 10 +- narwhals/_plan/functions.py | 130 +++------ narwhals/_plan/lists.py | 11 +- narwhals/_plan/literal.py | 29 +- narwhals/_plan/meta.py | 46 ++-- narwhals/_plan/name.py | 48 +--- narwhals/_plan/operators.py | 122 +++------ narwhals/_plan/options.py | 62 ++++- narwhals/_plan/protocols.py | 213 +++------------ narwhals/_plan/ranges.py | 4 +- narwhals/_plan/schema.py | 6 +- narwhals/_plan/selectors.py | 29 +- narwhals/_plan/strings.py | 53 ++-- narwhals/_plan/struct.py | 15 +- narwhals/_plan/temporal.py | 133 ++-------- narwhals/_plan/typing.py | 3 +- narwhals/_plan/when_then.py | 25 +- narwhals/_plan/window.py | 3 +- tests/plan/compliant_test.py | 2 +- tests/plan/expr_rewrites_test.py | 3 +- tests/plan/immutable_test.py | 2 +- tests/plan/utils.py | 3 +- 36 files changed, 1061 insertions(+), 1745 deletions(-) create mode 100644 narwhals/_plan/_guards.py create mode 100644 narwhals/_plan/_immutable.py diff --git a/narwhals/_plan/_guards.py b/narwhals/_plan/_guards.py new file mode 100644 index 0000000000..867d16d397 --- /dev/null +++ b/narwhals/_plan/_guards.py @@ -0,0 +1,103 @@ +"""Common type guards, mostly with inline imports.""" + +from __future__ import annotations + +import datetime as dt +from decimal import Decimal +from typing import TYPE_CHECKING, Any, TypeVar + +from narwhals._utils import _hasattr_static + +if TYPE_CHECKING: + from typing_extensions import TypeIs + + from narwhals._plan import expr + from narwhals._plan.dummy import Expr, Series + from narwhals._plan.protocols import CompliantSeries + from narwhals._plan.typing import NativeSeriesT, Seq + from narwhals.typing import NonNestedLiteral + + T = TypeVar("T") + +_NON_NESTED_LITERAL_TPS = ( + int, + float, + str, + dt.date, + dt.time, + dt.timedelta, + bytes, + Decimal, +) + + +def _dummy(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 + from narwhals._plan import dummy + + return dummy + + +def _expr(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 + from narwhals._plan import expr + + return expr + + +def is_non_nested_literal(obj: Any) -> TypeIs[NonNestedLiteral]: + return obj is None or isinstance(obj, _NON_NESTED_LITERAL_TPS) + + +def is_expr(obj: Any) -> TypeIs[Expr]: + return isinstance(obj, _dummy().Expr) + + +def is_column(obj: Any) -> TypeIs[Expr]: + """Indicate if the given object is a basic/unaliased column.""" + return is_expr(obj) and obj.meta.is_column() + + +def is_series(obj: Series[NativeSeriesT] | Any) -> TypeIs[Series[NativeSeriesT]]: + return isinstance(obj, _dummy().Series) + + +def is_compliant_series( + obj: CompliantSeries[NativeSeriesT] | Any, +) -> TypeIs[CompliantSeries[NativeSeriesT]]: + return _hasattr_static(obj, "__narwhals_series__") + + +def is_iterable_reject(obj: Any) -> TypeIs[str | bytes | Series | CompliantSeries]: + return isinstance(obj, (str, bytes, _dummy().Series)) or is_compliant_series(obj) + + +def is_window_expr(obj: Any) -> TypeIs[expr.WindowExpr]: + return isinstance(obj, _expr().WindowExpr) + + +def is_function_expr(obj: Any) -> TypeIs[expr.FunctionExpr[Any]]: + return isinstance(obj, _expr().FunctionExpr) + + +def is_binary_expr(obj: Any) -> TypeIs[expr.BinaryExpr]: + return isinstance(obj, _expr().BinaryExpr) + + +def is_agg_expr(obj: Any) -> TypeIs[expr.AggExpr]: + return isinstance(obj, _expr().AggExpr) + + +def is_aggregation(obj: Any) -> TypeIs[expr.AggExpr | expr.FunctionExpr[Any]]: + """Superset of `ExprIR.is_scalar`, excludes literals & len.""" + return is_agg_expr(obj) or (is_function_expr(obj) and obj.is_scalar) + + +def is_literal(obj: Any) -> TypeIs[expr.Literal[Any]]: + return isinstance(obj, _expr().Literal) + + +def is_horizontal_reduction(obj: Any) -> TypeIs[expr.FunctionExpr[Any]]: + return is_function_expr(obj) and obj.options.is_input_wildcard_expansion() + + +def is_tuple_of(obj: Any, tp: type[T]) -> TypeIs[Seq[T]]: + return bool(isinstance(obj, tuple) and obj and isinstance(obj[0], tp)) diff --git a/narwhals/_plan/_immutable.py b/narwhals/_plan/_immutable.py new file mode 100644 index 0000000000..0abe0739b6 --- /dev/null +++ b/narwhals/_plan/_immutable.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, TypeVar + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Any, Callable + + from typing_extensions import Never, Self, dataclass_transform + +else: + # https://docs.python.org/3/library/typing.html#typing.dataclass_transform + def dataclass_transform( + *, + eq_default: bool = True, + order_default: bool = False, + kw_only_default: bool = False, + frozen_default: bool = False, + field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = (), + **kwargs: Any, + ) -> Callable[[T], T]: + def decorator(cls_or_fn: T) -> T: + cls_or_fn.__dataclass_transform__ = { + "eq_default": eq_default, + "order_default": order_default, + "kw_only_default": kw_only_default, + "frozen_default": frozen_default, + "field_specifiers": field_specifiers, + "kwargs": kwargs, + } + return cls_or_fn + + return decorator + + +T = TypeVar("T") +_IMMUTABLE_HASH_NAME: Literal["__immutable_hash_value__"] = "__immutable_hash_value__" + + +@dataclass_transform(kw_only_default=True, frozen_default=True) +class Immutable: + """A poor man's frozen dataclass. + + - Keyword-only constructor (IDE supported) + - Manual `__slots__` required + - Compatible with [`copy.replace`] + - No handling for default arguments + + [`copy.replace`]: https://docs.python.org/3.13/library/copy.html#copy.replace + """ + + __slots__ = (_IMMUTABLE_HASH_NAME,) + __immutable_hash_value__: int + + @property + def __immutable_keys__(self) -> Iterator[str]: + slots: tuple[str, ...] = self.__slots__ + for name in slots: + if name != _IMMUTABLE_HASH_NAME: + yield name + + @property + def __immutable_values__(self) -> Iterator[Any]: + for name in self.__immutable_keys__: + yield getattr(self, name) + + @property + def __immutable_items__(self) -> Iterator[tuple[str, Any]]: + for name in self.__immutable_keys__: + yield name, getattr(self, name) + + @property + def __immutable_hash__(self) -> int: + if hasattr(self, _IMMUTABLE_HASH_NAME): + return self.__immutable_hash_value__ + hash_value = hash((self.__class__, *self.__immutable_values__)) + object.__setattr__(self, _IMMUTABLE_HASH_NAME, hash_value) + return self.__immutable_hash_value__ + + def __setattr__(self, name: str, value: Never) -> Never: + msg = f"{type(self).__name__!r} is immutable, {name!r} cannot be set." + raise AttributeError(msg) + + def __replace__(self, **changes: Any) -> Self: + """https://docs.python.org/3.13/library/copy.html#copy.replace""" # noqa: D415 + if len(changes) == 1: + # The most common case is a single field replacement. + # Iff that field happens to be equal, we can noop, preserving the current object's hash. + name, value_changed = next(iter(changes.items())) + if getattr(self, name) == value_changed: + return self + changes = dict(self.__immutable_items__, **changes) + else: + for name, value_current in self.__immutable_items__: + if name not in changes or value_current == changes[name]: + changes[name] = value_current + return type(self)(**changes) + + def __init_subclass__(cls, *args: Any, **kwds: Any) -> None: + super().__init_subclass__(*args, **kwds) + if cls.__slots__: + ... + else: + cls.__slots__ = () + + def __hash__(self) -> int: + return self.__immutable_hash__ + + def __eq__(self, other: object) -> bool: + if self is other: + return True + if type(self) is not type(other): + return False + return all( + getattr(self, key) == getattr(other, key) for key in self.__immutable_keys__ + ) + + def __str__(self) -> str: + fields = ", ".join(f"{_field_str(k, v)}" for k, v in self.__immutable_items__) + return f"{type(self).__name__}({fields})" + + def __init__(self, **kwds: Any) -> None: + required: set[str] = set(self.__immutable_keys__) + if not required and not kwds: + # NOTE: Fastpath for empty slots + ... + elif required == set(kwds): + for name, value in kwds.items(): + object.__setattr__(self, name, value) + elif missing := required.difference(kwds): + msg = ( + f"{type(self).__name__!r} requires attributes {sorted(required)!r}, \n" + f"but missing values for {sorted(missing)!r}" + ) + raise TypeError(msg) + else: + extra = set(kwds).difference(required) + msg = ( + f"{type(self).__name__!r} only supports attributes {sorted(required)!r}, \n" + f"but got unknown arguments {sorted(extra)!r}" + ) + raise TypeError(msg) + + +def _field_str(name: str, value: Any) -> str: + if isinstance(value, tuple): + inner = ", ".join(f"{v}" for v in value) + return f"{name}=[{inner}]" + if isinstance(value, str): + return f"{name}={value!r}" + return f"{name}={value}" diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index ea25e82ad1..b1f47ca1d7 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -2,19 +2,16 @@ from typing import TYPE_CHECKING, Any -from narwhals._plan.common import ExprIR, _pascal_to_snake_case, replace +from narwhals._plan.common import ExprIR, _pascal_to_snake_case from narwhals._plan.exceptions import agg_scalar_error if TYPE_CHECKING: from collections.abc import Iterator - from typing_extensions import Self - - from narwhals._plan.typing import MapIR from narwhals.typing import RollingInterpolationMethod -class AggExpr(ExprIR): +class AggExpr(ExprIR, child=("expr",)): __slots__ = ("expr",) expr: ExprIR @@ -25,50 +22,31 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: return f"{self.expr!r}.{_pascal_to_snake_case(type(self).__name__)}()" - def iter_left(self) -> Iterator[ExprIR]: - yield from self.expr.iter_left() - yield self - - def iter_right(self) -> Iterator[ExprIR]: - yield self - yield from self.expr.iter_right() - def iter_output_name(self) -> Iterator[ExprIR]: yield from self.expr.iter_output_name() - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self.with_expr(self.expr.map_ir(function))) - - def with_expr(self, expr: ExprIR, /) -> Self: - return replace(self, expr=expr) - def __init__(self, *, expr: ExprIR, **kwds: Any) -> None: if expr.is_scalar: raise agg_scalar_error(self, expr) super().__init__(expr=expr, **kwds) # pyright: ignore[reportCallIssue] +# fmt: off class Count(AggExpr): ... - - class Max(AggExpr): ... - - class Mean(AggExpr): ... - - class Median(AggExpr): ... - - class Min(AggExpr): ... - - class NUnique(AggExpr): ... - - +class Sum(AggExpr): ... +class OrderableAggExpr(AggExpr): ... +class First(OrderableAggExpr): ... +class Last(OrderableAggExpr): ... +class ArgMin(OrderableAggExpr): ... +class ArgMax(OrderableAggExpr): ... +# fmt: on class Quantile(AggExpr): __slots__ = (*AggExpr.__slots__, "interpolation", "quantile") - quantile: float interpolation: RollingInterpolationMethod @@ -78,24 +56,6 @@ class Std(AggExpr): ddof: int -class Sum(AggExpr): ... - - class Var(AggExpr): __slots__ = (*AggExpr.__slots__, "ddof") ddof: int - - -class OrderableAggExpr(AggExpr): ... - - -class First(OrderableAggExpr): ... - - -class Last(OrderableAggExpr): ... - - -class ArgMin(OrderableAggExpr): ... - - -class ArgMax(OrderableAggExpr): ... diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index cbc5f600b4..fc61e69acc 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -1,34 +1,34 @@ from __future__ import annotations -import typing as t +from typing import TYPE_CHECKING, Any, Literal, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._plan.arrow import functions as fn -from narwhals._plan.arrow.series import ArrowSeries -from narwhals._plan.common import ExprIR -from narwhals._plan.protocols import EagerDataFrame +from narwhals._plan.arrow.series import ArrowSeries as Series +from narwhals._plan.protocols import EagerDataFrame, namespace from narwhals._utils import Version +from narwhals.schema import Schema -if t.TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Sequence +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Mapping, Sequence from typing_extensions import Self from narwhals._arrow.typing import ChunkedArrayAny - from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar + from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.common import ExprIR, NamedIR - from narwhals._plan.dummy import DataFrame + from narwhals._plan.dummy import DataFrame as NwDataFrame from narwhals._plan.options import SortMultipleOptions from narwhals._plan.typing import Seq from narwhals.dtypes import DType - from narwhals.schema import Schema + from narwhals.typing import IntoSchema -class ArrowDataFrame(EagerDataFrame[ArrowSeries, "pa.Table", "ChunkedArrayAny"]): +class ArrowDataFrame(EagerDataFrame[Series, "pa.Table", "ChunkedArrayAny"]): def __narwhals_namespace__(self) -> ArrowNamespace: from narwhals._plan.arrow.namespace import ArrowNamespace @@ -49,47 +49,37 @@ def schema(self) -> dict[str, DType]: def __len__(self) -> int: return self.native.num_rows - def to_narwhals(self) -> DataFrame[pa.Table, ChunkedArrayAny]: + def to_narwhals(self) -> NwDataFrame[pa.Table, ChunkedArrayAny]: from narwhals._plan.dummy import DataFrame return DataFrame[pa.Table, "ChunkedArrayAny"]._from_compliant(self) @classmethod def from_dict( - cls, - data: t.Mapping[str, t.Any], - /, - *, - schema: t.Mapping[str, DType] | Schema | None = None, + cls, data: Mapping[str, Any], /, *, schema: IntoSchema | None = None ) -> Self: - from narwhals.schema import Schema - pa_schema = Schema(schema).to_arrow() if schema is not None else schema native = pa.Table.from_pydict(data, schema=pa_schema) return cls.from_native(native, version=Version.MAIN) - def iter_columns(self) -> t.Iterator[ArrowSeries]: + def iter_columns(self) -> Iterator[Series]: for name, series in zip(self.columns, self.native.itercolumns()): - yield ArrowSeries.from_native(series, name, version=self.version) - - @t.overload - def to_dict(self, *, as_series: t.Literal[True]) -> dict[str, ArrowSeries]: ... - @t.overload - def to_dict(self, *, as_series: t.Literal[False]) -> dict[str, list[t.Any]]: ... - @t.overload - def to_dict( - self, *, as_series: bool - ) -> dict[str, ArrowSeries] | dict[str, list[t.Any]]: ... - def to_dict( - self, *, as_series: bool - ) -> dict[str, ArrowSeries] | dict[str, list[t.Any]]: + yield Series.from_native(series, name, version=self.version) + + @overload + def to_dict(self, *, as_series: Literal[True]) -> dict[str, Series]: ... + @overload + def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... + @overload + def to_dict(self, *, as_series: bool) -> dict[str, Series] | dict[str, list[Any]]: ... + def to_dict(self, *, as_series: bool) -> dict[str, Series] | dict[str, list[Any]]: it = self.iter_columns() if as_series: return {ser.name: ser for ser in it} return {ser.name: ser.to_list() for ser in it} - def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[ArrowSeries]: - ns = self.__narwhals_namespace__() + def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[Series]: + ns = namespace(self) from_named_ir = ns._expr.from_named_ir yield from ns._expr.align(from_named_ir(e, self) for e in nodes) @@ -101,16 +91,16 @@ def sort(self, by: Seq[NamedIR], options: SortMultipleOptions) -> Self: def with_row_index(self, name: str) -> Self: return self._with_native(self.native.add_column(0, name, fn.int_range(len(self)))) - def get_column(self, name: str) -> ArrowSeries: + def get_column(self, name: str) -> Series: chunked = self.native.column(name) - return ArrowSeries.from_native(chunked, name, version=self.version) + return Series.from_native(chunked, name, version=self.version) def drop(self, columns: Sequence[str]) -> Self: to_drop = list(columns) return self._with_native(self.native.drop(to_drop)) # NOTE: Use instead of `with_columns` for trivial cases - def _with_columns(self, exprs: Iterable[ArrowExpr | ArrowScalar], /) -> Self: + def _with_columns(self, exprs: Iterable[Expr | Scalar], /) -> Self: native = self.native columns = self.columns height = len(self) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index ef6e0164ed..d8a163d120 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -7,11 +7,10 @@ from narwhals._arrow.utils import narwhals_to_native_dtype from narwhals._plan.arrow import functions as fn -from narwhals._plan.arrow.functions import lit -from narwhals._plan.arrow.series import ArrowSeries -from narwhals._plan.arrow.typing import NativeScalar, StoresNativeT_co -from narwhals._plan.common import ExprIR, NamedIR, into_dtype -from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch +from narwhals._plan.arrow.series import ArrowSeries as Series +from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co +from narwhals._plan.common import ExprIR, NamedIR +from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch, namespace from narwhals._utils import ( Implementation, Version, @@ -24,7 +23,7 @@ if TYPE_CHECKING: from collections.abc import Callable - from typing_extensions import Self + from typing_extensions import Self, TypeAlias from narwhals._arrow.typing import ChunkedArrayAny, Incomplete from narwhals._plan import boolean, expr @@ -44,28 +43,29 @@ Sum, Var, ) - from narwhals._plan.arrow.dataframe import ArrowDataFrame + from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.namespace import ArrowNamespace - from narwhals._plan.boolean import IsBetween, IsFinite, IsNan, IsNull + from narwhals._plan.boolean import All, IsBetween, IsFinite, IsNan, IsNull, Not from narwhals._plan.expr import ( AnonymousExpr, BinaryExpr, FunctionExpr, OrderedWindowExpr, RollingExpr, - Ternary, + TernaryExpr, WindowExpr, ) from narwhals._plan.functions import FillNull, Pow from narwhals.typing import Into1DArray, IntoDType, PythonLiteral + Expr: TypeAlias = "ArrowExpr" + Scalar: TypeAlias = "ArrowScalar" + BACKEND_VERSION = Implementation.PYARROW._backend_version() -class _ArrowDispatch( - ExprDispatch["ArrowDataFrame", StoresNativeT_co, "ArrowNamespace"], Protocol -): +class _ArrowDispatch(ExprDispatch["Frame", StoresNativeT_co, "ArrowNamespace"], Protocol): """Common to `Expr`, `Scalar` + their dependencies.""" def __narwhals_namespace__(self) -> ArrowNamespace: @@ -74,94 +74,84 @@ def __narwhals_namespace__(self) -> ArrowNamespace: return ArrowNamespace(self.version) def _with_native(self, native: Any, name: str, /) -> StoresNativeT_co: ... - def cast(self, node: expr.Cast, frame: ArrowDataFrame, name: str) -> StoresNativeT_co: + def cast(self, node: expr.Cast, frame: Frame, name: str) -> StoresNativeT_co: data_type = narwhals_to_native_dtype(node.dtype, frame.version) - native = self._dispatch(node.expr, frame, name).native + native = node.expr.dispatch(self, frame, name).native return self._with_native(fn.cast(native, data_type), name) - def pow( - self, node: FunctionExpr[Pow], frame: ArrowDataFrame, name: str - ) -> StoresNativeT_co: + def pow(self, node: FunctionExpr[Pow], frame: Frame, name: str) -> StoresNativeT_co: base, exponent = node.function.unwrap_input(node) - base_ = self._dispatch(base, frame, "base").native - exponent_ = self._dispatch(exponent, frame, "exponent").native + base_ = base.dispatch(self, frame, "base").native + exponent_ = exponent.dispatch(self, frame, "exponent").native return self._with_native(pc.power(base_, exponent_), name) def fill_null( - self, node: FunctionExpr[FillNull], frame: ArrowDataFrame, name: str + self, node: FunctionExpr[FillNull], frame: Frame, name: str ) -> StoresNativeT_co: expr, value = node.function.unwrap_input(node) - native = self._dispatch(expr, frame, name).native - value_ = self._dispatch(value, frame, "value").native + native = expr.dispatch(self, frame, name).native + value_ = value.dispatch(self, frame, "value").native return self._with_native(pc.fill_null(native, value_), name) def is_between( - self, node: FunctionExpr[IsBetween], frame: ArrowDataFrame, name: str + self, node: FunctionExpr[IsBetween], frame: Frame, name: str ) -> StoresNativeT_co: expr, lower_bound, upper_bound = node.function.unwrap_input(node) - native = self._dispatch(expr, frame, name).native - lower = self._dispatch(lower_bound, frame, "lower").native - upper = self._dispatch(upper_bound, frame, "upper").native + native = expr.dispatch(self, frame, name).native + lower = lower_bound.dispatch(self, frame, "lower").native + upper = upper_bound.dispatch(self, frame, "upper").native result = fn.is_between(native, lower, upper, node.function.closed) return self._with_native(result, name) def _unary_function( self, fn_native: Callable[[Any], Any], / - ) -> Callable[[FunctionExpr[Any], ArrowDataFrame, str], StoresNativeT_co]: - def func( - node: FunctionExpr[Any], frame: ArrowDataFrame, name: str - ) -> StoresNativeT_co: - native = self._dispatch(node.input[0], frame, name).native + ) -> Callable[[FunctionExpr[Any], Frame, str], StoresNativeT_co]: + def func(node: FunctionExpr[Any], frame: Frame, name: str) -> StoresNativeT_co: + native = node.input[0].dispatch(self, frame, name).native return self._with_native(fn_native(native), name) return func - def not_( - self, node: FunctionExpr[boolean.Not], frame: ArrowDataFrame, name: str - ) -> StoresNativeT_co: + def not_(self, node: FunctionExpr[Not], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(pc.invert)(node, frame, name) - def all( - self, node: FunctionExpr[boolean.All], frame: ArrowDataFrame, name: str - ) -> StoresNativeT_co: + def all(self, node: FunctionExpr[All], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(fn.all_)(node, frame, name) def any( - self, node: FunctionExpr[boolean.Any], frame: ArrowDataFrame, name: str + self, node: FunctionExpr[boolean.Any], frame: Frame, name: str ) -> StoresNativeT_co: return self._unary_function(fn.any_)(node, frame, name) def is_finite( - self, node: FunctionExpr[IsFinite], frame: ArrowDataFrame, name: str + self, node: FunctionExpr[IsFinite], frame: Frame, name: str ) -> StoresNativeT_co: return self._unary_function(fn.is_finite)(node, frame, name) def is_nan( - self, node: FunctionExpr[IsNan], frame: ArrowDataFrame, name: str + self, node: FunctionExpr[IsNan], frame: Frame, name: str ) -> StoresNativeT_co: return self._unary_function(fn.is_nan)(node, frame, name) def is_null( - self, node: FunctionExpr[IsNull], frame: ArrowDataFrame, name: str + self, node: FunctionExpr[IsNull], frame: Frame, name: str ) -> StoresNativeT_co: return self._unary_function(fn.is_null)(node, frame, name) - def binary_expr( - self, node: BinaryExpr, frame: ArrowDataFrame, name: str - ) -> StoresNativeT_co: + def binary_expr(self, node: BinaryExpr, frame: Frame, name: str) -> StoresNativeT_co: lhs, rhs = ( - self._dispatch(node.left, frame, name), - self._dispatch(node.right, frame, name), + node.left.dispatch(self, frame, name), + node.right.dispatch(self, frame, name), ) result = fn.binary(lhs.native, node.op.__class__, rhs.native) return self._with_native(result, name) def ternary_expr( - self, node: Ternary, frame: ArrowDataFrame, name: str + self, node: TernaryExpr, frame: Frame, name: str ) -> StoresNativeT_co: - when = self._dispatch(node.predicate, frame, name) - then = self._dispatch(node.truthy, frame, name) - otherwise = self._dispatch(node.falsy, frame, name) + when = node.predicate.dispatch(self, frame, name) + then = node.truthy.dispatch(self, frame, name) + otherwise = node.falsy.dispatch(self, frame, name) result = pc.if_else(when.native, then.native, otherwise.native) return self._with_native(result, name) @@ -169,9 +159,9 @@ def ternary_expr( class ArrowExpr( # type: ignore[misc] _ArrowDispatch["ArrowExpr | ArrowScalar"], _StoresNative["ChunkedArrayAny"], - EagerExpr["ArrowDataFrame", ArrowSeries], + EagerExpr["Frame", Series], ): - _evaluated: ArrowSeries + _evaluated: Series _version: Version @property @@ -179,7 +169,7 @@ def name(self) -> str: return self._evaluated.name @classmethod - def from_series(cls, series: ArrowSeries, /) -> Self: + def from_series(cls, series: Series, /) -> Self: obj = cls.__new__(cls) obj._evaluated = series obj._version = series.version @@ -189,26 +179,20 @@ def from_series(cls, series: ArrowSeries, /) -> Self: def from_native( cls, native: ChunkedArrayAny, name: str = "", /, version: Version = Version.MAIN ) -> Self: - return cls.from_series(ArrowSeries.from_native(native, name, version=version)) + return cls.from_series(Series.from_native(native, name, version=version)) @overload def _with_native(self, result: ChunkedArrayAny, name: str, /) -> Self: ... @overload - def _with_native(self, result: NativeScalar, name: str, /) -> ArrowScalar: ... + def _with_native(self, result: NativeScalar, name: str, /) -> Scalar: ... @overload - def _with_native( - self, result: ChunkedArrayAny | NativeScalar, name: str, / - ) -> ArrowScalar | Self: ... - def _with_native( - self, result: ChunkedArrayAny | NativeScalar, name: str, / - ) -> ArrowScalar | Self: + def _with_native(self, result: ChunkedOrScalarAny, name: str, /) -> Scalar | Self: ... + def _with_native(self, result: ChunkedOrScalarAny, name: str, /) -> Scalar | Self: if isinstance(result, pa.Scalar): return ArrowScalar.from_native(result, name, version=self.version) return self.from_native(result, name or self.name, self.version) - def _dispatch_expr( - self, node: ExprIR, frame: ArrowDataFrame, name: str - ) -> ArrowSeries: + def _dispatch_expr(self, node: ExprIR, frame: Frame, name: str) -> Series: """Use instead of `_dispatch` *iff* an operation isn't natively supported on `ChunkedArray`. There is no need to broadcast, as they may have a cheaper impl elsewhere (`CompliantScalar` or `ArrowScalar`). @@ -216,16 +200,16 @@ def _dispatch_expr( Mainly for the benefit of a type checker, but the equivalent `ArrowScalar._dispatch_expr` will raise if the assumption fails. """ - return self._dispatch(node, frame, name).to_series() + return node.dispatch(self, frame, name).to_series() @property def native(self) -> ChunkedArrayAny: return self._evaluated.native - def to_series(self) -> ArrowSeries: + def to_series(self) -> Series: return self._evaluated - def broadcast(self, length: int, /) -> ArrowSeries: + def broadcast(self, length: int, /) -> Series: if (actual_len := len(self)) != length: msg = f"Expected object of length {length}, got {actual_len}." raise ShapeError(msg) @@ -234,25 +218,24 @@ def broadcast(self, length: int, /) -> ArrowSeries: def __len__(self) -> int: return len(self._evaluated) - def sort(self, node: expr.Sort, frame: ArrowDataFrame, name: str) -> ArrowExpr: + def sort(self, node: expr.Sort, frame: Frame, name: str) -> Expr: native = self._dispatch_expr(node.expr, frame, name).native sorted_indices = pc.array_sort_indices(native, options=node.options.to_arrow()) return self._with_native(native.take(sorted_indices), name) - def sort_by(self, node: expr.SortBy, frame: ArrowDataFrame, name: str) -> ArrowExpr: + def sort_by(self, node: expr.SortBy, frame: Frame, name: str) -> Expr: series = self._dispatch_expr(node.expr, frame, name) by = ( self._dispatch_expr(e, frame, f"_{idx}") for idx, e in enumerate(node.by) ) - ns = self.__narwhals_namespace__() - df = ns._concat_horizontal((series, *by)) + df = namespace(self)._concat_horizontal((series, *by)) names = df.columns[1:] indices = pc.sort_indices(df.native, options=node.options.to_arrow(names)) result: ChunkedArrayAny = df.native.column(0).take(indices) return self._with_native(result, name) - def filter(self, node: expr.Filter, frame: ArrowDataFrame, name: str) -> ArrowExpr: + def filter(self, node: expr.Filter, frame: Frame, name: str) -> Expr: return self._with_native( self._dispatch_expr(node.expr, frame, name).native.filter( self._dispatch_expr(node.by, frame, name).native @@ -260,49 +243,49 @@ def filter(self, node: expr.Filter, frame: ArrowDataFrame, name: str) -> ArrowEx name, ) - def first(self, node: First, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def first(self, node: First, frame: Frame, name: str) -> Scalar: prev = self._dispatch_expr(node.expr, frame, name) native = prev.native - result = native[0] if len(prev) else lit(None, native.type) + result = native[0] if len(prev) else fn.lit(None, native.type) return self._with_native(result, name) - def last(self, node: Last, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def last(self, node: Last, frame: Frame, name: str) -> Scalar: prev = self._dispatch_expr(node.expr, frame, name) native = prev.native - result = native[height - 1] if (height := len(prev)) else lit(None, native.type) + result = native[len_ - 1] if (len_ := len(prev)) else fn.lit(None, native.type) return self._with_native(result, name) - def arg_min(self, node: ArgMin, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def arg_min(self, node: ArgMin, frame: Frame, name: str) -> Scalar: native = self._dispatch_expr(node.expr, frame, name).native result = pc.index(native, fn.min_(native)) return self._with_native(result, name) - def arg_max(self, node: ArgMax, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def arg_max(self, node: ArgMax, frame: Frame, name: str) -> Scalar: native = self._dispatch_expr(node.expr, frame, name).native result: NativeScalar = pc.index(native, fn.max_(native)) return self._with_native(result, name) - def sum(self, node: Sum, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def sum(self, node: Sum, frame: Frame, name: str) -> Scalar: result = fn.sum_(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) - def n_unique(self, node: NUnique, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def n_unique(self, node: NUnique, frame: Frame, name: str) -> Scalar: result = fn.n_unique(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) - def std(self, node: Std, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def std(self, node: Std, frame: Frame, name: str) -> Scalar: result = fn.std( self._dispatch_expr(node.expr, frame, name).native, ddof=node.ddof ) return self._with_native(result, name) - def var(self, node: Var, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def var(self, node: Var, frame: Frame, name: str) -> Scalar: result = fn.var( self._dispatch_expr(node.expr, frame, name).native, ddof=node.ddof ) return self._with_native(result, name) - def quantile(self, node: Quantile, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def quantile(self, node: Quantile, frame: Frame, name: str) -> Scalar: result = fn.quantile( self._dispatch_expr(node.expr, frame, name).native, q=node.quantile, @@ -310,23 +293,23 @@ def quantile(self, node: Quantile, frame: ArrowDataFrame, name: str) -> ArrowSca )[0] return self._with_native(result, name) - def count(self, node: Count, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def count(self, node: Count, frame: Frame, name: str) -> Scalar: result = fn.count(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) - def max(self, node: Max, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def max(self, node: Max, frame: Frame, name: str) -> Scalar: result: NativeScalar = fn.max_(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) - def mean(self, node: Mean, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def mean(self, node: Mean, frame: Frame, name: str) -> Scalar: result = fn.mean(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) - def median(self, node: Median, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def median(self, node: Median, frame: Frame, name: str) -> Scalar: result = fn.median(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) - def min(self, node: Min, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def min(self, node: Min, frame: Frame, name: str) -> Scalar: result: NativeScalar = fn.min_(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) @@ -336,12 +319,12 @@ def min(self, node: Min, frame: ArrowDataFrame, name: str) -> ArrowScalar: # - [x] `map_batches` is defined in `EagerExpr`, might be simpler here than on main # - [ ] `rolling_expr` has 4 variants - def over(self, node: WindowExpr, frame: ArrowDataFrame, name: str) -> Self: + def over(self, node: WindowExpr, frame: Frame, name: str) -> Self: raise NotImplementedError def over_ordered( - self, node: OrderedWindowExpr, frame: ArrowDataFrame, name: str - ) -> Self | ArrowScalar: + self, node: OrderedWindowExpr, frame: Frame, name: str + ) -> Self | Scalar: if node.partition_by: msg = f"Need to implement `group_by`, `join` for:\n{node!r}" raise NotImplementedError(msg) @@ -351,7 +334,7 @@ def over_ordered( options = node.sort_options.to_multiple(len(node.order_by)) idx_name = generate_temporary_column_name(8, frame.columns) sorted_context = frame.with_row_index(idx_name).sort(sort_by, options) - evaluated = self._dispatch(node.expr, sorted_context.drop([idx_name]), name) + evaluated = node.expr.dispatch(self, sorted_context.drop([idx_name]), name) if isinstance(evaluated, ArrowScalar): # NOTE: We're already sorted, defer broadcasting to the outer context # Wouldn't be suitable for partitions, but will be fine here @@ -364,28 +347,28 @@ def over_ordered( return self._with_native(result, name) # NOTE: Can't implement in `EagerExpr`, since it doesn't derive `ExprDispatch` - def map_batches(self, node: AnonymousExpr, frame: ArrowDataFrame, name: str) -> Self: + def map_batches(self, node: AnonymousExpr, frame: Frame, name: str) -> Self: if node.is_scalar: - # NOTE: Just trying to avoid redoing the whole API for `ArrowSeries` + # NOTE: Just trying to avoid redoing the whole API for `Series` msg = "Only elementwise is currently supported" raise NotImplementedError(msg) series = self._dispatch_expr(node.input[0], frame, name) udf = node.function.function - result: ArrowSeries | Into1DArray = udf(series) + result: Series | Into1DArray = udf(series) if not fn.is_series(result): - result = ArrowSeries.from_numpy(result, name, version=self.version) + result = Series.from_numpy(result, name, version=self.version) if dtype := node.function.return_dtype: result = result.cast(dtype) return self.from_series(result) - def rolling_expr(self, node: RollingExpr, frame: ArrowDataFrame, name: str) -> Self: + def rolling_expr(self, node: RollingExpr, frame: Frame, name: str) -> Self: raise NotImplementedError class ArrowScalar( _ArrowDispatch["ArrowScalar"], _StoresNative[NativeScalar], - EagerScalar["ArrowDataFrame", ArrowSeries], + EagerScalar["Frame", Series], ): _evaluated: NativeScalar _version: Version @@ -416,14 +399,12 @@ def from_python( version: Version = Version.MAIN, ) -> Self: dtype_pa: pa.DataType | None = None - if dtype: - dtype = into_dtype(dtype) - if not isinstance(dtype, version.dtypes.Unknown): - dtype_pa = narwhals_to_native_dtype(dtype, version) - return cls.from_native(lit(value, dtype_pa), name, version) + if dtype and dtype != version.dtypes.Unknown: + dtype_pa = narwhals_to_native_dtype(dtype, version) + return cls.from_native(fn.lit(value, dtype_pa), name, version) @classmethod - def from_series(cls, series: ArrowSeries) -> Self: + def from_series(cls, series: Series) -> Self: if len(series) == 1: return cls.from_native(series.native[0], series.name, series.version) if len(series) == 0: @@ -433,9 +414,7 @@ def from_series(cls, series: ArrowSeries) -> Self: msg = f"Too long {len(series)!r}" raise InvalidOperationError(msg) - def _dispatch_expr( - self, node: ExprIR, frame: ArrowDataFrame, name: str - ) -> ArrowSeries: + def _dispatch_expr(self, node: ExprIR, frame: Frame, name: str) -> Series: msg = f"Expected unreachable, but hit at: {node!r}" raise InvalidOperationError(msg) @@ -446,13 +425,13 @@ def _with_native(self, native: Any, name: str, /) -> Self: def native(self) -> NativeScalar: return self._evaluated - def to_series(self) -> ArrowSeries: + def to_series(self) -> Series: return self.broadcast(1) def to_python(self) -> PythonLiteral: return self.native.as_py() # type: ignore[no-any-return] - def broadcast(self, length: int) -> ArrowSeries: + def broadcast(self, length: int) -> Series: scalar = self.native if length == 1: chunked = fn.chunked_array(scalar) @@ -461,25 +440,25 @@ def broadcast(self, length: int) -> ArrowSeries: # https://github.com/zen-xu/pyarrow-stubs/pull/209 pa_repeat: Incomplete = pa.repeat chunked = fn.chunked_array(pa_repeat(scalar, length)) - return ArrowSeries.from_native(chunked, self.name, version=self.version) + return Series.from_native(chunked, self.name, version=self.version) - def arg_min(self, node: ArgMin, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def arg_min(self, node: ArgMin, frame: Frame, name: str) -> Scalar: return self._with_native(pa.scalar(0), name) - def arg_max(self, node: ArgMax, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def arg_max(self, node: ArgMax, frame: Frame, name: str) -> Scalar: return self._with_native(pa.scalar(0), name) - def n_unique(self, node: NUnique, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def n_unique(self, node: NUnique, frame: Frame, name: str) -> Scalar: return self._with_native(pa.scalar(1), name) - def std(self, node: Std, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def std(self, node: Std, frame: Frame, name: str) -> Scalar: return self._with_native(pa.scalar(None, pa.null()), name) - def var(self, node: Var, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def var(self, node: Var, frame: Frame, name: str) -> Scalar: return self._with_native(pa.scalar(None, pa.null()), name) - def count(self, node: Count, frame: ArrowDataFrame, name: str) -> ArrowScalar: - native = self._dispatch(node.expr, frame, name).native + def count(self, node: Count, frame: Frame, name: str) -> Scalar: + native = node.expr.dispatch(self, frame, name).native return self._with_native(pa.scalar(1 if native.is_valid else 0), name) filter = not_implemented() diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index e941727d6b..f7bfaaa330 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -7,9 +7,9 @@ import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import narwhals_to_native_dtype +from narwhals._plan._guards import is_tuple_of from narwhals._plan.arrow import functions as fn -from narwhals._plan.arrow.functions import lit -from narwhals._plan.common import collect, is_tuple_of +from narwhals._plan.common import collect from narwhals._plan.literal import is_literal_scalar from narwhals._plan.protocols import EagerNamespace from narwhals._utils import Version @@ -20,79 +20,64 @@ from narwhals._arrow.typing import ChunkedArrayAny from narwhals._plan import expr, functions as F - from narwhals._plan.arrow.dataframe import ArrowDataFrame - from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar - from narwhals._plan.arrow.series import ArrowSeries + from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame + from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar + from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.boolean import AllHorizontal, AnyHorizontal - from narwhals._plan.dummy import Series + from narwhals._plan.dummy import Series as NwSeries from narwhals._plan.expr import FunctionExpr, RangeExpr from narwhals._plan.ranges import IntRange - from narwhals._plan.strings import ConcatHorizontal + from narwhals._plan.strings import ConcatStr from narwhals.typing import ConcatMethod, NonNestedLiteral, PythonLiteral -class ArrowNamespace( - EagerNamespace["ArrowDataFrame", "ArrowSeries", "ArrowExpr", "ArrowScalar"] -): +class ArrowNamespace(EagerNamespace["Frame", "Series", "Expr", "Scalar"]): def __init__(self, version: Version = Version.MAIN) -> None: self._version = version @property - def _expr(self) -> type[ArrowExpr]: + def _expr(self) -> type[Expr]: from narwhals._plan.arrow.expr import ArrowExpr return ArrowExpr @property - def _scalar(self) -> type[ArrowScalar]: + def _scalar(self) -> type[Scalar]: from narwhals._plan.arrow.expr import ArrowScalar return ArrowScalar @property - def _series(self) -> type[ArrowSeries]: + def _series(self) -> type[Series]: from narwhals._plan.arrow.series import ArrowSeries return ArrowSeries @property - def _dataframe(self) -> type[ArrowDataFrame]: + def _dataframe(self) -> type[Frame]: from narwhals._plan.arrow.dataframe import ArrowDataFrame return ArrowDataFrame - def col(self, node: expr.Column, frame: ArrowDataFrame, name: str) -> ArrowExpr: + def col(self, node: expr.Column, frame: Frame, name: str) -> Expr: return self._expr.from_native( frame.native.column(node.name), name, version=frame.version ) @overload def lit( - self, node: expr.Literal[NonNestedLiteral], frame: ArrowDataFrame, name: str - ) -> ArrowScalar: ... - + self, node: expr.Literal[NonNestedLiteral], frame: Frame, name: str + ) -> Scalar: ... @overload def lit( - self, - node: expr.Literal[Series[ChunkedArrayAny]], - frame: ArrowDataFrame, - name: str, - ) -> ArrowExpr: ... - - @overload - def lit( - self, - node: expr.Literal[NonNestedLiteral] | expr.Literal[Series[ChunkedArrayAny]], - frame: ArrowDataFrame, - name: str, - ) -> ArrowExpr | ArrowScalar: ... - + self, node: expr.Literal[NwSeries[ChunkedArrayAny]], frame: Frame, name: str + ) -> Expr: ... def lit( self, - node: expr.Literal[NonNestedLiteral] | expr.Literal[Series[ChunkedArrayAny]], - frame: ArrowDataFrame, + node: expr.Literal[NonNestedLiteral] | expr.Literal[NwSeries[ChunkedArrayAny]], + frame: Frame, name: str, - ) -> ArrowExpr | ArrowScalar: + ) -> Expr | Scalar: if is_literal_scalar(node): return self._scalar.from_python( node.unwrap(), name, dtype=node.dtype, version=frame.version @@ -106,13 +91,11 @@ def lit( # https://github.com/narwhals-dev/narwhals/pull/2719 def _horizontal_function( self, fn_native: Callable[[Any, Any], Any], /, fill: NonNestedLiteral = None - ) -> Callable[[FunctionExpr[Any], ArrowDataFrame, str], ArrowExpr | ArrowScalar]: - def func( - node: FunctionExpr[Any], frame: ArrowDataFrame, name: str - ) -> ArrowExpr | ArrowScalar: + ) -> Callable[[FunctionExpr[Any], Frame, str], Expr | Scalar]: + def func(node: FunctionExpr[Any], frame: Frame, name: str) -> Expr | Scalar: it = (self._expr.from_ir(e, frame, name).native for e in node.input) if fill is not None: - it = (pc.fill_null(native, lit(fill)) for native in it) + it = (pc.fill_null(native, fn.lit(fill)) for native in it) result = reduce(fn_native, it) if isinstance(result, pa.Scalar): return self._scalar.from_native(result, name, self.version) @@ -121,36 +104,36 @@ def func( return func def any_horizontal( - self, node: FunctionExpr[AnyHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr | ArrowScalar: + self, node: FunctionExpr[AnyHorizontal], frame: Frame, name: str + ) -> Expr | Scalar: return self._horizontal_function(fn.or_)(node, frame, name) def all_horizontal( - self, node: FunctionExpr[AllHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr | ArrowScalar: + self, node: FunctionExpr[AllHorizontal], frame: Frame, name: str + ) -> Expr | Scalar: return self._horizontal_function(fn.and_)(node, frame, name) def sum_horizontal( - self, node: FunctionExpr[F.SumHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr | ArrowScalar: + self, node: FunctionExpr[F.SumHorizontal], frame: Frame, name: str + ) -> Expr | Scalar: return self._horizontal_function(fn.add, fill=0)(node, frame, name) def min_horizontal( - self, node: FunctionExpr[F.MinHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr | ArrowScalar: + self, node: FunctionExpr[F.MinHorizontal], frame: Frame, name: str + ) -> Expr | Scalar: return self._horizontal_function(fn.min_horizontal)(node, frame, name) def max_horizontal( - self, node: FunctionExpr[F.MaxHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr | ArrowScalar: + self, node: FunctionExpr[F.MaxHorizontal], frame: Frame, name: str + ) -> Expr | Scalar: return self._horizontal_function(fn.max_horizontal)(node, frame, name) def mean_horizontal( - self, node: FunctionExpr[F.MeanHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr | ArrowScalar: + self, node: FunctionExpr[F.MeanHorizontal], frame: Frame, name: str + ) -> Expr | Scalar: int64 = pa.int64() inputs = [self._expr.from_ir(e, frame, name).native for e in node.input] - filled = (pc.fill_null(native, lit(0)) for native in inputs) + filled = (pc.fill_null(native, fn.lit(0)) for native in inputs) # NOTE: `mypy` doesn't like that `add` is overloaded sum_not_null = reduce( fn.add, # type: ignore[arg-type] @@ -162,8 +145,8 @@ def mean_horizontal( return self._expr.from_native(result, name, self.version) def concat_str( - self, node: FunctionExpr[ConcatHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr | ArrowScalar: + self, node: FunctionExpr[ConcatStr], frame: Frame, name: str + ) -> Expr | Scalar: exprs = (self._expr.from_ir(e, frame, name) for e in node.input) aligned = (ser.native for ser in self._expr.align(exprs)) separator = node.function.separator @@ -173,9 +156,7 @@ def concat_str( return self._scalar.from_native(result, name, self.version) return self._expr.from_native(result, name, self.version) - def int_range( - self, node: RangeExpr[IntRange], frame: ArrowDataFrame, name: str - ) -> ArrowExpr: + def int_range(self, node: RangeExpr[IntRange], frame: Frame, name: str) -> Expr: start_: PythonLiteral end_: PythonLiteral start, end = node.function.unwrap_input(node) @@ -209,21 +190,12 @@ def int_range( raise InvalidOperationError(msg) @overload - def concat( - self, items: Iterable[ArrowDataFrame], *, how: ConcatMethod - ) -> ArrowDataFrame: ... - + def concat(self, items: Iterable[Frame], *, how: ConcatMethod) -> Frame: ... @overload + def concat(self, items: Iterable[Series], *, how: Literal["vertical"]) -> Series: ... def concat( - self, items: Iterable[ArrowSeries], *, how: Literal["vertical"] - ) -> ArrowSeries: ... - - def concat( - self, - items: Iterable[ArrowDataFrame] | Iterable[ArrowSeries], - *, - how: ConcatMethod, - ) -> ArrowDataFrame | ArrowSeries: + self, items: Iterable[Frame | Series], *, how: ConcatMethod + ) -> Frame | Series: if how == "vertical": return self._concat_vertical(items) if how == "horizontal": @@ -232,20 +204,16 @@ def concat( first = next(it) if self._is_series(first): raise TypeError(first) - dfs = cast("Sequence[ArrowDataFrame]", (first, *it)) + dfs = cast("Sequence[Frame]", (first, *it)) return self._concat_diagonal(dfs) - def _concat_diagonal(self, items: Iterable[ArrowDataFrame]) -> ArrowDataFrame: + def _concat_diagonal(self, items: Iterable[Frame]) -> Frame: return self._dataframe.from_native( fn.concat_vertical_table(df.native for df in items), self.version ) - def _concat_horizontal( - self, items: Iterable[ArrowDataFrame | ArrowSeries] - ) -> ArrowDataFrame: - def gen( - objs: Iterable[ArrowDataFrame | ArrowSeries], - ) -> Iterator[tuple[ChunkedArrayAny, str]]: + def _concat_horizontal(self, items: Iterable[Frame | Series]) -> Frame: + def gen(objs: Iterable[Frame | Series]) -> Iterator[tuple[ChunkedArrayAny, str]]: for item in objs: if self._is_series(item): yield item.native, item.name @@ -256,9 +224,7 @@ def gen( native = pa.Table.from_arrays(arrays, list(names)) return self._dataframe.from_native(native, self.version) - def _concat_vertical( - self, items: Iterable[ArrowDataFrame] | Iterable[ArrowSeries] - ) -> ArrowDataFrame | ArrowSeries: + def _concat_vertical(self, items: Iterable[Frame | Series]) -> Frame | Series: collected = collect(items) if is_tuple_of(collected, self._series): sers = collected diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/boolean.py index 7a3902cb6b..23f7d27dd3 100644 --- a/narwhals/_plan/boolean.py +++ b/narwhals/_plan/boolean.py @@ -4,8 +4,8 @@ # - Any import typing as t -from narwhals._plan.common import Function -from narwhals._plan.options import FunctionOptions +from narwhals._plan.common import Function, HorizontalFunction +from narwhals._plan.options import FEOptions, FunctionOptions from narwhals._typing_compat import TypeVar if t.TYPE_CHECKING: @@ -21,22 +21,22 @@ ExprT = TypeVar("ExprT", bound="ExprIR", default="ExprIR") -class BooleanFunction(Function): ... - - +# fmt: off +class BooleanFunction(Function, options=FunctionOptions.elementwise): ... class All(BooleanFunction, options=FunctionOptions.aggregation): ... - - -class AllHorizontal(BooleanFunction, options=FunctionOptions.horizontal): ... - - +class AllHorizontal(HorizontalFunction, BooleanFunction): ... class Any(BooleanFunction, options=FunctionOptions.aggregation): ... - - -class AnyHorizontal(BooleanFunction, options=FunctionOptions.horizontal): ... - - -class IsBetween(BooleanFunction, options=FunctionOptions.elementwise): +class AnyHorizontal(HorizontalFunction, BooleanFunction): ... +class IsDuplicated(BooleanFunction, options=FunctionOptions.length_preserving): ... +class IsFinite(BooleanFunction): ... +class IsFirstDistinct(BooleanFunction, options=FunctionOptions.length_preserving): ... +class IsLastDistinct(BooleanFunction, options=FunctionOptions.length_preserving): ... +class IsNan(BooleanFunction): ... +class IsNull(BooleanFunction): ... +class IsUnique(BooleanFunction, options=FunctionOptions.length_preserving): ... +class Not(BooleanFunction, config=FEOptions.renamed("not_")): ... +# fmt: on +class IsBetween(BooleanFunction): """N-ary (expr, lower_bound, upper_bound).""" __slots__ = ("closed",) @@ -47,16 +47,7 @@ def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR, Exp return expr, lower_bound, upper_bound -class IsDuplicated(BooleanFunction, options=FunctionOptions.length_preserving): ... - - -class IsFinite(BooleanFunction, options=FunctionOptions.elementwise): ... - - -class IsFirstDistinct(BooleanFunction, options=FunctionOptions.length_preserving): ... - - -class IsIn(BooleanFunction, t.Generic[OtherT], options=FunctionOptions.elementwise): +class IsIn(BooleanFunction, t.Generic[OtherT]): __slots__ = ("other",) other: OtherT @@ -90,19 +81,3 @@ def __init__(self, *, other: ExprT) -> None: "You should provide an iterable instead." ) raise NotImplementedError(msg) - - -class IsLastDistinct(BooleanFunction, options=FunctionOptions.length_preserving): ... - - -class IsNan(BooleanFunction, options=FunctionOptions.elementwise): ... - - -class IsNull(BooleanFunction, options=FunctionOptions.elementwise): ... - - -class IsUnique(BooleanFunction, options=FunctionOptions.length_preserving): ... - - -class Not(BooleanFunction, options=FunctionOptions.elementwise): - """`__invert__`.""" diff --git a/narwhals/_plan/categorical.py b/narwhals/_plan/categorical.py index 7fb58367f9..13791bed16 100644 --- a/narwhals/_plan/categorical.py +++ b/narwhals/_plan/categorical.py @@ -1,23 +1,19 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from narwhals._plan.common import ExprNamespace, Function, IRNamespace -from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: from narwhals._plan.dummy import Expr +# fmt: off class CategoricalFunction(Function, accessor="cat"): ... - - -class GetCategories(CategoricalFunction, options=FunctionOptions.groupwise): ... - - +class GetCategories(CategoricalFunction): ... +# fmt: on class IRCatNamespace(IRNamespace): - def get_categories(self) -> GetCategories: - return GetCategories() + get_categories: ClassVar = GetCategories class ExprCatNamespace(ExprNamespace[IRCatNamespace]): diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 3e77d493d0..f73f21b26b 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -5,73 +5,40 @@ import sys from collections.abc import Iterable from decimal import Decimal +from operator import attrgetter from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast, overload +from narwhals._plan._guards import is_function_expr, is_iterable_reject, is_literal +from narwhals._plan._immutable import Immutable +from narwhals._plan.options import ExprIROptions, FEOptions, FunctionOptions from narwhals._plan.typing import ( Accessor, DTypeT, ExprIRT, ExprIRT2, + FunctionT, IRNamespaceT, MapIR, NamedOrExprIRT, - NativeSeriesT, NonNestedDTypeT, + OneOrIterable, Seq, ) -from narwhals._utils import _hasattr_static from narwhals.dtypes import DType from narwhals.utils import Version if TYPE_CHECKING: from collections.abc import Iterator - from typing import Any, Callable, Literal - - from typing_extensions import Never, Self, TypeIs, dataclass_transform - - from narwhals._plan import expr - from narwhals._plan.dummy import Expr, Selector, Series - from narwhals._plan.expr import ( - AggExpr, - Alias, - BinaryExpr, - Cast, - Column, - FunctionExpr, - WindowExpr, - ) + from typing import Any, Callable + + from typing_extensions import Self, TypeAlias + + from narwhals._plan.dummy import Expr, Selector + from narwhals._plan.expr import Alias, Cast, Column, FunctionExpr from narwhals._plan.meta import IRMetaNamespace - from narwhals._plan.options import FunctionOptions - from narwhals._plan.protocols import CompliantSeries + from narwhals._plan.protocols import Ctx, FrameT_contra, R_co from narwhals.typing import NonNestedDType, NonNestedLiteral -else: - # NOTE: This isn't important to the proposal, just wanted IDE support - # for the **temporary** constructors. - # It is interesting how much boilerplate this avoids though 🤔 - # https://docs.python.org/3/library/typing.html#typing.dataclass_transform - def dataclass_transform( - *, - eq_default: bool = True, - order_default: bool = False, - kw_only_default: bool = False, - frozen_default: bool = False, - field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = (), - **kwargs: Any, - ) -> Callable[[T], T]: - def decorator(cls_or_fn: T) -> T: - cls_or_fn.__dataclass_transform__ = { - "eq_default": eq_default, - "order_default": order_default, - "kw_only_default": kw_only_default, - "frozen_default": frozen_default, - "field_specifiers": field_specifiers, - "kwargs": kwargs, - } - return cls_or_fn - - return decorator - if sys.version_info >= (3, 13): from copy import replace as replace # noqa: PLC0414 @@ -87,127 +54,109 @@ def replace(obj: T, /, **changes: Any) -> T: T = TypeVar("T") +Incomplete: TypeAlias = "Any" -_IMMUTABLE_HASH_NAME: Literal["__immutable_hash_value__"] = "__immutable_hash_value__" +def _pascal_to_snake_case(s: str) -> str: + """Convert a PascalCase, camelCase string to snake_case. -@dataclass_transform(kw_only_default=True, frozen_default=True) -class Immutable: - __slots__ = (_IMMUTABLE_HASH_NAME,) - __immutable_hash_value__: int + Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62 + """ + # Handle the sequence of uppercase letters followed by a lowercase letter + snake = _PATTERN_UPPER_LOWER.sub(_re_repl_snake, s) + # Insert an underscore between a lowercase letter and an uppercase letter + return _PATTERN_LOWER_UPPER.sub(_re_repl_snake, snake).lower() - @property - def __immutable_keys__(self) -> Iterator[str]: - slots: tuple[str, ...] = self.__slots__ - for name in slots: - if name != _IMMUTABLE_HASH_NAME: - yield name - @property - def __immutable_values__(self) -> Iterator[Any]: - for name in self.__immutable_keys__: - yield getattr(self, name) +_PATTERN_UPPER_LOWER = re.compile(r"([A-Z]+)([A-Z][a-z])") +_PATTERN_LOWER_UPPER = re.compile(r"([a-z])([A-Z])") - @property - def __immutable_items__(self) -> Iterator[tuple[str, Any]]: - for name in self.__immutable_keys__: - yield name, getattr(self, name) - @property - def __immutable_hash__(self) -> int: - if hasattr(self, _IMMUTABLE_HASH_NAME): - return self.__immutable_hash_value__ - hash_value = hash((self.__class__, *self.__immutable_values__)) - object.__setattr__(self, _IMMUTABLE_HASH_NAME, hash_value) - return self.__immutable_hash_value__ - - def __setattr__(self, name: str, value: Never) -> Never: - msg = f"{type(self).__name__!r} is immutable, {name!r} cannot be set." - raise AttributeError(msg) - - def __replace__(self, **changes: Any) -> Self: - """https://docs.python.org/3.13/library/copy.html#copy.replace""" # noqa: D415 - if len(changes) == 1: - k_new, v_new = next(iter(changes.items())) - # NOTE: Will trigger an attribute error if invalid name - if getattr(self, k_new) == v_new: - return self - changed = dict(self.__immutable_items__) - # Now we *don't* need to check the key is valid - changed[k_new] = v_new - else: - changed = dict(self.__immutable_items__) - changed |= changes - return type(self)(**changed) +def _re_repl_snake(match: re.Match[str], /) -> str: + return f"{match.group(1)}_{match.group(2)}" - def __init_subclass__(cls, *args: Any, **kwds: Any) -> None: - super().__init_subclass__(*args, **kwds) - if cls.__slots__: - ... - else: - cls.__slots__ = () - - def __hash__(self) -> int: - return self.__immutable_hash__ - - def __eq__(self, other: object) -> bool: - if self is other: - return True - if type(self) is not type(other): - return False - return all( - getattr(self, key) == getattr(other, key) for key in self.__immutable_keys__ - ) - - def __str__(self) -> str: - # NOTE: Debug repr, closer to constructor - fields = ", ".join(f"{_field_str(k, v)}" for k, v in self.__immutable_items__) - return f"{type(self).__name__}({fields})" - - def __init__(self, **kwds: Any) -> None: - # NOTE: DUMMY CONSTRUCTOR - don't use beyond prototyping! - # Just need a quick way to demonstrate `ExprIR` and interactions - required: set[str] = set(self.__immutable_keys__) - if not required and not kwds: - # NOTE: Fastpath for empty slots - ... - elif required == set(kwds): - # NOTE: Everything is as expected - for name, value in kwds.items(): - object.__setattr__(self, name, value) - elif missing := required.difference(kwds): - msg = ( - f"{type(self).__name__!r} requires attributes {sorted(required)!r}, \n" - f"but missing values for {sorted(missing)!r}" - ) - raise TypeError(msg) - else: - extra = set(kwds).difference(required) + +def _dispatch_method_name(tp: type[ExprIRT | FunctionT]) -> str: + config = tp.__expr_ir_config__ + name = config.override_name or _pascal_to_snake_case(tp.__name__) + return f"{ns}.{name}" if (ns := getattr(config, "accessor_name", "")) else name + + +def _dispatch_getter(tp: type[ExprIRT | FunctionT]) -> Callable[[Any], Any]: + getter = attrgetter(_dispatch_method_name(tp)) + if tp.__expr_ir_config__.origin == "expr": + return getter + return lambda ctx: getter(ctx.__narwhals_namespace__()) + + +def _dispatch_generate( + tp: type[ExprIRT], / +) -> Callable[[Incomplete, ExprIRT, Incomplete, str], Incomplete]: + if not tp.__expr_ir_config__.allow_dispatch: + + def _(ctx: Any, /, node: ExprIRT, _: Any, name: str) -> Any: msg = ( - f"{type(self).__name__!r} only supports attributes {sorted(required)!r}, \n" - f"but got unknown arguments {sorted(extra)!r}" + f"{tp.__name__!r} should not appear at the compliant-level.\n\n" + f"Make sure to expand all expressions first, got:\n{ctx!r}\n{node!r}\n{name!r}" ) raise TypeError(msg) + return _ + getter = _dispatch_getter(tp) + + def _(ctx: Any, /, node: ExprIRT, frame: Any, name: str) -> Any: + return getter(ctx)(node, frame, name) + + return _ + -def _field_str(name: str, value: Any) -> str: - if isinstance(value, tuple): - inner = ", ".join(f"{v}" for v in value) - return f"{name}=[{inner}]" - if isinstance(value, str): - return f"{name}={value!r}" - return f"{name}={value}" +def _dispatch_generate_function( + tp: type[FunctionT], / +) -> Callable[[Incomplete, FunctionExpr[FunctionT], Incomplete, str], Incomplete]: + getter = _dispatch_getter(tp) + + def _(ctx: Any, /, node: FunctionExpr[FunctionT], frame: Any, name: str) -> Any: + return getter(ctx)(node, frame, name) + + return _ class ExprIR(Immutable): """Anything that can be a node on a graph of expressions.""" + _child: ClassVar[Seq[str]] = () + """Nested node names, in iteration order.""" + + __expr_ir_config__: ClassVar[ExprIROptions] = ExprIROptions.default() + __expr_ir_dispatch__: ClassVar[ + staticmethod[[Incomplete, Self, Incomplete, str], Incomplete] + ] + + def __init_subclass__( + cls: type[Self], + *args: Any, + child: Seq[str] = (), + config: ExprIROptions | None = None, + **kwds: Any, + ) -> None: + super().__init_subclass__(*args, **kwds) + if child: + cls._child = child + if config: + cls.__expr_ir_config__ = config + cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate(cls)) + + def dispatch( + self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, / + ) -> R_co: + """Evaluate expression in `frame`, using `ctx` for implementation(s).""" + return self.__expr_ir_dispatch__(ctx, cast("Self", self), frame, name) # type: ignore[no-any-return] + def to_narwhals(self, version: Version = Version.MAIN) -> Expr: from narwhals._plan import dummy - if version is Version.MAIN: - return dummy.Expr._from_ir(self) - return dummy.ExprV1._from_ir(self) + tp = dummy.Expr if version is Version.MAIN else dummy.ExprV1 + return tp._from_ir(self) @property def is_scalar(self) -> bool: @@ -221,8 +170,11 @@ def map_ir(self, function: MapIR, /) -> ExprIR: [`polars_plan::plans::iterator::Expr.map_expr`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/iterator.rs#L152-L159 [`polars_plan::plans::visitor::visitors`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/visitor/visitors.rs """ - msg = f"Need to handle recursive visiting first for {type(self).__qualname__!r}!\n\n{self!r}" - raise NotImplementedError(msg) + if not self._child: + return function(self) + children = ((name, getattr(self, name)) for name in self._child) + changed = {name: _map_ir_child(child, function) for name, child in children} + return function(replace(self, **changed)) def iter_left(self) -> Iterator[ExprIR]: """Yield nodes root->leaf. @@ -247,6 +199,13 @@ def iter_left(self) -> Iterator[ExprIR]: >>> list(d._ir.iter_left()) [col('a'), col('a').alias('b'), col('a').alias('b').min(), col('a').alias('b').min().alias('c'), col('e'), col('f'), col('a').alias('b').min().alias('c').over([col('e'), col('f')])] """ + for name in self._child: + child: ExprIR | Seq[ExprIR] = getattr(self, name) + if isinstance(child, ExprIR): + yield from child.iter_left() + else: + for node in child: + yield from node.iter_left() yield self def iter_right(self) -> Iterator[ExprIR]: @@ -276,6 +235,13 @@ def iter_right(self) -> Iterator[ExprIR]: [col('a').alias('b').min().alias('c').over([col('e'), col('f')]), col('f'), col('e'), col('a').alias('b').min().alias('c'), col('a').alias('b').min(), col('a').alias('b'), col('a')] """ yield self + for name in reversed(self._child): + child: ExprIR | Seq[ExprIR] = getattr(self, name) + if isinstance(child, ExprIR): + yield from child.iter_right() + else: + for node in reversed(child): + yield from node.iter_right() def iter_root_names(self) -> Iterator[ExprIR]: """Override for different iteration behavior in `ExprIR.meta.root_names`. @@ -313,7 +279,7 @@ def _repr_html_(self) -> str: return self.__repr__() -class SelectorIR(ExprIR): +class SelectorIR(ExprIR, config=ExprIROptions.no_dispatch()): def to_narwhals(self, version: Version = Version.MAIN) -> Selector: from narwhals._plan import dummy @@ -366,12 +332,9 @@ def from_ir(expr: ExprIRT2, /) -> NamedIR[ExprIRT2]: """ return NamedIR(expr=expr, name=expr.meta.output_name(raise_if_undetermined=True)) - def map_ir(self, function: MapIR, /) -> NamedIR[ExprIR]: + def map_ir(self, function: MapIR, /) -> Self: """**WARNING**: don't use renaming ops here, or `self.name` is invalid.""" - return self.with_expr(function(self.expr.map_ir(function))) - - def with_expr(self, expr: ExprIRT2, /) -> NamedIR[ExprIRT2]: - return cast("NamedIR[ExprIRT2]", replace(self, expr=expr)) + return replace(self, expr=function(self.expr.map_ir(function))) def __repr__(self) -> str: return f"{self.name}={self.expr!r}" @@ -397,7 +360,7 @@ def is_elementwise_top_level(self) -> bool: return ir.options.is_elementwise() if is_literal(ir): return ir.is_scalar - return isinstance(ir, (expr.BinaryExpr, expr.Column, expr.Ternary, expr.Cast)) + return isinstance(ir, (expr.BinaryExpr, expr.Column, expr.TernaryExpr, expr.Cast)) class IRNamespace(Immutable): @@ -428,26 +391,19 @@ def _with_unary(self, function: Function, /) -> Expr: return self._expr._with_unary(function) -def _function_options_default() -> FunctionOptions: - from narwhals._plan.options import FunctionOptions - - return FunctionOptions.default() - - class Function(Immutable): """Shared by expr functions and namespace functions. - Only valid in `FunctionExpr.function` - https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L114 """ - _accessor: ClassVar[Accessor | None] = None - """Namespace accessor name, if any.""" - _function_options: ClassVar[staticmethod[[], FunctionOptions]] = staticmethod( - _function_options_default + FunctionOptions.default ) + __expr_ir_config__: ClassVar[FEOptions] = FEOptions.default() + __expr_ir_dispatch__: ClassVar[ + staticmethod[[Incomplete, FunctionExpr[Self], Incomplete, str], Incomplete] + ] @property def function_options(self) -> FunctionOptions: @@ -463,136 +419,29 @@ def to_function_expr(self, *inputs: ExprIR) -> FunctionExpr[Self]: return FunctionExpr(input=inputs, function=self, options=self.function_options) def __init_subclass__( - cls, + cls: type[Self], *args: Any, accessor: Accessor | None = None, options: Callable[[], FunctionOptions] | None = None, + config: FEOptions | None = None, **kwds: Any, ) -> None: super().__init_subclass__(*args, **kwds) if accessor: - cls._accessor = accessor + config = replace(config or FEOptions.default(), accessor_name=accessor) if options: cls._function_options = staticmethod(options) + if config: + cls.__expr_ir_config__ = config + cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate_function(cls)) def __repr__(self) -> str: - return _function_repr(type(self)) - - -# TODO @dangotbanned: Add caching strategy? -def _function_repr(tp: type[Function], /) -> str: - name = _pascal_to_snake_case(tp.__name__) - return f"{ns_name}.{name}" if (ns_name := tp._accessor) else name - - -def _pascal_to_snake_case(s: str) -> str: - """Convert a PascalCase, camelCase string to snake_case. - - Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62 - """ - # Handle the sequence of uppercase letters followed by a lowercase letter - snake = _PATTERN_UPPER_LOWER.sub(_re_repl_snake, s) - # Insert an underscore between a lowercase letter and an uppercase letter - return _PATTERN_LOWER_UPPER.sub(_re_repl_snake, snake).lower() - - -_PATTERN_UPPER_LOWER = re.compile(r"([A-Z]+)([A-Z][a-z])") -_PATTERN_LOWER_UPPER = re.compile(r"([a-z])([A-Z])") - - -def _re_repl_snake(match: re.Match[str], /) -> str: - return f"{match.group(1)}_{match.group(2)}" - - -_NON_NESTED_LITERAL_TPS = ( - int, - float, - str, - dt.date, - dt.time, - dt.timedelta, - bytes, - Decimal, -) - - -def is_non_nested_literal(obj: Any) -> TypeIs[NonNestedLiteral]: - return obj is None or isinstance(obj, _NON_NESTED_LITERAL_TPS) - - -def is_expr(obj: Any) -> TypeIs[Expr]: - from narwhals._plan.dummy import Expr - - return isinstance(obj, Expr) - - -def is_column(obj: Any) -> TypeIs[Expr]: - """Indicate if the given object is a basic/unaliased column. - - https://github.com/pola-rs/polars/blob/a3d6a3a7863b4d42e720a05df69ff6b6f5fc551f/py-polars/polars/_utils/various.py#L164-L168. - """ - return is_expr(obj) and obj.meta.is_column() + return _dispatch_method_name(type(self)) -def is_series(obj: Series[NativeSeriesT] | Any) -> TypeIs[Series[NativeSeriesT]]: - from narwhals._plan.dummy import Series - - return isinstance(obj, Series) - - -def is_compliant_series( - obj: CompliantSeries[NativeSeriesT] | Any, -) -> TypeIs[CompliantSeries[NativeSeriesT]]: - return _hasattr_static(obj, "__narwhals_series__") - - -def is_iterable_reject(obj: Any) -> TypeIs[str | bytes | Series | CompliantSeries]: - from narwhals._plan.dummy import Series - - return isinstance(obj, (str, bytes, Series)) or is_compliant_series(obj) - - -def is_window_expr(obj: Any) -> TypeIs[WindowExpr]: - from narwhals._plan.expr import WindowExpr - - return isinstance(obj, WindowExpr) - - -def is_function_expr(obj: Any) -> TypeIs[FunctionExpr[Any]]: - from narwhals._plan.expr import FunctionExpr - - return isinstance(obj, FunctionExpr) - - -def is_binary_expr(obj: Any) -> TypeIs[BinaryExpr]: - from narwhals._plan.expr import BinaryExpr - - return isinstance(obj, BinaryExpr) - - -def is_agg_expr(obj: Any) -> TypeIs[AggExpr]: - from narwhals._plan.expr import AggExpr - - return isinstance(obj, AggExpr) - - -def is_aggregation(obj: Any) -> TypeIs[AggExpr | FunctionExpr[Any]]: - """Superset of `ExprIR.is_scalar`, excludes literals & len.""" - return is_agg_expr(obj) or (is_function_expr(obj) and obj.is_scalar) - - -def is_literal(obj: Any) -> TypeIs[expr.Literal[Any]]: - from narwhals._plan import expr - - return isinstance(obj, expr.Literal) - - -def is_horizontal_reduction(obj: FunctionExpr[Any] | Any) -> TypeIs[FunctionExpr[Any]]: - return is_function_expr(obj) and obj.options.is_input_wildcard_expansion() - - -def is_tuple_of(obj: Any, tp: type[T]) -> TypeIs[Seq[T]]: - return bool(isinstance(obj, tuple) and obj and isinstance(obj[0], tp)) +class HorizontalFunction( + Function, options=FunctionOptions.horizontal, config=FEOptions.namespaced() +): ... def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) -> DType: @@ -641,10 +490,14 @@ def map_ir( return origin.map_ir(function) +def _map_ir_child(obj: ExprIR | Seq[ExprIR], fn: MapIR, /) -> ExprIR | Seq[ExprIR]: + return obj.map_ir(fn) if isinstance(obj, ExprIR) else tuple(e.map_ir(fn) for e in obj) + + # TODO @dangotbanned: Review again and try to work around (https://github.com/microsoft/pyright/issues/10673#issuecomment-3033789021) # The issue is `T` possibly being `Iterable` # Ignoring here still leaks the issue to the caller, where you need to annotate the base case -def flatten_hash_safe(iterable: Iterable[T | Iterable[T]], /) -> Iterator[T]: +def flatten_hash_safe(iterable: Iterable[OneOrIterable[T]], /) -> Iterator[T]: """Fully unwrap all levels of nesting. Aiming to reduce the chances of passing an unhashable argument. diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index 85b8ac2ad7..ab89b97c96 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -3,17 +3,12 @@ import builtins import typing as t -from narwhals._plan import boolean, expr, expr_parsing as parse, functions as F -from narwhals._plan.common import ( - into_dtype, - is_non_nested_literal, - is_series, - py_to_narwhals_dtype, -) +from narwhals._plan import _guards, boolean, expr, expr_parsing as parse, functions as F +from narwhals._plan.common import into_dtype, py_to_narwhals_dtype from narwhals._plan.expr import All, Len from narwhals._plan.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.ranges import IntRange -from narwhals._plan.strings import ConcatHorizontal +from narwhals._plan.strings import ConcatStr from narwhals._plan.when_then import When from narwhals._utils import Version, flatten @@ -39,9 +34,9 @@ def nth(*indices: int | t.Sequence[int]) -> Expr: def lit( value: NonNestedLiteral | Series[NativeSeriesT], dtype: IntoDType | None = None ) -> Expr: - if is_series(value): + if _guards.is_series(value): return SeriesLiteral(value=value).to_literal().to_narwhals() - if not is_non_nested_literal(value): + if not _guards.is_non_nested_literal(value): msg = f"{type(value).__name__!r} is not supported in `nw.lit`, got: {value!r}." raise TypeError(msg) if dtype is None: @@ -121,7 +116,7 @@ def concat_str( ) -> Expr: it = parse.parse_into_seq_of_expr_ir(exprs, *more_exprs) return ( - ConcatHorizontal(separator=separator, ignore_nulls=ignore_nulls) + ConcatStr(separator=separator, ignore_nulls=ignore_nulls) .to_function_expr(*it) .to_narwhals() ) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 6ef174ec06..0a1e469917 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -3,8 +3,8 @@ from __future__ import annotations import math -import typing as t -from typing import TYPE_CHECKING, Generic +from collections.abc import Iterable, Iterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload from narwhals._plan import ( aggregation as agg, @@ -15,7 +15,8 @@ functions as F, operators as ops, ) -from narwhals._plan.common import NamedIR, into_dtype, is_column, is_expr, is_series +from narwhals._plan._guards import is_column, is_expr, is_series +from narwhals._plan.common import into_dtype from narwhals._plan.contexts import ExprContext from narwhals._plan.options import ( EWMOptions, @@ -33,13 +34,11 @@ from narwhals.schema import Schema if TYPE_CHECKING: - from collections.abc import Iterable, Sequence - import pyarrow as pa from typing_extensions import Never, Self from narwhals._plan.categorical import ExprCatNamespace - from narwhals._plan.common import ExprIR, Function + from narwhals._plan.common import ExprIR, Function, NamedIR from narwhals._plan.lists import ExprListNamespace from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.name import ExprNameNamespace @@ -52,7 +51,7 @@ from narwhals._plan.strings import ExprStringNamespace from narwhals._plan.struct import ExprStructNamespace from narwhals._plan.temporal import ExprDateTimeNamespace - from narwhals._plan.typing import IntoExpr, IntoExprColumn, Seq, Udf + from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq, Udf from narwhals.dtypes import DType from narwhals.typing import ( ClosedInterval, @@ -69,10 +68,10 @@ # NOTE: Trying to keep consistent logic between `DataFrame.sort` and `Expr.sort_by` def _parse_sort_by( - by: IntoExpr | Iterable[IntoExpr] = (), + by: OneOrIterable[IntoExpr] = (), *more_by: IntoExpr, - descending: bool | t.Iterable[bool] = False, - nulls_last: bool | t.Iterable[bool] = False, + descending: OneOrIterable[bool] = False, + nulls_last: OneOrIterable[bool] = False, ) -> tuple[Seq[ExprIR], SortMultipleOptions]: sort_by = parse.parse_into_seq_of_expr_ir(by, *more_by) if length_changing := next((e for e in sort_by if e.is_scalar), None): @@ -86,7 +85,7 @@ def _parse_sort_by( # Entirely ignoring namespace + function binding class Expr: _ir: ExprIR - _version: t.ClassVar[Version] = Version.MAIN + _version: ClassVar[Version] = Version.MAIN def __repr__(self) -> str: return f"nw._plan.Expr({self.version.name.lower()}):\n{self._ir!r}" @@ -114,7 +113,7 @@ def alias(self, name: str) -> Self: def cast(self, dtype: IntoDType) -> Self: return self._from_ir(self._ir.cast(into_dtype(dtype))) - def exclude(self, *names: str | t.Iterable[str]) -> Self: + def exclude(self, *names: OneOrIterable[str]) -> Self: return self._from_ir(expr.Exclude.from_names(self._ir, *names)) def count(self) -> Self: @@ -165,8 +164,8 @@ def quantile( def over( self, - *partition_by: IntoExpr | t.Iterable[IntoExpr], - order_by: IntoExpr | t.Iterable[IntoExpr] = None, + *partition_by: OneOrIterable[IntoExpr], + order_by: OneOrIterable[IntoExpr] = None, descending: bool = False, nulls_last: bool = False, ) -> Self: @@ -191,10 +190,10 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: def sort_by( self, - by: IntoExpr | t.Iterable[IntoExpr], + by: OneOrIterable[IntoExpr], *more_by: IntoExpr, - descending: bool | t.Iterable[bool] = False, - nulls_last: bool | t.Iterable[bool] = False, + descending: OneOrIterable[bool] = False, + nulls_last: OneOrIterable[bool] = False, ) -> Self: keys, opts = _parse_sort_by( by, *more_by, descending=descending, nulls_last=nulls_last @@ -202,9 +201,7 @@ def sort_by( return self._from_ir(expr.SortBy(expr=self._ir, by=keys, options=opts)) def filter( - self, - *predicates: IntoExprColumn | t.Iterable[IntoExprColumn], - **constraints: t.Any, + self, *predicates: OneOrIterable[IntoExprColumn], **constraints: Any ) -> Self: by = parse.parse_predicates_constraints_into_expr_ir(*predicates, **constraints) return self._from_ir(expr.Filter(expr=self._ir, by=by)) @@ -217,7 +214,7 @@ def abs(self) -> Self: def hist( self, - bins: t.Sequence[float] | None = None, + bins: Sequence[float] | None = None, *, bin_count: int | None = None, include_breakpoint: bool = True, @@ -371,20 +368,20 @@ def ewm_mean( def replace_strict( self, - old: t.Sequence[t.Any] | t.Mapping[t.Any, t.Any], - new: t.Sequence[t.Any] | None = None, + old: Sequence[Any] | Mapping[Any, Any], + new: Sequence[Any] | None = None, *, return_dtype: IntoDType | None = None, ) -> Self: - before: Seq[t.Any] - after: Seq[t.Any] + before: Seq[Any] + after: Seq[Any] if new is None: - if not isinstance(old, t.Mapping): + if not isinstance(old, Mapping): msg = "`new` argument is required if `old` argument is not a Mapping type" raise TypeError(msg) before = tuple(old) after = tuple(old.values()) - elif isinstance(old, t.Mapping): + elif isinstance(old, Mapping): msg = "`new` argument cannot be used if `old` argument is a Mapping type" raise TypeError(msg) else: @@ -455,10 +452,10 @@ def is_between( boolean.IsBetween(closed=closed).to_function_expr(self._ir, *it) ) - def is_in(self, other: t.Iterable[t.Any]) -> Self: + def is_in(self, other: Iterable[Any]) -> Self: if is_series(other): return self._with_unary(boolean.IsInSeries.from_series(other)) - if isinstance(other, t.Iterable): + if isinstance(other, Iterable): return self._with_unary(boolean.IsInSeq.from_iterable(other)) if is_expr(other): return self._with_unary(boolean.IsInExpr(other=other._ir)) @@ -627,9 +624,9 @@ def _from_ir(cls, ir: expr.SelectorIR, /) -> Self: # type: ignore[override] def _to_expr(self) -> Expr: return self._ir.to_narwhals(self.version) - @t.overload # type: ignore[override] + @overload # type: ignore[override] def __or__(self, other: Self) -> Self: ... - @t.overload + @overload def __or__(self, other: IntoExprColumn | int | bool) -> Expr: ... def __or__(self, other: IntoExprColumn | int | bool) -> Self | Expr: if isinstance(other, type(self)): @@ -637,9 +634,9 @@ def __or__(self, other: IntoExprColumn | int | bool) -> Self | Expr: return self._from_ir(op.to_binary_selector(self._ir, other._ir)) return self._to_expr() | other - @t.overload # type: ignore[override] + @overload # type: ignore[override] def __and__(self, other: Self) -> Self: ... - @t.overload + @overload def __and__(self, other: IntoExprColumn | int | bool) -> Expr: ... def __and__(self, other: IntoExprColumn | int | bool) -> Self | Expr: if is_column(other) and (name := other.meta.output_name()): @@ -649,9 +646,9 @@ def __and__(self, other: IntoExprColumn | int | bool) -> Self | Expr: return self._from_ir(op.to_binary_selector(self._ir, other._ir)) return self._to_expr() & other - @t.overload # type: ignore[override] + @overload # type: ignore[override] def __sub__(self, other: Self) -> Self: ... - @t.overload + @overload def __sub__(self, other: IntoExpr) -> Expr: ... def __sub__(self, other: IntoExpr) -> Self | Expr: if isinstance(other, type(self)): @@ -659,9 +656,9 @@ def __sub__(self, other: IntoExpr) -> Self | Expr: return self._from_ir(op.to_binary_selector(self._ir, other._ir)) return self._to_expr() - other - @t.overload # type: ignore[override] + @overload # type: ignore[override] def __xor__(self, other: Self) -> Self: ... - @t.overload + @overload def __xor__(self, other: IntoExprColumn | int | bool) -> Expr: ... def __xor__(self, other: IntoExprColumn | int | bool) -> Self | Expr: if isinstance(other, type(self)): @@ -672,41 +669,41 @@ def __xor__(self, other: IntoExprColumn | int | bool) -> Self | Expr: def __invert__(self) -> Self: return self._from_ir(expr.InvertSelector(selector=self._ir)) - def __add__(self, other: t.Any) -> Expr: # type: ignore[override] + def __add__(self, other: Any) -> Expr: # type: ignore[override] if isinstance(other, type(self)): msg = "unsupported operand type(s) for op: ('Selector' + 'Selector')" raise TypeError(msg) return self._to_expr() + other # type: ignore[no-any-return] - def __radd__(self, other: t.Any) -> Never: + def __radd__(self, other: Any) -> Never: msg = "unsupported operand type(s) for op: ('Expr' + 'Selector')" raise TypeError(msg) - def __rsub__(self, other: t.Any) -> Never: + def __rsub__(self, other: Any) -> Never: msg = "unsupported operand type(s) for op: ('Expr' - 'Selector')" raise TypeError(msg) - @t.overload # type: ignore[override] + @overload # type: ignore[override] def __rand__(self, other: Self) -> Self: ... - @t.overload + @overload def __rand__(self, other: IntoExprColumn | int | bool) -> Expr: ... def __rand__(self, other: IntoExprColumn | int | bool) -> Self | Expr: if is_column(other) and (name := other.meta.output_name()): return by_name(name) & self return self._to_expr().__rand__(other) - @t.overload # type: ignore[override] + @overload # type: ignore[override] def __ror__(self, other: Self) -> Self: ... - @t.overload + @overload def __ror__(self, other: IntoExprColumn | int | bool) -> Expr: ... def __ror__(self, other: IntoExprColumn | int | bool) -> Self | Expr: if is_column(other) and (name := other.meta.output_name()): return by_name(name) | self return self._to_expr().__ror__(other) - @t.overload # type: ignore[override] + @overload # type: ignore[override] def __rxor__(self, other: Self) -> Self: ... - @t.overload + @overload def __rxor__(self, other: IntoExprColumn | int | bool) -> Expr: ... def __rxor__(self, other: IntoExprColumn | int | bool) -> Self | Expr: if is_column(other) and (name := other.meta.output_name()): @@ -715,16 +712,16 @@ def __rxor__(self, other: IntoExprColumn | int | bool) -> Self | Expr: class ExprV1(Expr): - _version: t.ClassVar[Version] = Version.V1 + _version: ClassVar[Version] = Version.V1 class SelectorV1(Selector): - _version: t.ClassVar[Version] = Version.V1 + _version: ClassVar[Version] = Version.V1 class BaseFrame(Generic[NativeFrameT]): - _compliant: CompliantBaseFrame[t.Any, NativeFrameT] - _version: t.ClassVar[Version] = Version.MAIN + _compliant: CompliantBaseFrame[Any, NativeFrameT] + _version: ClassVar[Version] = Version.MAIN @property def version(self) -> Version: @@ -742,13 +739,11 @@ def __repr__(self) -> str: # pragma: no cover return generate_repr(f"nw.{type(self).__name__}", self.to_native().__repr__()) @classmethod - def from_native(cls, native: t.Any, /) -> Self: + def from_native(cls, native: Any, /) -> Self: raise NotImplementedError @classmethod - def _from_compliant( - cls, compliant: CompliantBaseFrame[t.Any, NativeFrameT], / - ) -> Self: + def _from_compliant(cls, compliant: CompliantBaseFrame[Any, NativeFrameT], /) -> Self: obj = cls.__new__(cls) obj._compliant = compliant return obj @@ -758,8 +753,8 @@ def to_native(self) -> NativeFrameT: def _project( self, - exprs: tuple[IntoExpr | Iterable[IntoExpr], ...], - named_exprs: dict[str, t.Any], + exprs: tuple[OneOrIterable[IntoExpr], ...], + named_exprs: dict[str, Any], context: ExprContext, /, ) -> tuple[Seq[NamedIR[ExprIR]], FrozenSchema]: @@ -770,15 +765,13 @@ def _project( named_irs = expr_expansion.into_named_irs(irs, output_names) return schema_frozen.project(named_irs, context) - def select(self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: t.Any) -> Self: + def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: named_irs, schema_projected = self._project( exprs, named_exprs, ExprContext.SELECT ) return self._from_compliant(self._compliant.select(named_irs)) - def with_columns( - self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: t.Any - ) -> Self: + def with_columns(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: named_irs, schema_projected = self._project( exprs, named_exprs, ExprContext.WITH_COLUMNS ) @@ -786,10 +779,10 @@ def with_columns( def sort( self, - by: str | Iterable[str], + by: OneOrIterable[str], *more_by: str, - descending: bool | Sequence[bool] = False, - nulls_last: bool | Sequence[bool] = False, + descending: OneOrIterable[bool] = False, + nulls_last: OneOrIterable[bool] = False, ) -> Self: sort, opts = _parse_sort_by( by, *more_by, descending=descending, nulls_last=nulls_last @@ -802,7 +795,7 @@ def sort( class DataFrame(BaseFrame[NativeDataFrameT], Generic[NativeDataFrameT, NativeSeriesT]): - _compliant: CompliantDataFrame[t.Any, NativeDataFrameT, NativeSeriesT] + _compliant: CompliantDataFrame[Any, NativeDataFrameT, NativeSeriesT] @property def _series(self) -> type[Series[NativeSeriesT]]: @@ -812,7 +805,7 @@ def _series(self) -> type[Series[NativeSeriesT]]: @classmethod def from_native( # type: ignore[override] cls, native: NativeFrame, / - ) -> DataFrame[pa.Table, pa.ChunkedArray[t.Any]]: + ) -> DataFrame[pa.Table, pa.ChunkedArray[Any]]: if is_pyarrow_table(native): from narwhals._plan.arrow.dataframe import ArrowDataFrame @@ -820,22 +813,19 @@ def from_native( # type: ignore[override] raise NotImplementedError(type(native)) - @t.overload + @overload def to_dict( - self, *, as_series: t.Literal[True] = ... + self, *, as_series: Literal[True] = ... ) -> dict[str, Series[NativeSeriesT]]: ... - - @t.overload - def to_dict(self, *, as_series: t.Literal[False]) -> dict[str, list[t.Any]]: ... - - @t.overload + @overload + def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... + @overload def to_dict( self, *, as_series: bool - ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[t.Any]]: ... - + ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[Any]]: ... def to_dict( self, *, as_series: bool = True - ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[t.Any]]: + ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[Any]]: if as_series: return { key: self._series._from_compliant(value) @@ -849,7 +839,7 @@ def __len__(self) -> int: class Series(Generic[NativeSeriesT]): _compliant: CompliantSeries[NativeSeriesT] - _version: t.ClassVar[Version] = Version.MAIN + _version: ClassVar[Version] = Version.MAIN @property def version(self) -> Version: @@ -867,7 +857,7 @@ def name(self) -> str: @classmethod def from_native( cls, native: NativeSeries, name: str = "", / - ) -> Series[pa.ChunkedArray[t.Any]]: + ) -> Series[pa.ChunkedArray[Any]]: if is_pyarrow_chunked_array(native): from narwhals._plan.arrow.series import ArrowSeries @@ -886,12 +876,12 @@ def _from_compliant(cls, compliant: CompliantSeries[NativeSeriesT], /) -> Self: def to_native(self) -> NativeSeriesT: return self._compliant.native - def to_list(self) -> list[t.Any]: + def to_list(self) -> list[Any]: return self._compliant.to_list() - def __iter__(self) -> t.Iterator[t.Any]: + def __iter__(self) -> Iterator[Any]: yield from self.to_native() class SeriesV1(Series[NativeSeriesT]): - _version: t.ClassVar[Version] = Version.V1 + _version: ClassVar[Version] = Version.V1 diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 610a9e80a1..7d35237f16 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -6,23 +6,20 @@ # - Literal import typing as t -from narwhals._plan import common from narwhals._plan.aggregation import AggExpr, OrderableAggExpr from narwhals._plan.common import ExprIR, SelectorIR, collect from narwhals._plan.exceptions import function_expr_invalid_operation_error from narwhals._plan.name import KeepName, RenameAlias +from narwhals._plan.options import ExprIROptions from narwhals._plan.typing import ( FunctionT, LeftSelectorT, LeftT, - LeftT2, LiteralT, - MapIR, OperatorT, RangeT, RightSelectorT, RightT, - RightT2, RollingT, SelectorOperatorT, SelectorT, @@ -37,6 +34,7 @@ from narwhals._plan.functions import MapBatches # noqa: F401 from narwhals._plan.literal import LiteralValue from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions + from narwhals._plan.protocols import Ctx, FrameT_contra, R_co from narwhals._plan.selectors import Selector from narwhals._plan.window import Window from narwhals.dtypes import DType @@ -66,7 +64,7 @@ "SelectorIR", "Sort", "SortBy", - "Ternary", + "TernaryExpr", "WindowExpr", "col", ] @@ -88,7 +86,7 @@ def index_columns(*indices: int) -> IndexColumns: return IndexColumns(indices=indices) -class Alias(ExprIR): +class Alias(ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): __slots__ = ("expr", "name") expr: ExprIR name: str @@ -100,41 +98,18 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: return f"{self.expr!r}.alias({self.name!r})" - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - yield from self.expr.iter_right() - - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self.with_expr(self.expr.map_ir(function))) - - def with_expr(self, expr: ExprIR, /) -> Self: - return common.replace(self, expr=expr) - -class Column(ExprIR): +class Column(ExprIR, config=ExprIROptions.namespaced("col")): __slots__ = ("name",) name: str def __repr__(self) -> str: return f"col({self.name!r})" - def with_name(self, name: str, /) -> Column: - return common.replace(self, name=name) - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self) - - -class _ColumnSelection(ExprIR): +class _ColumnSelection(ExprIR, config=ExprIROptions.no_dispatch()): """Nodes which can resolve to `Column`(s) with a `Schema`.""" - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self) - class Columns(_ColumnSelection): __slots__ = ("names",) @@ -165,7 +140,7 @@ def __repr__(self) -> str: return "all()" -class Exclude(_ColumnSelection): +class Exclude(_ColumnSelection, child=("expr",)): __slots__ = ("expr", "names") expr: ExprIR """Default is `all()`.""" @@ -180,22 +155,8 @@ def from_names(expr: ExprIR, *names: str | t.Iterable[str]) -> Exclude: def __repr__(self) -> str: return f"{self.expr!r}.exclude({list(self.names)!r})" - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - yield from self.expr.iter_right() - - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self.with_expr(self.expr.map_ir(function))) - - def with_expr(self, expr: ExprIR, /) -> Self: - return common.replace(self, expr=expr) - -class Literal(ExprIR, t.Generic[LiteralT]): +class Literal(ExprIR, t.Generic[LiteralT], config=ExprIROptions.namespaced("lit")): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L81.""" __slots__ = ("value",) @@ -219,9 +180,6 @@ def __repr__(self) -> str: def unwrap(self) -> LiteralT: return self.value.unwrap() - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self) - class _BinaryOp(ExprIR, t.Generic[LeftT, OperatorT, RightT]): __slots__ = ("left", "op", "right") @@ -238,40 +196,17 @@ def __repr__(self) -> str: class BinaryExpr( - _BinaryOp[LeftT, OperatorT, RightT], t.Generic[LeftT, OperatorT, RightT] + _BinaryOp[LeftT, OperatorT, RightT], + t.Generic[LeftT, OperatorT, RightT], + child=("left", "right"), ): """Application of two exprs via an `Operator`.""" - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.left.iter_left() - yield from self.right.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - yield from self.right.iter_right() - yield from self.left.iter_right() - def iter_output_name(self) -> t.Iterator[ExprIR]: yield from self.left.iter_output_name() - def with_left(self, left: LeftT2, /) -> BinaryExpr[LeftT2, OperatorT, RightT]: - changed = common.replace(self, left=left) - return t.cast("BinaryExpr[LeftT2, OperatorT, RightT]", changed) - - def with_right(self, right: RightT2, /) -> BinaryExpr[LeftT, OperatorT, RightT2]: - changed = common.replace(self, right=right) - return t.cast("BinaryExpr[LeftT, OperatorT, RightT2]", changed) - - def map_ir(self, function: MapIR, /) -> ExprIR: - return function( - self.with_left(self.left.map_ir(function)).with_right( - self.right.map_ir(function) - ) - ) - -class Cast(ExprIR): +class Cast(ExprIR, child=("expr",)): __slots__ = ("expr", "dtype") # noqa: RUF023 expr: ExprIR dtype: DType @@ -283,25 +218,11 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: return f"{self.expr!r}.cast({self.dtype!r})" - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - yield from self.expr.iter_right() - def iter_output_name(self) -> t.Iterator[ExprIR]: yield from self.expr.iter_output_name() - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self.with_expr(self.expr.map_ir(function))) - - def with_expr(self, expr: ExprIR, /) -> Self: - return common.replace(self, expr=expr) - -class Sort(ExprIR): +class Sort(ExprIR, child=("expr",)): __slots__ = ("expr", "options") expr: ExprIR options: SortOptions @@ -314,25 +235,11 @@ def __repr__(self) -> str: direction = "desc" if self.options.descending else "asc" return f"{self.expr!r}.sort({direction})" - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - yield from self.expr.iter_right() - def iter_output_name(self) -> t.Iterator[ExprIR]: yield from self.expr.iter_output_name() - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self.with_expr(self.expr.map_ir(function))) - - def with_expr(self, expr: ExprIR, /) -> Self: - return common.replace(self, expr=expr) - -class SortBy(ExprIR): +class SortBy(ExprIR, child=("expr", "by")): """https://github.com/narwhals-dev/narwhals/issues/2534.""" __slots__ = ("expr", "by", "options") # noqa: RUF023 @@ -347,33 +254,11 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: return f"{self.expr!r}.sort_by(by={self.by!r}, options={self.options!r})" - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_left() - for e in self.by: - yield from e.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - for e in reversed(self.by): - yield from e.iter_right() - yield from self.expr.iter_right() - def iter_output_name(self) -> t.Iterator[ExprIR]: yield from self.expr.iter_output_name() - def map_ir(self, function: MapIR, /) -> ExprIR: - by = (ir.map_ir(function) for ir in self.by) - return function(self.with_expr(self.expr.map_ir(function)).with_by(by)) - def with_expr(self, expr: ExprIR, /) -> Self: - return common.replace(self, expr=expr) - - def with_by(self, by: t.Iterable[ExprIR], /) -> Self: - return common.replace(self, by=collect(by)) - - -class FunctionExpr(ExprIR, t.Generic[FunctionT]): +class FunctionExpr(ExprIR, t.Generic[FunctionT], child=("input",)): """**Representing `Expr::Function`**. https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L114-L120 @@ -392,15 +277,6 @@ class FunctionExpr(ExprIR, t.Generic[FunctionT]): def is_scalar(self) -> bool: return self.function.is_scalar - def with_options(self, options: FunctionOptions, /) -> Self: - return common.replace(self, options=self.options.with_flags(options.flags)) - - def with_input(self, input: t.Iterable[ExprIR], /) -> Self: # noqa: A002 - return common.replace(self, input=collect(input)) - - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self.with_input(ir.map_ir(function) for ir in self.input)) - def __repr__(self) -> str: if self.input: first = self.input[0] @@ -409,16 +285,6 @@ def __repr__(self) -> str: return f"{first!r}.{self.function!r}()" return f"{self.function!r}()" - def iter_left(self) -> t.Iterator[ExprIR]: - for e in self.input: - yield from e.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - for e in reversed(self.input): - yield from e.iter_right() - def iter_output_name(self) -> t.Iterator[ExprIR]: """When we have multiple inputs, we want the name of the left-most expression. @@ -447,13 +313,25 @@ def __init__( raise function_expr_invalid_operation_error(function, parent) super().__init__(**dict(input=input, function=function, options=options, **kwds)) + def dispatch( + self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str + ) -> R_co: + return self.function.__expr_ir_dispatch__(ctx, t.cast("Self", self), frame, name) # type: ignore[no-any-return] + class RollingExpr(FunctionExpr[RollingT]): ... -class AnonymousExpr(FunctionExpr["MapBatches"]): +class AnonymousExpr( + FunctionExpr["MapBatches"], config=ExprIROptions.renamed("map_batches") +): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L158-L166.""" + def dispatch( + self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str + ) -> R_co: + return self.__expr_ir_dispatch__(ctx, t.cast("Self", self), frame, name) # type: ignore[no-any-return] + class RangeExpr(FunctionExpr[RangeT]): """E.g. `int_range(...)`. @@ -484,7 +362,7 @@ def __repr__(self) -> str: return f"{self.function!r}({list(self.input)!r})" -class Filter(ExprIR): +class Filter(ExprIR, child=("expr", "by")): __slots__ = ("expr", "by") # noqa: RUF023 expr: ExprIR by: ExprIR @@ -496,26 +374,13 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: return f"{self.expr!r}.filter({self.by!r})" - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_left() - yield from self.by.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - yield from self.by.iter_right() - yield from self.expr.iter_right() - def iter_output_name(self) -> t.Iterator[ExprIR]: yield from self.expr.iter_output_name() - def map_ir(self, function: MapIR, /) -> ExprIR: - expr, by = self.expr, self.by - changed = common.replace(self, expr=expr.map_ir(function), by=by.map_ir(function)) - return function(changed) - -class WindowExpr(ExprIR): +class WindowExpr( + ExprIR, child=("expr", "partition_by"), config=ExprIROptions.renamed("over") +): """A fully specified `.over()`, that occurred after another expression. Related: @@ -533,35 +398,15 @@ class WindowExpr(ExprIR): def __repr__(self) -> str: return f"{self.expr!r}.over({list(self.partition_by)!r})" - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_left() - for e in self.partition_by: - yield from e.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - for e in reversed(self.partition_by): - yield from e.iter_right() - yield from self.expr.iter_right() - def iter_output_name(self) -> t.Iterator[ExprIR]: yield from self.expr.iter_output_name() - def map_ir(self, function: MapIR, /) -> ExprIR: - over = self.with_expr(self.expr.map_ir(function)).with_partition_by( - ir.map_ir(function) for ir in self.partition_by - ) - return function(over) - - def with_expr(self, expr: ExprIR, /) -> Self: - return common.replace(self, expr=expr) - - def with_partition_by(self, partition_by: t.Iterable[ExprIR], /) -> Self: - return common.replace(self, partition_by=collect(partition_by)) - -class OrderedWindowExpr(WindowExpr): +class OrderedWindowExpr( + WindowExpr, + child=("expr", "partition_by", "order_by"), + config=ExprIROptions.renamed("over_ordered"), +): __slots__ = ("expr", "partition_by", "order_by", "sort_options", "options") # noqa: RUF023 expr: ExprIR partition_by: Seq[ExprIR] @@ -577,41 +422,18 @@ def __repr__(self) -> str: args = f"partition_by={list(self.partition_by)!r}, order_by={list(order)!r}" return f"{self.expr!r}.over({args})" - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_left() - for e in self.partition_by: - yield from e.iter_left() - for e in self.order_by: - yield from e.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - for e in reversed(self.order_by): - yield from e.iter_right() - for e in reversed(self.partition_by): - yield from e.iter_right() - yield from self.expr.iter_right() - def iter_root_names(self) -> t.Iterator[ExprIR]: # NOTE: `order_by` is never considered in `polars` # To match that behavior for `root_names` - but still expand in all other cases # - this little escape hatch exists # https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/plans/iterator.rs#L76-L86 - yield from super().iter_left() - - def map_ir(self, function: MapIR, /) -> ExprIR: - over = self.with_expr(self.expr.map_ir(function)).with_partition_by( - ir.map_ir(function) for ir in self.partition_by - ) - over = over.with_order_by(ir.map_ir(function) for ir in self.order_by) - return function(over) - - def with_order_by(self, order_by: t.Iterable[ExprIR], /) -> Self: - return common.replace(self, order_by=collect(order_by)) + yield from self.expr.iter_left() + for e in self.partition_by: + yield from e.iter_left() + yield self -class Len(ExprIR): +class Len(ExprIR, config=ExprIROptions.namespaced()): @property def is_scalar(self) -> bool: return True @@ -623,9 +445,6 @@ def name(self) -> str: def __repr__(self) -> str: return "len()" - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self) - class RootSelector(SelectorIR): """A single selector expression.""" @@ -639,9 +458,6 @@ def __repr__(self) -> str: def matches_column(self, name: str, dtype: DType) -> bool: return self.selector.matches_column(name, dtype) - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self) - class BinarySelector( _BinaryOp[LeftSelectorT, SelectorOperatorT, RightSelectorT], @@ -655,9 +471,6 @@ def matches_column(self, name: str, dtype: DType) -> bool: right = self.right.matches_column(name, dtype) return bool(self.op(left, right)) - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self) - class InvertSelector(SelectorIR, t.Generic[SelectorT]): __slots__ = ("selector",) @@ -669,14 +482,11 @@ def __repr__(self) -> str: def matches_column(self, name: str, dtype: DType) -> bool: return not self.selector.matches_column(name, dtype) - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self) - -class Ternary(ExprIR): +class TernaryExpr(ExprIR, child=("truthy", "falsy", "predicate")): """When-Then-Otherwise.""" - __slots__ = ("predicate", "truthy", "falsy") # noqa: RUF023 + __slots__ = ("truthy", "falsy", "predicate") # noqa: RUF023 predicate: ExprIR truthy: ExprIR falsy: ExprIR @@ -690,24 +500,5 @@ def __repr__(self) -> str: f".when({self.predicate!r}).then({self.truthy!r}).otherwise({self.falsy!r})" ) - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.truthy.iter_left() - yield from self.falsy.iter_left() - yield from self.predicate.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - yield from self.predicate.iter_right() - yield from self.falsy.iter_right() - yield from self.truthy.iter_right() - def iter_output_name(self) -> t.Iterator[ExprIR]: yield from self.truthy.iter_output_name() - - def map_ir(self, function: MapIR, /) -> ExprIR: - predicate = self.predicate.map_ir(function) - truthy = self.truthy.map_ir(function) - falsy = self.falsy.map_ir(function) - changed = common.replace(self, predicate=predicate, truthy=truthy, falsy=falsy) - return function(changed) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index d7ab345f81..2fdf92ef7d 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -40,17 +40,12 @@ from collections import deque from functools import lru_cache -from itertools import chain from typing import TYPE_CHECKING -from narwhals._plan import common -from narwhals._plan.common import ( - ExprIR, - Immutable, - NamedIR, - SelectorIR, - is_horizontal_reduction, -) +from narwhals._plan import common, meta +from narwhals._plan._guards import is_horizontal_reduction +from narwhals._plan._immutable import Immutable +from narwhals._plan.common import ExprIR, NamedIR, SelectorIR from narwhals._plan.exceptions import ( column_index_error, column_not_found_error, @@ -79,11 +74,10 @@ from narwhals.exceptions import ComputeError, InvalidOperationError if TYPE_CHECKING: - from collections.abc import Iterator, Sequence + from collections.abc import Iterable, Iterator, Sequence from typing_extensions import TypeAlias - from narwhals._plan.dummy import Expr from narwhals._plan.typing import Seq from narwhals.dtypes import DType @@ -153,10 +147,6 @@ def from_ir(ir: ExprIR, /) -> ExpansionFlags: has_exclude=has_exclude, ) - @classmethod - def from_expr(cls, expr: Expr, /) -> ExpansionFlags: - return cls.from_ir(expr._ir) - def with_multiple_columns(self) -> ExpansionFlags: return common.replace(self, multiple_columns=True) @@ -194,7 +184,7 @@ def into_named_irs(exprs: Seq[ExprIR], names: OutputNames) -> Seq[NamedIR]: def ensure_valid_exprs(exprs: Seq[ExprIR], schema: FrozenSchema) -> OutputNames: """Raise an appropriate error if we can't materialize.""" output_names = _ensure_output_names_unique(exprs) - root_names = _root_names_unique(exprs) + root_names = meta.root_names_unique(exprs) if not (set(schema.names).issuperset(root_names)): raise column_not_found_error(root_names, schema) return output_names @@ -207,19 +197,11 @@ def _ensure_output_names_unique(exprs: Seq[ExprIR]) -> OutputNames: return names -def _root_names_unique(exprs: Seq[ExprIR]) -> set[str]: - from narwhals._plan.meta import _expr_to_leaf_column_names_iter - - it = chain.from_iterable(_expr_to_leaf_column_names_iter(expr) for expr in exprs) - return set(it) - - def expand_function_inputs(origin: ExprIR, /, *, schema: FrozenSchema) -> ExprIR: def fn(child: ExprIR, /) -> ExprIR: if is_horizontal_reduction(child): - return child.with_input( - rewrite_projections(child.input, keys=(), schema=schema) - ) + rewrites = rewrite_projections(child.input, keys=(), schema=schema) + return common.replace(child, input=rewrites) return child return origin.map_ir(fn) @@ -245,18 +227,7 @@ def is_index_in_range(index: int, n_fields: int) -> bool: def remove_alias(origin: ExprIR, /) -> ExprIR: def fn(child: ExprIR, /) -> ExprIR: - if isinstance(child, Alias): - return child.expr - return child - - return origin.map_ir(fn) - - -def remove_exclude(origin: ExprIR, /) -> ExprIR: - def fn(child: ExprIR, /) -> ExprIR: - if isinstance(child, Exclude): - return child.expr - return child + return child.expr if isinstance(child, Alias) else child return origin.map_ir(fn) @@ -269,20 +240,14 @@ def replace_with_column( def fn(child: ExprIR, /) -> ExprIR: if isinstance(child, tp): return col(name) - if isinstance(child, Exclude): - return child.expr - return child + return child.expr if isinstance(child, Exclude) else child return origin.map_ir(fn) def replace_selector(ir: ExprIR, /, *, schema: FrozenSchema) -> ExprIR: - """Fully diverging from `polars`, we'll see how that goes.""" - def fn(child: ExprIR, /) -> ExprIR: - if isinstance(child, SelectorIR): - return expand_selector(child, schema=schema) - return child + return expand_selector(child, schema) if isinstance(child, SelectorIR) else child return ir.map_ir(fn) @@ -299,7 +264,7 @@ def selector_matches_column(selector: SelectorIR, name: str, dtype: DType, /) -> @lru_cache(maxsize=100) -def expand_selector(selector: SelectorIR, *, schema: FrozenSchema) -> Columns: +def expand_selector(selector: SelectorIR, schema: FrozenSchema) -> Columns: """Expand `selector` into `Columns`, within the context of `schema`.""" matches = selector_matches_column return cols(*(k for k, v in schema.items() if matches(selector, k, v))) @@ -319,64 +284,46 @@ def rewrite_projections( if flags.has_selector: expanded = replace_selector(expanded, schema=schema) flags = flags.with_multiple_columns() - result.extend( - replace_and_add_to_results( - expanded, keys=keys, col_names=schema.names, flags=flags - ) - ) + result.extend(iter_replace(expanded, keys, col_names=schema.names, flags=flags)) return tuple(result) -def replace_and_add_to_results( +def iter_replace( origin: ExprIR, /, keys: GroupByKeys, *, col_names: FrozenColumns, flags: ExpansionFlags, -) -> Seq[ExprIR]: - result: deque[ExprIR] = deque() +) -> Iterator[ExprIR]: if flags.has_nth: origin = replace_nth(origin, col_names) if flags.expands: it = (e for e in origin.iter_left() if isinstance(e, (Columns, IndexColumns))) if e := next(it, None): if isinstance(e, Columns): - exclude = prepare_excluded( - origin, keys=keys, has_exclude=flags.has_exclude - ) - result.extend(expand_columns(origin, e, exclude=exclude)) + if not _all_columns_match(origin, e): + msg = "expanding more than one `col` is not allowed" + raise ComputeError(msg) + names: Iterable[str] = e.names else: - exclude = prepare_excluded( - origin, keys=keys, has_exclude=flags.has_exclude - ) - result.extend( - expand_indices(origin, e, col_names=col_names, exclude=exclude) - ) + names = _iter_index_names(e, col_names) + exclude = prepare_excluded(origin, keys, flags) + yield from expand_column_selection(origin, type(e), names, exclude) elif flags.has_wildcard: - exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude) - result.extend(replace_wildcard(origin, col_names=col_names, exclude=exclude)) + exclude = prepare_excluded(origin, keys, flags) + yield from expand_column_selection(origin, All, col_names, exclude) else: - exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude) - expanded = rewrite_special_aliases(origin) - result.append(expanded) - return tuple(result) - - -def _iter_exclude_names(origin: ExprIR, /) -> Iterator[str]: - """Yield all excluded names in `origin`.""" - for e in origin.iter_left(): - if isinstance(e, Exclude): - yield from e.names + yield rewrite_special_aliases(origin) def prepare_excluded( - origin: ExprIR, /, keys: GroupByKeys, *, has_exclude: bool + origin: ExprIR, keys: GroupByKeys, flags: ExpansionFlags, / ) -> Excluded: """Huge simplification of https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L484-L555.""" exclude: set[str] = set() - if has_exclude: - exclude.update(_iter_exclude_names(origin)) + if flags.has_exclude: + exclude.update(*(e.names for e in origin.iter_left() if isinstance(e, Exclude))) for group_by_key in keys: if name := group_by_key.meta.output_name(raise_if_undetermined=False): exclude.add(name) @@ -388,52 +335,20 @@ def _all_columns_match(origin: ExprIR, /, columns: Columns) -> bool: return all(it) -def expand_columns( - origin: ExprIR, /, columns: Columns, *, exclude: Excluded -) -> Seq[ExprIR]: - if not _all_columns_match(origin, columns): - msg = "expanding more than one `col` is not allowed" - raise ComputeError(msg) - result: deque[ExprIR] = deque() - for name in columns.names: - if name not in exclude: - expanded = replace_with_column(origin, Columns, name) - expanded = rewrite_special_aliases(expanded) - result.append(expanded) - return tuple(result) - - -def expand_indices( - origin: ExprIR, - /, - indices: IndexColumns, - *, - col_names: FrozenColumns, - exclude: Excluded, -) -> Seq[ExprIR]: - result: deque[ExprIR] = deque() - n_fields = len(col_names) +def _iter_index_names(indices: IndexColumns, names: FrozenColumns, /) -> Iterator[str]: + n_fields = len(names) for index in indices.indices: if not is_index_in_range(index, n_fields): - raise column_index_error(index, col_names) - name = col_names[index] - if name not in exclude: - expanded = replace_with_column(origin, IndexColumns, name) - expanded = rewrite_special_aliases(expanded) - result.append(expanded) - return tuple(result) + raise column_index_error(index, names) + yield names[index] -def replace_wildcard( - origin: ExprIR, /, *, col_names: FrozenColumns, exclude: Excluded -) -> Seq[ExprIR]: - result: deque[ExprIR] = deque() - for name in col_names: +def expand_column_selection( + origin: ExprIR, tp: type[_ColumnSelection], /, names: Iterable[str], exclude: Excluded +) -> Iterator[ExprIR]: + for name in names: if name not in exclude: - expanded = replace_with_column(origin, All, name) - expanded = rewrite_special_aliases(expanded) - result.append(expanded) - return tuple(result) + yield rewrite_special_aliases(replace_with_column(origin, tp, name)) def rewrite_special_aliases(origin: ExprIR, /) -> ExprIR: @@ -444,8 +359,6 @@ def rewrite_special_aliases(origin: ExprIR, /) -> ExprIR: - Expanding all selections into `Column` - Dealing with `FunctionExpr.input` """ - from narwhals._plan import meta - if meta.has_expr_ir(origin, KeepName, RenameAlias): if isinstance(origin, KeepName): parent = origin.expr diff --git a/narwhals/_plan/expr_parsing.py b/narwhals/_plan/expr_parsing.py index b1ef22fbc9..1e450f2307 100644 --- a/narwhals/_plan/expr_parsing.py +++ b/narwhals/_plan/expr_parsing.py @@ -6,7 +6,7 @@ from itertools import chain from typing import TYPE_CHECKING, TypeVar -from narwhals._plan.common import is_expr, is_iterable_reject +from narwhals._plan._guards import is_expr, is_iterable_reject from narwhals._plan.exceptions import ( invalid_into_expr_error, is_iterable_pandas_error, @@ -22,7 +22,7 @@ from typing_extensions import TypeAlias, TypeIs from narwhals._plan.common import ExprIR - from narwhals._plan.typing import IntoExpr, IntoExprColumn, Seq + from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq from narwhals.typing import IntoDType T = TypeVar("T") @@ -100,7 +100,7 @@ def parse_into_expr_ir( def parse_into_seq_of_expr_ir( - first_input: IntoExpr | Iterable[IntoExpr] = (), + first_input: OneOrIterable[IntoExpr] = (), *more_inputs: IntoExpr | _RaisesInvalidIntoExprError, **named_inputs: IntoExpr, ) -> Seq[ExprIR]: @@ -109,7 +109,7 @@ def parse_into_seq_of_expr_ir( def parse_predicates_constraints_into_expr_ir( - first_predicate: IntoExprColumn | Iterable[IntoExprColumn] = (), + first_predicate: OneOrIterable[IntoExprColumn] = (), *more_predicates: IntoExprColumn | _RaisesInvalidIntoExprError, **constraints: IntoExpr, ) -> ExprIR: @@ -125,9 +125,7 @@ def parse_predicates_constraints_into_expr_ir( def _parse_into_iter_expr_ir( - first_input: IntoExpr | Iterable[IntoExpr], - *more_inputs: IntoExpr, - **named_inputs: IntoExpr, + first_input: OneOrIterable[IntoExpr], *more_inputs: IntoExpr, **named_inputs: IntoExpr ) -> Iterator[ExprIR]: if not _is_empty_sequence(first_input): # NOTE: These need to be separated to introduce an intersection type diff --git a/narwhals/_plan/expr_rewrites.py b/narwhals/_plan/expr_rewrites.py index 705dec3371..597e8afc21 100644 --- a/narwhals/_plan/expr_rewrites.py +++ b/narwhals/_plan/expr_rewrites.py @@ -5,14 +5,13 @@ from typing import TYPE_CHECKING from narwhals._plan import expr_parsing as parse -from narwhals._plan.common import ( - NamedIR, +from narwhals._plan._guards import ( is_aggregation, is_binary_expr, is_function_expr, is_window_expr, - map_ir, ) +from narwhals._plan.common import NamedIR, map_ir, replace from narwhals._plan.expr_expansion import into_named_irs, prepare_projection if TYPE_CHECKING: @@ -59,7 +58,7 @@ def rewrite_elementwise_over(window: ExprIR, /) -> ExprIR: ): func = window.expr parent, *args = func.input - return func.with_input((window.with_expr(parent), *args)) + return replace(func, input=(replace(window, expr=parent), *args)) return window @@ -84,6 +83,5 @@ def rewrite_binary_agg_over(window: ExprIR, /) -> ExprIR: and (is_aggregation(window.expr.right)) ): binary_expr = window.expr - rhs = window.expr.right - return binary_expr.with_right(window.with_expr(rhs)) + return replace(binary_expr, right=replace(window, expr=binary_expr.right)) return window diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 4c80849a89..570d75b4d0 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import Function +from narwhals._plan.common import Function, HorizontalFunction from narwhals._plan.exceptions import hist_bins_monotonic_error from narwhals._plan.options import FunctionFlags, FunctionOptions @@ -21,10 +21,48 @@ from narwhals.typing import FillNullStrategy -class Abs(Function, options=FunctionOptions.elementwise): ... +class CumAgg(Function, options=FunctionOptions.length_preserving): + __slots__ = ("reverse",) + reverse: bool + + +class RollingWindow(Function, options=FunctionOptions.length_preserving): + __slots__ = ("options",) + options: RollingOptionsFixedWindow + def to_function_expr(self, *inputs: ExprIR) -> RollingExpr[Self]: + from narwhals._plan.expr import RollingExpr + + options = self.function_options + return RollingExpr(input=inputs, function=self, options=options) -class Hist(Function, options=FunctionOptions.groupwise): + +# fmt: off +class Abs(Function, options=FunctionOptions.elementwise): ... +class NullCount(Function, options=FunctionOptions.aggregation): ... +class Exp(Function, options=FunctionOptions.elementwise): ... +class Sqrt(Function, options=FunctionOptions.elementwise): ... +class DropNulls(Function, options=FunctionOptions.row_separable): ... +class Mode(Function): ... +class Skew(Function, options=FunctionOptions.aggregation): ... +class Clip(Function, options=FunctionOptions.elementwise): ... +class CumCount(CumAgg): ... +class CumMin(CumAgg): ... +class CumMax(CumAgg): ... +class CumProd(CumAgg): ... +class CumSum(CumAgg): ... +class RollingSum(RollingWindow): ... +class RollingMean(RollingWindow): ... +class RollingVar(RollingWindow): ... +class RollingStd(RollingWindow): ... +class Diff(Function, options=FunctionOptions.length_preserving): ... +class Unique(Function): ... +class SumHorizontal(HorizontalFunction): ... +class MinHorizontal(HorizontalFunction): ... +class MaxHorizontal(HorizontalFunction): ... +class MeanHorizontal(HorizontalFunction): ... +# fmt: on +class Hist(Function): """Only supported for `Series` so far.""" __slots__ = ("include_breakpoint",) @@ -55,17 +93,11 @@ def __init__(self, *, bin_count: int = 10, include_breakpoint: bool = True) -> N object.__setattr__(self, "include_breakpoint", include_breakpoint) -class NullCount(Function, options=FunctionOptions.aggregation): ... - - class Log(Function, options=FunctionOptions.elementwise): __slots__ = ("base",) base: float -class Exp(Function, options=FunctionOptions.elementwise): ... - - class Pow(Function, options=FunctionOptions.elementwise): """N-ary (base, exponent).""" @@ -74,9 +106,6 @@ def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: return base, exponent -class Sqrt(Function, options=FunctionOptions.elementwise): ... - - class Kurtosis(Function, options=FunctionOptions.aggregation): __slots__ = ("bias", "fisher") fisher: bool @@ -102,89 +131,16 @@ class Shift(Function, options=FunctionOptions.length_preserving): n: int -class DropNulls(Function, options=FunctionOptions.row_separable): ... - - -class Mode(Function, options=FunctionOptions.groupwise): ... - - -class Skew(Function, options=FunctionOptions.aggregation): ... - - -class Rank(Function, options=FunctionOptions.groupwise): +class Rank(Function): __slots__ = ("options",) options: RankOptions -class Clip(Function, options=FunctionOptions.elementwise): ... - - -class CumAgg(Function, options=FunctionOptions.length_preserving): - __slots__ = ("reverse",) - reverse: bool - - -class RollingWindow(Function, options=FunctionOptions.length_preserving): - __slots__ = ("options",) - options: RollingOptionsFixedWindow - - def to_function_expr(self, *inputs: ExprIR) -> RollingExpr[Self]: - from narwhals._plan.expr import RollingExpr - - options = self.function_options - return RollingExpr(input=inputs, function=self, options=options) - - -class CumCount(CumAgg): ... - - -class CumMin(CumAgg): ... - - -class CumMax(CumAgg): ... - - -class CumProd(CumAgg): ... - - -class CumSum(CumAgg): ... - - -class RollingSum(RollingWindow): ... - - -class RollingMean(RollingWindow): ... - - -class RollingVar(RollingWindow): ... - - -class RollingStd(RollingWindow): ... - - -class Diff(Function, options=FunctionOptions.length_preserving): ... - - -class Unique(Function, options=FunctionOptions.groupwise): ... - - class Round(Function, options=FunctionOptions.elementwise): __slots__ = ("decimals",) decimals: int -class SumHorizontal(Function, options=FunctionOptions.horizontal): ... - - -class MinHorizontal(Function, options=FunctionOptions.horizontal): ... - - -class MaxHorizontal(Function, options=FunctionOptions.horizontal): ... - - -class MeanHorizontal(Function, options=FunctionOptions.horizontal): ... - - class EwmMean(Function, options=FunctionOptions.length_preserving): __slots__ = ("options",) options: EWMOptions @@ -197,7 +153,7 @@ class ReplaceStrict(Function, options=FunctionOptions.elementwise): return_dtype: DType | None -class GatherEvery(Function, options=FunctionOptions.groupwise): +class GatherEvery(Function): __slots__ = ("n", "offset") n: int offset: int diff --git a/narwhals/_plan/lists.py b/narwhals/_plan/lists.py index 046db5615d..f4a45f217f 100644 --- a/narwhals/_plan/lists.py +++ b/narwhals/_plan/lists.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from narwhals._plan.common import ExprNamespace, Function, IRNamespace from narwhals._plan.options import FunctionOptions @@ -9,15 +9,12 @@ from narwhals._plan.dummy import Expr +# fmt: off class ListFunction(Function, accessor="list"): ... - - class Len(ListFunction, options=FunctionOptions.elementwise): ... - - +# fmt: on class IRListNamespace(IRNamespace): - def len(self) -> Len: - return Len() + len: ClassVar = Len class ExprListNamespace(ExprNamespace[IRListNamespace]): diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index e0dba305fa..94f5a9a5b4 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING, Any, Generic -from narwhals._plan.common import Immutable +from narwhals._plan._guards import is_literal +from narwhals._plan._immutable import Immutable from narwhals._plan.typing import LiteralT, NativeSeriesT, NonNestedLiteralT if TYPE_CHECKING: @@ -74,31 +75,7 @@ def unwrap(self) -> Series[NativeSeriesT]: return self.value -def _is_scalar( - obj: ScalarLiteral[NonNestedLiteralT] | Any, -) -> TypeIs[ScalarLiteral[NonNestedLiteralT]]: - return isinstance(obj, ScalarLiteral) - - -def _is_series( - obj: SeriesLiteral[NativeSeriesT] | Any, -) -> TypeIs[SeriesLiteral[NativeSeriesT]]: - return isinstance(obj, SeriesLiteral) - - -def is_literal(obj: Literal[LiteralT] | Any) -> TypeIs[Literal[LiteralT]]: - from narwhals._plan.expr import Literal - - return isinstance(obj, Literal) - - def is_literal_scalar( obj: Literal[NonNestedLiteralT] | Any, ) -> TypeIs[Literal[NonNestedLiteralT]]: - return is_literal(obj) and _is_scalar(obj.value) - - -def is_literal_series( - obj: Literal[Series[NativeSeriesT]] | Any, -) -> TypeIs[Literal[Series[NativeSeriesT]]]: - return is_literal(obj) and _is_series(obj.value) + return is_literal(obj) and isinstance(obj.value, ScalarLiteral) diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 1a5cbb8eac..ce78165c00 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -7,6 +7,7 @@ from __future__ import annotations from functools import lru_cache +from itertools import chain from typing import TYPE_CHECKING, Literal, overload from narwhals._plan.common import IRNamespace @@ -14,7 +15,7 @@ from narwhals.utils import Version if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Iterable, Iterator from typing_extensions import TypeIs @@ -74,16 +75,7 @@ def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: def root_names(self) -> list[str]: """Get the root column names.""" - return _expr_to_leaf_column_names(self._ir) - - -def _expr_to_leaf_column_names(ir: ExprIR) -> list[str]: - """After a lot of indirection, [root_names] resolves [here]. - - [root_names]: https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/meta.rs#L27-L30 - [here]: https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/utils.rs#L171-L195 - """ - return list(_expr_to_leaf_column_names_iter(ir)) + return list(_expr_to_leaf_column_names_iter(self._ir)) def _expr_to_leaf_column_names_iter(ir: ExprIR) -> Iterator[str]: @@ -121,6 +113,10 @@ def _expr_to_leaf_column_name(ir: ExprIR) -> str | ComputeError: return ComputeError(msg) +def root_names_unique(irs: Iterable[ExprIR], /) -> set[str]: + return set(chain.from_iterable(_expr_to_leaf_column_names_iter(e) for e in irs)) + + @lru_cache(maxsize=32) def _expr_output_name(ir: ExprIR) -> str | ComputeError: from narwhals._plan import expr @@ -186,26 +182,22 @@ def is_column(ir: ExprIR) -> TypeIs[Column]: def _is_literal(ir: ExprIR, *, allow_aliasing: bool) -> bool: from narwhals._plan import expr - from narwhals._plan.literal import ScalarLiteral - - if isinstance(ir, expr.Literal): - return True - if isinstance(ir, expr.Alias): - return allow_aliasing - if isinstance(ir, expr.Cast): - return ( - isinstance(ir.expr, expr.Literal) - and isinstance(ir.expr, ScalarLiteral) + from narwhals._plan.literal import is_literal_scalar + + return ( + isinstance(ir, expr.Literal) + or (allow_aliasing and isinstance(ir, expr.Alias)) + or ( + isinstance(ir, expr.Cast) + and is_literal_scalar(ir.expr) and isinstance(ir.expr.dtype, Version.MAIN.dtypes.Datetime) ) - return False + ) def _is_column_selection(ir: ExprIR, *, allow_aliasing: bool) -> bool: from narwhals._plan import expr - if isinstance(ir, (expr.Column, expr._ColumnSelection, expr.SelectorIR)): - return True - if isinstance(ir, (expr.Alias, expr.KeepName, expr.RenameAlias)): - return allow_aliasing - return False + return isinstance(ir, (expr.Column, expr._ColumnSelection, expr.SelectorIR)) or ( + allow_aliasing and isinstance(ir, (expr.Alias, expr.KeepName, expr.RenameAlias)) + ) diff --git a/narwhals/_plan/name.py b/narwhals/_plan/name.py index 7c695599bc..4147f20450 100644 --- a/narwhals/_plan/name.py +++ b/narwhals/_plan/name.py @@ -3,21 +3,17 @@ from typing import TYPE_CHECKING from narwhals._plan import common -from narwhals._plan.common import ExprIR, ExprNamespace, Immutable, IRNamespace +from narwhals._plan._immutable import Immutable +from narwhals._plan.options import ExprIROptions if TYPE_CHECKING: - from collections.abc import Iterator - - from typing_extensions import Self - from narwhals._compliant.typing import AliasName from narwhals._plan.dummy import Expr - from narwhals._plan.typing import MapIR -class KeepName(ExprIR): +class KeepName(common.ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): __slots__ = ("expr",) - expr: ExprIR + expr: common.ExprIR @property def is_scalar(self) -> bool: @@ -26,24 +22,10 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: return f"{self.expr!r}.name.keep()" - def iter_left(self) -> Iterator[ExprIR]: - yield from self.expr.iter_left() - yield self - - def iter_right(self) -> Iterator[ExprIR]: - yield self - yield from self.expr.iter_right() - - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self.with_expr(self.expr.map_ir(function))) - - def with_expr(self, expr: ExprIR, /) -> Self: - return common.replace(self, expr=expr) - -class RenameAlias(ExprIR): +class RenameAlias(common.ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): __slots__ = ("expr", "function") - expr: ExprIR + expr: common.ExprIR function: AliasName @property @@ -53,20 +35,6 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: return f".rename_alias({self.expr!r})" - def iter_left(self) -> Iterator[ExprIR]: - yield from self.expr.iter_left() - yield self - - def iter_right(self) -> Iterator[ExprIR]: - yield self - yield from self.expr.iter_right() - - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self.with_expr(self.expr.map_ir(function))) - - def with_expr(self, expr: ExprIR, /) -> Self: - return common.replace(self, expr=expr) - class Prefix(Immutable): __slots__ = ("prefix",) @@ -84,7 +52,7 @@ def __call__(self, name: str, /) -> str: return f"{name}{self.suffix}" -class IRNameNamespace(IRNamespace): +class IRNameNamespace(common.IRNamespace): def keep(self) -> KeepName: return KeepName(expr=self._ir) @@ -104,7 +72,7 @@ def to_uppercase(self) -> RenameAlias: return self.map(str.upper) -class ExprNameNamespace(ExprNamespace[IRNameNamespace]): +class ExprNameNamespace(common.ExprNamespace[IRNameNamespace]): @property def _ir_namespace(self) -> type[IRNameNamespace]: return IRNameNamespace diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index 09d072e7bd..78b33b042f 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -1,15 +1,15 @@ from __future__ import annotations -import operator +import operator as op from typing import TYPE_CHECKING -from narwhals._plan.common import Immutable, is_function_expr +from narwhals._plan._guards import is_function_expr +from narwhals._plan._immutable import Immutable from narwhals._plan.exceptions import ( binary_expr_length_changing_error, binary_expr_multi_output_error, binary_expr_shape_error, ) -from narwhals._plan.expr import BinarySelector if TYPE_CHECKING: from typing import Any, ClassVar @@ -28,30 +28,19 @@ class Operator(Immutable): - _op: ClassVar[OperatorFn] + _func: ClassVar[OperatorFn] + _symbol: ClassVar[str] def __repr__(self) -> str: - tp = type(self) - if tp in {Operator, SelectorOperator}: - return tp.__name__ - m = { - Eq: "==", - NotEq: "!=", - Lt: "<", - LtEq: "<=", - Gt: ">", - GtEq: ">=", - Add: "+", - Sub: "-", - Multiply: "*", - TrueDivide: "/", - FloorDivide: "//", - Modulus: "%", - And: "&", - Or: "|", - ExclusiveOr: "^", - } - return m[tp] + return self._symbol + + def __init_subclass__( + cls, *args: Any, func: OperatorFn | None, symbol: str = "", **kwds: Any + ) -> None: + super().__init_subclass__(*args, **kwds) + if func: + cls._func = func + cls._symbol = symbol or cls.__name__ def to_binary_expr( self, left: LeftT, right: RightT, / @@ -72,16 +61,14 @@ def to_binary_expr( def __call__(self, lhs: Any, rhs: Any) -> Any: """Apply binary operator to `left`, `right` operands.""" - return self.__class__._op(lhs, rhs) + return self.__class__._func(lhs, rhs) def _is_filtration(ir: ExprIR) -> bool: - if not ir.is_scalar and is_function_expr(ir): - return not ir.options.is_elementwise() - return False + return not ir.is_scalar and is_function_expr(ir) and not ir.options.is_elementwise() -class SelectorOperator(Operator): +class SelectorOperator(Operator, func=None): """Operators that can *also* be used in selectors.""" def to_binary_selector( @@ -92,61 +79,20 @@ def to_binary_selector( return BinarySelector(left=left, op=self, right=right) -class Eq(Operator): - _op = operator.eq - - -class NotEq(Operator): - _op = operator.ne - - -class Lt(Operator): - _op = operator.le - - -class LtEq(Operator): - _op = operator.lt - - -class Gt(Operator): - _op = operator.gt - - -class GtEq(Operator): - _op = operator.ge - - -class Add(Operator): - _op = operator.add - - -class Sub(SelectorOperator): - _op = operator.sub - - -class Multiply(Operator): - _op = operator.mul - - -class TrueDivide(Operator): - _op = operator.truediv - - -class FloorDivide(Operator): - _op = operator.floordiv - - -class Modulus(Operator): - _op = operator.mod - - -class And(SelectorOperator): - _op = operator.and_ - - -class Or(SelectorOperator): - _op = operator.or_ - - -class ExclusiveOr(SelectorOperator): - _op = operator.xor +# fmt: off +class Eq(Operator, func=op.eq, symbol="=="): ... +class NotEq(Operator, func=op.ne, symbol="!="): ... +class Lt(Operator, func=op.le, symbol="<"): ... +class LtEq(Operator, func=op.lt, symbol="<="): ... +class Gt(Operator, func=op.gt, symbol=">"): ... +class GtEq(Operator, func=op.ge, symbol=">="): ... +class Add(Operator, func=op.add, symbol="+"): ... +class Sub(SelectorOperator, func=op.sub, symbol="-"): ... +class Multiply(Operator, func=op.mul, symbol="*"): ... +class TrueDivide(Operator, func=op.truediv, symbol="/"): ... +class FloorDivide(Operator, func=op.floordiv, symbol="//"): ... +class Modulus(Operator, func=op.mod, symbol="%"): ... +class And(SelectorOperator, func=op.and_, symbol="&"): ... +class Or(SelectorOperator, func=op.or_, symbol="|"): ... +class ExclusiveOr(SelectorOperator, func=op.xor, symbol="^"): ... +# fmt: on diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index ca6cf91a04..6f77674dff 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -4,16 +4,19 @@ from itertools import repeat from typing import TYPE_CHECKING, Literal -from narwhals._plan.common import Immutable +from narwhals._plan._immutable import Immutable if TYPE_CHECKING: from collections.abc import Iterable, Sequence import pyarrow.compute as pc + from typing_extensions import Self, TypeAlias - from narwhals._plan.typing import Seq + from narwhals._plan.typing import Accessor, OneOrIterable, Seq from narwhals.typing import RankMethod +DispatchOrigin: TypeAlias = Literal["expr", "__narwhals_namespace__"] + class FunctionFlags(enum.Flag): ALLOW_GROUP_AWARE = 1 << 0 @@ -184,7 +187,7 @@ def __repr__(self) -> str: @staticmethod def parse( - *, descending: bool | Iterable[bool], nulls_last: bool | Iterable[bool] + *, descending: OneOrIterable[bool], nulls_last: OneOrIterable[bool] ) -> SortMultipleOptions: desc = (descending,) if isinstance(descending, bool) else tuple(descending) nulls = (nulls_last,) if isinstance(nulls_last, bool) else tuple(nulls_last) @@ -263,3 +266,56 @@ def rolling_options( center=center, fn_params=ddof if ddof is None else RollingVarParams(ddof=ddof), ) + + +class _BaseIROptions(Immutable): + __slots__ = ("origin", "override_name") + origin: DispatchOrigin + override_name: str + + def __repr__(self) -> str: + return self.__str__() + + @classmethod + def default(cls) -> Self: + return cls(origin="expr", override_name="") + + @classmethod + def renamed(cls, name: str, /) -> Self: + from narwhals._plan.common import replace + + return replace(cls.default(), override_name=name) + + @classmethod + def namespaced(cls, override_name: str = "", /) -> Self: + from narwhals._plan.common import replace + + return replace( + cls.default(), origin="__narwhals_namespace__", override_name=override_name + ) + + +class ExprIROptions(_BaseIROptions): + __slots__ = (*_BaseIROptions.__slots__, "allow_dispatch") + allow_dispatch: bool + + @classmethod + def default(cls) -> Self: + return cls(origin="expr", override_name="", allow_dispatch=True) + + @staticmethod + def no_dispatch() -> ExprIROptions: + return ExprIROptions(origin="expr", override_name="", allow_dispatch=False) + + +class FunctionExprOptions(_BaseIROptions): + __slots__ = (*_BaseIROptions.__slots__, "accessor_name") + accessor_name: Accessor | None + """Namespace accessor name, if any.""" + + @classmethod + def default(cls) -> Self: + return cls(origin="expr", override_name="", accessor_name=None) + + +FEOptions = FunctionExprOptions diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 821e7a338a..951dfa850e 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -1,27 +1,30 @@ from __future__ import annotations -from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence, Sized -from typing import TYPE_CHECKING, Any, ClassVar, Literal, Protocol, overload +from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized +from typing import TYPE_CHECKING, Any, Literal, Protocol, overload -from narwhals._plan import aggregation as agg, boolean, expr, functions as F, strings -from narwhals._plan.common import ExprIR, Function, NamedIR, flatten_hash_safe +from narwhals._plan.common import ExprIR, NamedIR, flatten_hash_safe from narwhals._plan.typing import NativeDataFrameT, NativeFrameT, NativeSeriesT, Seq from narwhals._typing_compat import TypeVar -from narwhals._utils import Version, _hasattr_static +from narwhals._utils import Version if TYPE_CHECKING: from typing_extensions import Self, TypeAlias, TypeIs + from narwhals._plan import aggregation as agg, boolean, expr, functions as F + from narwhals._plan.boolean import IsBetween, IsFinite, IsNan, IsNull, Not from narwhals._plan.dummy import BaseFrame, DataFrame, Series - from narwhals._plan.expr import FunctionExpr, RangeExpr + from narwhals._plan.expr import BinaryExpr, FunctionExpr, RangeExpr from narwhals._plan.options import SortMultipleOptions from narwhals._plan.ranges import IntRange + from narwhals._plan.strings import ConcatStr + from narwhals._plan.typing import OneOrIterable from narwhals.dtypes import DType - from narwhals.schema import Schema from narwhals.typing import ( ConcatMethod, Into1DArray, IntoDType, + IntoSchema, NonNestedLiteral, PythonLiteral, _1DArray, @@ -29,7 +32,6 @@ T = TypeVar("T") R_co = TypeVar("R_co", covariant=True) -OneOrIterable: TypeAlias = "T | Iterable[T]" LengthT = TypeVar("LengthT") NativeT_co = TypeVar("NativeT_co", covariant=True, default=Any) @@ -69,6 +71,22 @@ LazyExprT_co = TypeVar("LazyExprT_co", bound=LazyExprAny, covariant=True) LazyScalarT_co = TypeVar("LazyScalarT_co", bound=LazyScalarAny, covariant=True) +Ctx: TypeAlias = "ExprDispatch[FrameT_contra, R_co, NamespaceAny]" +"""Type of an unknown expression dispatch context. + +- `FrameT_contra`: Compliant data/lazyframe +- `R_co`: Upper bound return type of the context +""" + + +class SupportsNarwhalsNamespace(Protocol[NamespaceT_co]): + def __narwhals_namespace__(self) -> NamespaceT_co: ... + + +def namespace(obj: SupportsNarwhalsNamespace[NamespaceT_co], /) -> NamespaceT_co: + """Return the compliant namespace.""" + return obj.__narwhals_namespace__() + # NOTE: Unlike the version in `nw._utils`, here `.version` it is public class StoresVersion(Protocol): @@ -143,130 +161,11 @@ def _length_required( class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]): - _DISPATCH: ClassVar[Mapping[type[ExprIR], Callable[[Any, ExprIR, Any, str], Any]]] = { - expr.Column: lambda self, node, frame, name: self.__narwhals_namespace__().col( - node, frame, name - ), - expr.Literal: lambda self, node, frame, name: self.__narwhals_namespace__().lit( - node, frame, name - ), - expr.Len: lambda self, node, frame, name: self.__narwhals_namespace__().len( - node, frame, name - ), - expr.Cast: lambda self, node, frame, name: self.cast(node, frame, name), - expr.Sort: lambda self, node, frame, name: self.sort(node, frame, name), - expr.SortBy: lambda self, node, frame, name: self.sort_by(node, frame, name), - expr.Filter: lambda self, node, frame, name: self.filter(node, frame, name), - agg.First: lambda self, node, frame, name: self.first(node, frame, name), - agg.Last: lambda self, node, frame, name: self.last(node, frame, name), - agg.ArgMin: lambda self, node, frame, name: self.arg_min(node, frame, name), - agg.ArgMax: lambda self, node, frame, name: self.arg_max(node, frame, name), - agg.Sum: lambda self, node, frame, name: self.sum(node, frame, name), - agg.NUnique: lambda self, node, frame, name: self.n_unique(node, frame, name), - agg.Std: lambda self, node, frame, name: self.std(node, frame, name), - agg.Var: lambda self, node, frame, name: self.var(node, frame, name), - agg.Quantile: lambda self, node, frame, name: self.quantile(node, frame, name), - agg.Count: lambda self, node, frame, name: self.count(node, frame, name), - agg.Max: lambda self, node, frame, name: self.max(node, frame, name), - agg.Mean: lambda self, node, frame, name: self.mean(node, frame, name), - agg.Median: lambda self, node, frame, name: self.median(node, frame, name), - agg.Min: lambda self, node, frame, name: self.min(node, frame, name), - expr.BinaryExpr: lambda self, node, frame, name: self.binary_expr( - node, frame, name - ), - expr.RollingExpr: lambda self, node, frame, name: self.rolling_expr( - node, frame, name - ), - expr.AnonymousExpr: lambda self, node, frame, name: self.map_batches( - node, frame, name - ), - expr.FunctionExpr: lambda self, node, frame, name: self._dispatch_function( - node, frame, name - ), - # NOTE: Keeping it simple for now - # When adding other `*_range` functions, this should instead map to `range_expr` - expr.RangeExpr: lambda self, - node, - frame, - name: self.__narwhals_namespace__().int_range(node, frame, name), - expr.OrderedWindowExpr: lambda self, node, frame, name: self.over_ordered( - node, frame, name - ), - expr.WindowExpr: lambda self, node, frame, name: self.over(node, frame, name), - expr.Ternary: lambda self, node, frame, name: self.ternary_expr( - node, frame, name - ), - } - _DISPATCH_FUNCTION: ClassVar[ - Mapping[type[Function], Callable[[Any, FunctionExpr, Any, str], Any]] - ] = { - boolean.AnyHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().any_horizontal(node, frame, name), - boolean.AllHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().all_horizontal(node, frame, name), - F.SumHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().sum_horizontal(node, frame, name), - F.MinHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().min_horizontal(node, frame, name), - F.MaxHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().max_horizontal(node, frame, name), - F.MeanHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().mean_horizontal(node, frame, name), - strings.ConcatHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().concat_str(node, frame, name), - F.Pow: lambda self, node, frame, name: self.pow(node, frame, name), - F.FillNull: lambda self, node, frame, name: self.fill_null(node, frame, name), - boolean.IsBetween: lambda self, node, frame, name: self.is_between( - node, frame, name - ), - boolean.IsFinite: lambda self, node, frame, name: self.is_finite( - node, frame, name - ), - boolean.IsNan: lambda self, node, frame, name: self.is_nan(node, frame, name), - boolean.IsNull: lambda self, node, frame, name: self.is_null(node, frame, name), - boolean.Not: lambda self, node, frame, name: self.not_(node, frame, name), - boolean.Any: lambda self, node, frame, name: self.any(node, frame, name), - boolean.All: lambda self, node, frame, name: self.all(node, frame, name), - } - - def _dispatch(self, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: - if (method := self._DISPATCH.get(node.__class__)) and ( - result := method(self, node, frame, name) - ): - return result # type: ignore[no-any-return] - msg = f"Support for {node.__class__.__name__!r} is not yet implemented, got:\n{node!r}" - raise NotImplementedError(msg) - - def _dispatch_function( - self, node: FunctionExpr, frame: FrameT_contra, name: str - ) -> R_co: - fn = node.function - if (method := self._DISPATCH_FUNCTION.get(fn.__class__)) and ( - result := method(self, node, frame, name) - ): - return result # type: ignore[no-any-return] - msg = f"Support for {fn.__class__.__name__!r} is not yet implemented, got:\n{node!r}" - raise NotImplementedError(msg) - @classmethod def from_ir(cls, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: obj = cls.__new__(cls) obj._version = frame.version - return obj._dispatch(node, frame, name) + return node.dispatch(obj, frame, name) @classmethod def from_named_ir(cls, named_ir: NamedIR[ExprIR], frame: FrameT_contra) -> R_co: @@ -284,41 +183,35 @@ class CompliantExpr(StoresVersion, Protocol[FrameT_contra, SeriesT_co]): @property def name(self) -> str: ... - @classmethod def from_native( cls, native: Any, name: str = "", /, version: Version = Version.MAIN ) -> Self: ... - def _with_native(self, native: Any, name: str, /) -> Self: return self.from_native(native, name or self.name, self.version) # series & scalar def cast(self, node: expr.Cast, frame: FrameT_contra, name: str) -> Self: ... def pow(self, node: FunctionExpr[F.Pow], frame: FrameT_contra, name: str) -> Self: ... - def not_( - self, node: FunctionExpr[boolean.Not], frame: FrameT_contra, name: str - ) -> Self: ... + def not_(self, node: FunctionExpr[Not], frame: FrameT_contra, name: str) -> Self: ... def fill_null( self, node: FunctionExpr[F.FillNull], frame: FrameT_contra, name: str ) -> Self: ... def is_between( - self, node: FunctionExpr[boolean.IsBetween], frame: FrameT_contra, name: str + self, node: FunctionExpr[IsBetween], frame: FrameT_contra, name: str ) -> Self: ... def is_finite( - self, node: FunctionExpr[boolean.IsFinite], frame: FrameT_contra, name: str + self, node: FunctionExpr[IsFinite], frame: FrameT_contra, name: str ) -> Self: ... def is_nan( - self, node: FunctionExpr[boolean.IsNan], frame: FrameT_contra, name: str + self, node: FunctionExpr[IsNan], frame: FrameT_contra, name: str ) -> Self: ... def is_null( - self, node: FunctionExpr[boolean.IsNull], frame: FrameT_contra, name: str - ) -> Self: ... - def binary_expr( - self, node: expr.BinaryExpr, frame: FrameT_contra, name: str + self, node: FunctionExpr[IsNull], frame: FrameT_contra, name: str ) -> Self: ... + def binary_expr(self, node: BinaryExpr, frame: FrameT_contra, name: str) -> Self: ... def ternary_expr( - self, node: expr.Ternary, frame: FrameT_contra, name: str + self, node: expr.TernaryExpr, frame: FrameT_contra, name: str ) -> Self: ... def over(self, node: expr.WindowExpr, frame: FrameT_contra, name: str) -> Self: ... # NOTE: `Scalar` is returned **only** for un-partitioned `OrderableAggExpr` @@ -406,7 +299,6 @@ def from_python( dtype: IntoDType | None, version: Version, ) -> Self: ... - def _with_evaluated(self, evaluated: Any, name: str) -> Self: """Expr is based on a series having these via accessors, but a scalar needs to keep passing through.""" cls = type(self) @@ -526,7 +418,7 @@ def concat( self, items: Iterable[ConcatT2], *, how: Literal["vertical"] ) -> ConcatT2: ... def concat( - self, items: Iterable[ConcatT1] | Iterable[ConcatT2], *, how: ConcatMethod + self, items: Iterable[ConcatT1 | ConcatT2], *, how: ConcatMethod ) -> ConcatT1 | ConcatT2: ... @@ -536,7 +428,7 @@ def _concat_diagonal(self, items: Iterable[ConcatT1], /) -> ConcatT1: ... # but that is only available privately def _concat_horizontal(self, items: Iterable[ConcatT1 | ConcatT2], /) -> ConcatT1: ... def _concat_vertical( - self, items: Iterable[ConcatT1] | Iterable[ConcatT2], / + self, items: Iterable[ConcatT1 | ConcatT2], / ) -> ConcatT1 | ConcatT2: ... @@ -571,7 +463,7 @@ def mean_horizontal( self, node: FunctionExpr[F.MeanHorizontal], frame: FrameT, name: str ) -> ExprT_co | ScalarT_co: ... def concat_str( - self, node: FunctionExpr[strings.ConcatHorizontal], frame: FrameT, name: str + self, node: FunctionExpr[ConcatStr], frame: FrameT, name: str ) -> ExprT_co | ScalarT_co: ... def int_range( self, node: RangeExpr[IntRange], frame: FrameT, name: str @@ -605,17 +497,9 @@ def lit( def lit( self, node: expr.Literal[Series[Any]], frame: EagerDataFrameT, name: str ) -> EagerExprT_co: ... - @overload - def lit( - self, - node: expr.Literal[NonNestedLiteral] | expr.Literal[Series[Any]], - frame: EagerDataFrameT, - name: str, - ) -> EagerExprT_co | EagerScalarT_co: ... def lit( self, node: expr.Literal[Any], frame: EagerDataFrameT, name: str ) -> EagerExprT_co | EagerScalarT_co: ... - def len(self, node: expr.Len, frame: EagerDataFrameT, name: str) -> EagerScalarT_co: return self._scalar.from_python( len(frame), name or node.name, dtype=None, version=frame.version @@ -645,7 +529,6 @@ def native(self) -> NativeFrameT: @property def columns(self) -> list[str]: ... def to_narwhals(self) -> BaseFrame[NativeFrameT]: ... - @classmethod def from_native(cls, native: NativeFrameT, /, version: Version) -> Self: obj = cls.__new__(cls) @@ -672,15 +555,9 @@ class CompliantDataFrame( ): @classmethod def from_dict( - cls, - data: Mapping[str, Any], - /, - *, - schema: Mapping[str, DType] | Schema | None = None, + cls, data: Mapping[str, Any], /, *, schema: IntoSchema | None = None ) -> Self: ... - def to_narwhals(self) -> DataFrame[NativeDataFrameT, NativeSeriesT]: ... - @overload def to_dict(self, *, as_series: Literal[True]) -> dict[str, SeriesT]: ... @overload @@ -689,11 +566,9 @@ def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... def to_dict( self, *, as_series: bool ) -> dict[str, SeriesT] | dict[str, list[Any]]: ... - def to_dict( self, *, as_series: bool ) -> dict[str, SeriesT] | dict[str, list[Any]]: ... - def __len__(self) -> int: ... def with_row_index(self, name: str) -> Self: ... @@ -704,12 +579,10 @@ class EagerDataFrame( ): def __narwhals_namespace__(self) -> EagerNamespace[Self, SeriesT, Any, Any]: ... def select(self, irs: Seq[NamedIR]) -> Self: - ns = self.__narwhals_namespace__() - return ns._concat_horizontal(self._evaluate_irs(irs)) + return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) def with_columns(self, irs: Seq[NamedIR]) -> Self: - ns = self.__narwhals_namespace__() - return ns._concat_horizontal(self._evaluate_irs(irs)) + return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) class CompliantSeries(StoresVersion, Protocol[NativeSeriesT]): @@ -725,7 +598,6 @@ def native(self) -> NativeSeriesT: @property def dtype(self) -> DType: ... - @property def name(self) -> str: return self._name @@ -739,9 +611,6 @@ def to_narwhals(self) -> Series[NativeSeriesT]: def from_native( cls, native: NativeSeriesT, name: str = "", /, *, version: Version = Version.MAIN ) -> Self: - name = name or ( - getattr(native, "name", name) if _hasattr_static(native, "name") else name - ) obj = cls.__new__(cls) obj._native = native obj._name = name @@ -752,7 +621,6 @@ def from_native( def from_numpy( cls, data: Into1DArray, name: str = "", /, *, version: Version = Version.MAIN ) -> Self: ... - @classmethod def from_iterable( cls, @@ -762,7 +630,6 @@ def from_iterable( name: str = "", dtype: IntoDType | None = None, ) -> Self: ... - def _with_native(self, native: NativeSeriesT) -> Self: return self.from_native(native, self.name, version=self.version) diff --git a/narwhals/_plan/ranges.py b/narwhals/_plan/ranges.py index 4414afabf7..4f8e49b531 100644 --- a/narwhals/_plan/ranges.py +++ b/narwhals/_plan/ranges.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from narwhals._plan.common import ExprIR, Function -from narwhals._plan.options import FunctionOptions +from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: from typing_extensions import Self @@ -12,7 +12,7 @@ from narwhals.dtypes import IntegerType -class RangeFunction(Function): +class RangeFunction(Function, config=FEOptions.namespaced()): def to_function_expr(self, *inputs: ExprIR) -> RangeExpr[Self]: from narwhals._plan.expr import RangeExpr diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index 17b8416285..69c1b5a2b3 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -7,7 +7,8 @@ from types import MappingProxyType from typing import TYPE_CHECKING, Any, TypeVar, overload -from narwhals._plan.common import _IMMUTABLE_HASH_NAME, Immutable, NamedIR +from narwhals._plan._immutable import _IMMUTABLE_HASH_NAME, Immutable +from narwhals._plan.common import NamedIR from narwhals.dtypes import Unknown if TYPE_CHECKING: @@ -95,8 +96,7 @@ def _from_mapping(mapping: MappingProxyType[str, DType], /) -> FrozenSchema: @staticmethod def _from_hash_safe(items: _FrozenSchemaHash, /) -> FrozenSchema: - clone = MappingProxyType(dict(items)) - return FrozenSchema._from_mapping(clone) + return FrozenSchema._from_mapping(MappingProxyType(dict(items))) def items(self) -> ItemsView[str, DType]: return self._mapping.items() diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/selectors.py index 3cd7666ddc..4aa5f58a3d 100644 --- a/narwhals/_plan/selectors.py +++ b/narwhals/_plan/selectors.py @@ -9,16 +9,18 @@ import re from typing import TYPE_CHECKING -from narwhals._plan.common import Immutable, flatten_hash_safe +from narwhals._plan._immutable import Immutable +from narwhals._plan.common import flatten_hash_safe from narwhals._utils import Version, _parse_time_unit_and_time_zone if TYPE_CHECKING: - from collections.abc import Iterable, Iterator + from collections.abc import Iterator from datetime import timezone from typing import TypeVar from narwhals._plan import dummy from narwhals._plan.expr import RootSelector + from narwhals._plan.typing import OneOrIterable from narwhals.dtypes import DType from narwhals.typing import TimeUnit @@ -50,9 +52,7 @@ class ByDType(Selector): dtypes: frozenset[DType | type[DType]] @staticmethod - def from_dtypes( - *dtypes: DType | type[DType] | Iterable[DType | type[DType]], - ) -> ByDType: + def from_dtypes(*dtypes: OneOrIterable[DType | type[DType]]) -> ByDType: return ByDType(dtypes=frozenset(flatten_hash_safe(dtypes))) def __repr__(self) -> str: @@ -95,8 +95,8 @@ class Datetime(Selector): @staticmethod def from_time_unit_and_time_zone( - time_unit: TimeUnit | Iterable[TimeUnit] | None, - time_zone: str | timezone | Iterable[str | timezone | None] | None, + time_unit: OneOrIterable[TimeUnit] | None, + time_zone: OneOrIterable[str | timezone | None], /, ) -> Datetime: units, zones = _parse_time_unit_and_time_zone(time_unit, time_zone) @@ -125,11 +125,10 @@ def from_string(pattern: str, /) -> Matches: return Matches(pattern=re.compile(pattern)) @staticmethod - def from_names(*names: str | Iterable[str]) -> Matches: + def from_names(*names: OneOrIterable[str]) -> Matches: """Implements `cs.by_name` to support `__r__` with column selections.""" it: Iterator[str] = flatten_hash_safe(names) - pattern = f"^({'|'.join(re.escape(name) for name in it)})$" - return Matches.from_string(pattern) + return Matches.from_string(f"^({'|'.join(re.escape(name) for name in it)})$") def __repr__(self) -> str: return f"ncs.matches(pattern={self.pattern.pattern!r})" @@ -158,13 +157,11 @@ def all() -> dummy.Selector: return All().to_selector().to_narwhals() -def by_dtype( - *dtypes: DType | type[DType] | Iterable[DType | type[DType]], -) -> dummy.Selector: +def by_dtype(*dtypes: OneOrIterable[DType | type[DType]]) -> dummy.Selector: return ByDType.from_dtypes(*dtypes).to_selector().to_narwhals() -def by_name(*names: str | Iterable[str]) -> dummy.Selector: +def by_name(*names: OneOrIterable[str]) -> dummy.Selector: return Matches.from_names(*names).to_selector().to_narwhals() @@ -177,8 +174,8 @@ def categorical() -> dummy.Selector: def datetime( - time_unit: TimeUnit | Iterable[TimeUnit] | None = None, - time_zone: str | timezone | Iterable[str | timezone | None] | None = ("*", None), + time_unit: OneOrIterable[TimeUnit] | None = None, + time_zone: OneOrIterable[str | timezone | None] = ("*", None), ) -> dummy.Selector: return ( Datetime.from_time_unit_and_time_zone(time_unit, time_zone) diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/strings.py index 8a1789b079..4c1f4af303 100644 --- a/narwhals/_plan/strings.py +++ b/narwhals/_plan/strings.py @@ -1,20 +1,21 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar -from narwhals._plan.common import ExprNamespace, Function, IRNamespace +from narwhals._plan.common import ExprNamespace, Function, HorizontalFunction, IRNamespace from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: from narwhals._plan.dummy import Expr +# fmt: off class StringFunction(Function, accessor="str", options=FunctionOptions.elementwise): ... - - -class ConcatHorizontal(StringFunction, options=FunctionOptions.horizontal): - """`nw.functions.concat_str`.""" - +class LenChars(StringFunction): ... +class ToLowercase(StringFunction): ... +class ToUppercase(StringFunction): ... +# fmt: on +class ConcatStr(HorizontalFunction, StringFunction): __slots__ = ("ignore_nulls", "separator") separator: str ignore_nulls: bool @@ -31,9 +32,6 @@ class EndsWith(StringFunction): suffix: str -class LenChars(StringFunction): ... - - class Replace(StringFunction): __slots__ = ("literal", "n", "pattern", "value") pattern: str @@ -75,15 +73,13 @@ class ToDatetime(StringFunction): format: str | None -class ToLowercase(StringFunction): ... - - -class ToUppercase(StringFunction): ... - - class IRStringNamespace(IRNamespace): - def len_chars(self) -> LenChars: - return LenChars() + len_chars: ClassVar = LenChars + to_lowercase: ClassVar = ToUppercase + to_uppercase: ClassVar = ToLowercase + split: ClassVar = Split + starts_with: ClassVar = StartsWith + ends_with: ClassVar = EndsWith def replace( self, pattern: str, value: str, *, literal: bool = False, n: int = 1 @@ -98,12 +94,6 @@ def replace_all( def strip_chars(self, characters: str | None = None) -> StripChars: return StripChars(characters=characters) - def starts_with(self, prefix: str) -> StartsWith: - return StartsWith(prefix=prefix) - - def ends_with(self, suffix: str) -> EndsWith: - return EndsWith(suffix=suffix) - def contains(self, pattern: str, *, literal: bool = False) -> Contains: return Contains(pattern=pattern, literal=literal) @@ -116,18 +106,9 @@ def head(self, n: int = 5) -> Slice: def tail(self, n: int = 5) -> Slice: return self.slice(-n) - def split(self, by: str) -> Split: - return Split(by=by) - def to_datetime(self, format: str | None = None) -> ToDatetime: return ToDatetime(format=format) - def to_lowercase(self) -> ToUppercase: - return ToUppercase() - - def to_uppercase(self) -> ToLowercase: - return ToLowercase() - class ExprStringNamespace(ExprNamespace[IRStringNamespace]): @property @@ -149,10 +130,10 @@ def strip_chars(self, characters: str | None = None) -> Expr: return self._with_unary(self._ir.strip_chars(characters)) def starts_with(self, prefix: str) -> Expr: - return self._with_unary(self._ir.starts_with(prefix)) + return self._with_unary(self._ir.starts_with(prefix=prefix)) def ends_with(self, suffix: str) -> Expr: - return self._with_unary(self._ir.ends_with(suffix)) + return self._with_unary(self._ir.ends_with(suffix=suffix)) def contains(self, pattern: str, *, literal: bool = False) -> Expr: return self._with_unary(self._ir.contains(pattern, literal=literal)) @@ -167,7 +148,7 @@ def tail(self, n: int = 5) -> Expr: return self._with_unary(self._ir.tail(n)) def split(self, by: str) -> Expr: - return self._with_unary(self._ir.split(by)) + return self._with_unary(self._ir.split(by=by)) def to_datetime(self, format: str | None = None) -> Expr: return self._with_unary(self._ir.to_datetime(format)) diff --git a/narwhals/_plan/struct.py b/narwhals/_plan/struct.py index d91fef6458..2a3eca0b27 100644 --- a/narwhals/_plan/struct.py +++ b/narwhals/_plan/struct.py @@ -1,9 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from narwhals._plan.common import ExprNamespace, Function, IRNamespace -from narwhals._plan.options import FunctionOptions +from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: from narwhals._plan.dummy import Expr @@ -12,9 +12,9 @@ class StructFunction(Function, accessor="struct"): ... -class FieldByName(StructFunction, options=FunctionOptions.elementwise): - """https://github.com/pola-rs/polars/blob/62257860a43ec44a638e8492ed2cf98a49c05f2e/crates/polars-plan/src/dsl/function_expr/struct_.rs#L11.""" - +class FieldByName( + StructFunction, options=FunctionOptions.elementwise, config=FEOptions.renamed("field") +): __slots__ = ("name",) name: str @@ -23,8 +23,7 @@ def __repr__(self) -> str: class IRStructNamespace(IRNamespace): - def field(self, name: str) -> FieldByName: - return FieldByName(name=name) + field: ClassVar = FieldByName class ExprStructNamespace(ExprNamespace[IRStructNamespace]): @@ -33,4 +32,4 @@ def _ir_namespace(self) -> type[IRStructNamespace]: return IRStructNamespace def field(self, name: str) -> Expr: - return self._with_unary(self._ir.field(name)) + return self._with_unary(self._ir.field(name=name)) diff --git a/narwhals/_plan/temporal.py b/narwhals/_plan/temporal.py index f6a74587f7..bd21388728 100644 --- a/narwhals/_plan/temporal.py +++ b/narwhals/_plan/temporal.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, ClassVar, Literal from narwhals._duration import Interval from narwhals._plan.common import ExprNamespace, Function, IRNamespace @@ -20,60 +20,26 @@ def _is_polars_time_unit(obj: Any) -> TypeIs[PolarsTimeUnit]: return obj in {"ns", "us", "ms"} +# fmt: off class TemporalFunction(Function, accessor="dt", options=FunctionOptions.elementwise): ... - - class Date(TemporalFunction): ... - - class Year(TemporalFunction): ... - - class Month(TemporalFunction): ... - - class Day(TemporalFunction): ... - - class Hour(TemporalFunction): ... - - class Minute(TemporalFunction): ... - - class Second(TemporalFunction): ... - - class Millisecond(TemporalFunction): ... - - class Microsecond(TemporalFunction): ... - - class Nanosecond(TemporalFunction): ... - - class OrdinalDay(TemporalFunction): ... - - class WeekDay(TemporalFunction): ... - - class TotalMinutes(TemporalFunction): ... - - class TotalSeconds(TemporalFunction): ... - - class TotalMilliseconds(TemporalFunction): ... - - class TotalMicroseconds(TemporalFunction): ... - - class TotalNanoseconds(TemporalFunction): ... - - +# fmt: on class ToString(TemporalFunction): __slots__ = ("format",) format: str @@ -94,7 +60,7 @@ class Timestamp(TemporalFunction): time_unit: PolarsTimeUnit @staticmethod - def from_time_unit(time_unit: TimeUnit, /) -> Timestamp: + def from_time_unit(time_unit: TimeUnit = "us", /) -> Timestamp: if not _is_polars_time_unit(time_unit): msg = f"invalid `time_unit` \n\nExpected one of ['ns', 'us', 'ms'], got {time_unit!r}." raise ValueError(msg) @@ -119,71 +85,28 @@ def from_interval(every: Interval, /) -> Truncate: class IRDateTimeNamespace(IRNamespace): - def date(self) -> Date: - return Date() - - def year(self) -> Year: - return Year() - - def month(self) -> Month: - return Month() - - def day(self) -> Day: - return Day() - - def hour(self) -> Hour: - return Hour() - - def minute(self) -> Minute: - return Minute() - - def second(self) -> Second: - return Second() - - def millisecond(self) -> Millisecond: - return Millisecond() - - def microsecond(self) -> Microsecond: - return Microsecond() - - def nanosecond(self) -> Nanosecond: - return Nanosecond() - - def ordinal_day(self) -> OrdinalDay: - return OrdinalDay() - - def weekday(self) -> WeekDay: - return WeekDay() - - def total_minutes(self) -> TotalMinutes: - return TotalMinutes() - - def total_seconds(self) -> TotalSeconds: - return TotalSeconds() - - def total_milliseconds(self) -> TotalMilliseconds: - return TotalMilliseconds() - - def total_microseconds(self) -> TotalMicroseconds: - return TotalMicroseconds() - - def total_nanoseconds(self) -> TotalNanoseconds: - return TotalNanoseconds() - - def to_string(self, format: str) -> ToString: - return ToString(format=format) - - def replace_time_zone(self, time_zone: str | None) -> ReplaceTimeZone: - return ReplaceTimeZone(time_zone=time_zone) - - def convert_time_zone(self, time_zone: str) -> ConvertTimeZone: - return ConvertTimeZone(time_zone=time_zone) - - def timestamp(self, time_unit: TimeUnit = "us") -> Timestamp: - return Timestamp.from_time_unit(time_unit) - - def truncate(self, every: str) -> Truncate: - return Truncate.from_string(every) + date: ClassVar = Date + year: ClassVar = Year + month: ClassVar = Month + day: ClassVar = Day + hour: ClassVar = Hour + minute: ClassVar = Minute + second: ClassVar = Second + millisecond: ClassVar = Millisecond + microsecond: ClassVar = Microsecond + nanosecond: ClassVar = Nanosecond + ordinal_day: ClassVar = OrdinalDay + weekday: ClassVar = WeekDay + total_minutes: ClassVar = TotalMinutes + total_seconds: ClassVar = TotalSeconds + total_milliseconds: ClassVar = TotalMilliseconds + total_microseconds: ClassVar = TotalMicroseconds + total_nanoseconds: ClassVar = TotalNanoseconds + to_string: ClassVar = ToString + replace_time_zone: ClassVar = ReplaceTimeZone + convert_time_zone: ClassVar = ConvertTimeZone + truncate: ClassVar = staticmethod(Truncate.from_string) + timestamp: ClassVar = staticmethod(Timestamp.from_time_unit) class ExprDateTimeNamespace(ExprNamespace[IRDateTimeNamespace]): @@ -252,7 +175,7 @@ def convert_time_zone(self, time_zone: str) -> Expr: return self._with_unary(self._ir.convert_time_zone(time_zone=time_zone)) def timestamp(self, time_unit: TimeUnit = "us") -> Expr: - return self._with_unary(self._ir.timestamp(time_unit=time_unit)) + return self._with_unary(self._ir.timestamp(time_unit)) def truncate(self, every: str) -> Expr: - return self._with_unary(self._ir.truncate(every=every)) + return self._with_unary(self._ir.truncate(every)) diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 0b8b884b5f..251489c68d 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -47,10 +47,8 @@ RollingT = TypeVar("RollingT", bound="RollingWindow", default="RollingWindow") RangeT = TypeVar("RangeT", bound="RangeFunction", default="RangeFunction") LeftT = TypeVar("LeftT", bound="ExprIR", default="ExprIR") -LeftT2 = TypeVar("LeftT2", bound="ExprIR", default="ExprIR") OperatorT = TypeVar("OperatorT", bound="ops.Operator", default="ops.Operator") RightT = TypeVar("RightT", bound="ExprIR", default="ExprIR") -RightT2 = TypeVar("RightT2", bound="ExprIR", default="ExprIR") OperatorFn: TypeAlias = "t.Callable[[t.Any, t.Any], t.Any]" ExprIRT = TypeVar("ExprIRT", bound="ExprIR", default="ExprIR") ExprIRT2 = TypeVar("ExprIRT2", bound="ExprIR", default="ExprIR") @@ -96,3 +94,4 @@ IntoExprColumn: TypeAlias = "Expr | Series[t.Any] | str" IntoExpr: TypeAlias = "NonNestedLiteral | IntoExprColumn" +OneOrIterable: TypeAlias = "T | t.Iterable[T]" diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index d264f39733..62e0da3d2a 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING, Any -from narwhals._plan.common import Immutable, is_expr +from narwhals._plan._guards import is_expr +from narwhals._plan._immutable import Immutable from narwhals._plan.dummy import Expr from narwhals._plan.expr_parsing import ( parse_into_expr_ir, @@ -10,11 +11,9 @@ ) if TYPE_CHECKING: - from collections.abc import Iterable - from narwhals._plan.common import ExprIR - from narwhals._plan.expr import Ternary - from narwhals._plan.typing import IntoExpr, IntoExprColumn, Seq + from narwhals._plan.expr import TernaryExpr + from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq class When(Immutable): @@ -39,7 +38,7 @@ class Then(Immutable, Expr): statement: ExprIR def when( - self, *predicates: IntoExprColumn | Iterable[IntoExprColumn], **constraints: Any + self, *predicates: OneOrIterable[IntoExprColumn], **constraints: Any ) -> ChainedWhen: condition = parse_predicates_constraints_into_expr_ir(*predicates, **constraints) return ChainedWhen( @@ -84,7 +83,7 @@ class ChainedThen(Immutable, Expr): statements: Seq[ExprIR] def when( - self, *predicates: IntoExprColumn | Iterable[IntoExprColumn], **constraints: Any + self, *predicates: OneOrIterable[IntoExprColumn], **constraints: Any ) -> ChainedWhen: condition = parse_predicates_constraints_into_expr_ir(*predicates, **constraints) return ChainedWhen( @@ -96,10 +95,8 @@ def otherwise(self, statement: IntoExpr, /) -> Expr: def _otherwise(self, statement: IntoExpr = None, /) -> ExprIR: otherwise = parse_into_expr_ir(statement) - it_conditions = reversed(self.conditions) - it_statements = reversed(self.statements) - for e in it_conditions: - otherwise = ternary_expr(e, next(it_statements), otherwise) + for cond, stmt in zip(reversed(self.conditions), reversed(self.statements)): + otherwise = ternary_expr(cond, stmt, otherwise) return otherwise @property @@ -116,7 +113,7 @@ def __eq__(self, value: object) -> Expr | bool: # type: ignore[override] return super().__eq__(value) -def ternary_expr(predicate: ExprIR, truthy: ExprIR, falsy: ExprIR, /) -> Ternary: - from narwhals._plan.expr import Ternary +def ternary_expr(predicate: ExprIR, truthy: ExprIR, falsy: ExprIR, /) -> TernaryExpr: + from narwhals._plan.expr import TernaryExpr - return Ternary(predicate=predicate, truthy=truthy, falsy=falsy) + return TernaryExpr(predicate=predicate, truthy=truthy, falsy=falsy) diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py index f575d9d303..fd27743948 100644 --- a/narwhals/_plan/window.py +++ b/narwhals/_plan/window.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import Immutable, is_function_expr, is_window_expr +from narwhals._plan._guards import is_function_expr, is_window_expr +from narwhals._plan._immutable import Immutable from narwhals._plan.exceptions import ( over_elementwise_error, over_nested_error, diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 7d7f1f6248..dc548968a4 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -11,7 +11,7 @@ import narwhals as nw from narwhals._plan import demo as nwd, selectors as ndcs -from narwhals._plan.common import is_expr +from narwhals._plan._guards import is_expr from narwhals._plan.dummy import DataFrame from narwhals._utils import Version from narwhals.exceptions import ComputeError diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index 8e5dd0f29c..740d966818 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -6,7 +6,8 @@ import narwhals as nw from narwhals._plan import demo as nwd, expr_parsing as parse, selectors as ndcs -from narwhals._plan.common import ExprIR, NamedIR, is_expr +from narwhals._plan._guards import is_expr +from narwhals._plan.common import ExprIR, NamedIR from narwhals._plan.expr import WindowExpr from narwhals._plan.expr_rewrites import ( rewrite_all, diff --git a/tests/plan/immutable_test.py b/tests/plan/immutable_test.py index 3c5e97439e..6f9d0450ad 100644 --- a/tests/plan/immutable_test.py +++ b/tests/plan/immutable_test.py @@ -6,7 +6,7 @@ import pytest -from narwhals._plan.common import Immutable +from narwhals._plan._immutable import Immutable class Empty(Immutable): ... diff --git a/tests/plan/utils.py b/tests/plan/utils.py index 6b818f82df..4eaf98db9f 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import ExprIR, NamedIR, is_expr +from narwhals._plan._guards import is_expr +from narwhals._plan.common import ExprIR, NamedIR if TYPE_CHECKING: from typing_extensions import LiteralString From 6bc5ff5c5949df164512670278d7f6d229b1a38a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 10 Sep 2025 16:25:08 +0000 Subject: [PATCH 357/368] fix(typing): Make pyright mostly happy Why is mypy like this? --- narwhals/_plan/expr.py | 25 ++++++++++++++++--------- narwhals/_plan/typing.py | 9 +++++++++ 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 7d35237f16..ecf4efe136 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -12,15 +12,15 @@ from narwhals._plan.name import KeepName, RenameAlias from narwhals._plan.options import ExprIROptions from narwhals._plan.typing import ( - FunctionT, + FunctionT_co, LeftSelectorT, LeftT, LiteralT, OperatorT, - RangeT, + RangeT_co, RightSelectorT, RightT, - RollingT, + RollingT_co, SelectorOperatorT, SelectorT, Seq, @@ -258,7 +258,8 @@ def iter_output_name(self) -> t.Iterator[ExprIR]: yield from self.expr.iter_output_name() -class FunctionExpr(ExprIR, t.Generic[FunctionT], child=("input",)): +# mypy: disable-error-code="misc" +class FunctionExpr(ExprIR, t.Generic[FunctionT_co], child=("input",)): """**Representing `Expr::Function`**. https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L114-L120 @@ -267,7 +268,13 @@ class FunctionExpr(ExprIR, t.Generic[FunctionT], child=("input",)): __slots__ = ("function", "input", "options") input: Seq[ExprIR] - function: FunctionT + # NOTE: mypy being mypy - the top error can't be silenced 🤦‍♂️ + # narwhals/_plan/expr.py: error: Cannot use a covariant type variable as a parameter [misc] + # narwhals/_plan/expr.py:272:15: error: Cannot use a covariant type variable as a parameter [misc] + # function: FunctionT_co # noqa: ERA001 + # ^ + # Found 2 errors in 1 file (checked 476 source files) + function: FunctionT_co """Operation applied to each element of `input`.""" options: FunctionOptions @@ -304,7 +311,7 @@ def __init__( self, *, input: Seq[ExprIR], # noqa: A002 - function: FunctionT, + function: FunctionT_co, options: FunctionOptions, **kwds: t.Any, ) -> None: @@ -319,7 +326,7 @@ def dispatch( return self.function.__expr_ir_dispatch__(ctx, t.cast("Self", self), frame, name) # type: ignore[no-any-return] -class RollingExpr(FunctionExpr[RollingT]): ... +class RollingExpr(FunctionExpr[RollingT_co]): ... class AnonymousExpr( @@ -333,7 +340,7 @@ def dispatch( return self.__expr_ir_dispatch__(ctx, t.cast("Self", self), frame, name) # type: ignore[no-any-return] -class RangeExpr(FunctionExpr[RangeT]): +class RangeExpr(FunctionExpr[RangeT_co]): """E.g. `int_range(...)`. Special-cased as it is only allowed scalar inputs, and is row_separable. @@ -343,7 +350,7 @@ def __init__( self, *, input: Seq[ExprIR], # noqa: A002 - function: RangeT, + function: RangeT_co, options: FunctionOptions, **kwds: t.Any, ) -> None: diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 251489c68d..b7e0736e15 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -46,6 +46,15 @@ FunctionT = TypeVar("FunctionT", bound="Function", default="Function") RollingT = TypeVar("RollingT", bound="RollingWindow", default="RollingWindow") RangeT = TypeVar("RangeT", bound="RangeFunction", default="RangeFunction") +FunctionT_co = TypeVar( + "FunctionT_co", bound="Function", default="Function", covariant=True +) +RollingT_co = TypeVar( + "RollingT_co", bound="RollingWindow", default="RollingWindow", covariant=True +) +RangeT_co = TypeVar( + "RangeT_co", bound="RangeFunction", default="RangeFunction", covariant=True +) LeftT = TypeVar("LeftT", bound="ExprIR", default="ExprIR") OperatorT = TypeVar("OperatorT", bound="ops.Operator", default="ops.Operator") RightT = TypeVar("RightT", bound="ExprIR", default="ExprIR") From 2aa8a2e24cb77d1ebb99ff860a55b839422ab98d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 10 Sep 2025 16:28:58 +0000 Subject: [PATCH 358/368] chore(typing): Ignore the other one for now It is a valid complaint - but the interaction with `@dataclass_transform` is an issue --- narwhals/_plan/literal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index 94f5a9a5b4..22c170508c 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -39,7 +39,7 @@ def unwrap(self) -> LiteralT: class ScalarLiteral(LiteralValue[NonNestedLiteralT]): - __slots__ = ("dtype", "value") + __slots__ = ("dtype", "value") # pyright: ignore[reportIncompatibleMethodOverride] value: NonNestedLiteralT dtype: DType From 967a1507234ca4636e37951916b5c6b38ad7db81 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Sep 2025 20:36:03 +0000 Subject: [PATCH 359/368] refactor(expr-ir): Organize `_plan` package (#3122) --- .pre-commit-config.yaml | 4 +- narwhals/_plan/__init__.py | 55 + .../{expr_expansion.py => _expansion.py} | 6 +- narwhals/_plan/_expr_ir.py | 292 ++++++ narwhals/_plan/_function.py | 83 ++ narwhals/_plan/_guards.py | 47 +- narwhals/_plan/{expr_parsing.py => _parse.py} | 20 +- .../_plan/{expr_rewrites.py => _rewrites.py} | 28 +- narwhals/_plan/arrow/dataframe.py | 6 +- narwhals/_plan/arrow/expr.py | 53 +- narwhals/_plan/arrow/functions.py | 2 +- narwhals/_plan/arrow/namespace.py | 17 +- narwhals/_plan/common.py | 432 +------- narwhals/_plan/dataframe.py | 140 +++ narwhals/_plan/demo.py | 161 --- narwhals/_plan/dummy.py | 887 ---------------- narwhals/_plan/exceptions.py | 41 +- narwhals/_plan/expr.py | 974 +++++++++++------- narwhals/_plan/expressions/__init__.py | 94 ++ .../_plan/{ => expressions}/aggregation.py | 5 +- narwhals/_plan/{ => expressions}/boolean.py | 10 +- .../_plan/{ => expressions}/categorical.py | 5 +- narwhals/_plan/expressions/expr.py | 505 +++++++++ narwhals/_plan/expressions/functions.py | 182 ++++ narwhals/_plan/{ => expressions}/lists.py | 5 +- narwhals/_plan/{ => expressions}/literal.py | 6 +- narwhals/_plan/{ => expressions}/name.py | 17 +- narwhals/_plan/expressions/namespace.py | 41 + narwhals/_plan/{ => expressions}/operators.py | 7 +- narwhals/_plan/{ => expressions}/ranges.py | 6 +- narwhals/_plan/{ => expressions}/selectors.py | 24 +- narwhals/_plan/{ => expressions}/strings.py | 5 +- narwhals/_plan/{ => expressions}/struct.py | 5 +- narwhals/_plan/{ => expressions}/temporal.py | 5 +- narwhals/_plan/expressions/window.py | 67 ++ narwhals/_plan/functions.py | 251 +++-- narwhals/_plan/meta.py | 113 +- narwhals/_plan/protocols.py | 70 +- narwhals/_plan/schema.py | 2 +- narwhals/_plan/series.py | 67 ++ narwhals/_plan/typing.py | 13 +- narwhals/_plan/when_then.py | 21 +- narwhals/_plan/window.py | 70 -- pyproject.toml | 2 +- tests/plan/compliant_test.py | 253 ++--- tests/plan/expr_expansion_test.py | 314 +++--- tests/plan/expr_parsing_test.py | 267 +++-- tests/plan/expr_rewrites_test.py | 84 +- tests/plan/meta_test.py | 73 +- tests/plan/utils.py | 22 +- 50 files changed, 3024 insertions(+), 2835 deletions(-) rename narwhals/_plan/{expr_expansion.py => _expansion.py} (99%) create mode 100644 narwhals/_plan/_expr_ir.py create mode 100644 narwhals/_plan/_function.py rename narwhals/_plan/{expr_parsing.py => _parse.py} (93%) rename narwhals/_plan/{expr_rewrites.py => _rewrites.py} (72%) create mode 100644 narwhals/_plan/dataframe.py delete mode 100644 narwhals/_plan/demo.py delete mode 100644 narwhals/_plan/dummy.py create mode 100644 narwhals/_plan/expressions/__init__.py rename narwhals/_plan/{ => expressions}/aggregation.py (89%) rename narwhals/_plan/{ => expressions}/boolean.py (90%) rename narwhals/_plan/{ => expressions}/categorical.py (77%) create mode 100644 narwhals/_plan/expressions/expr.py create mode 100644 narwhals/_plan/expressions/functions.py rename narwhals/_plan/{ => expressions}/lists.py (78%) rename narwhals/_plan/{ => expressions}/literal.py (92%) rename narwhals/_plan/{ => expressions}/name.py (85%) create mode 100644 narwhals/_plan/expressions/namespace.py rename narwhals/_plan/{ => expressions}/operators.py (93%) rename narwhals/_plan/{ => expressions}/ranges.py (82%) rename narwhals/_plan/{ => expressions}/selectors.py (90%) rename narwhals/_plan/{ => expressions}/strings.py (96%) rename narwhals/_plan/{ => expressions}/struct.py (83%) rename narwhals/_plan/{ => expressions}/temporal.py (97%) create mode 100644 narwhals/_plan/expressions/window.py create mode 100644 narwhals/_plan/series.py delete mode 100644 narwhals/_plan/window.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9f9f105658..2ea270264b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -77,8 +77,8 @@ repos: narwhals/stable/v./_?dtypes.py| narwhals/.*__init__.py| narwhals/.*typing\.py| - narwhals/_plan/demo\.py| - narwhals/_plan/ranges\.py| + narwhals/_plan/functions\.py| + narwhals/_plan/expressions/ranges\.py| narwhals/_plan/schema\.py ) - id: pull-request-target diff --git a/narwhals/_plan/__init__.py b/narwhals/_plan/__init__.py index 9d48db4f9f..afeff442c0 100644 --- a/narwhals/_plan/__init__.py +++ b/narwhals/_plan/__init__.py @@ -1 +1,56 @@ from __future__ import annotations + +from narwhals._plan.dataframe import DataFrame +from narwhals._plan.expr import Expr, Selector +from narwhals._plan.expressions import selectors +from narwhals._plan.functions import ( + all, + all_horizontal, + any_horizontal, + col, + concat_str, + exclude, + int_range, + len, + lit, + max, + max_horizontal, + mean, + mean_horizontal, + median, + min, + min_horizontal, + nth, + sum, + sum_horizontal, + when, +) +from narwhals._plan.series import Series + +__all__ = [ + "DataFrame", + "Expr", + "Selector", + "Series", + "all", + "all_horizontal", + "any_horizontal", + "col", + "concat_str", + "exclude", + "int_range", + "len", + "lit", + "max", + "max_horizontal", + "mean", + "mean_horizontal", + "median", + "min", + "min_horizontal", + "nth", + "selectors", + "sum", + "sum_horizontal", + "when", +] diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/_expansion.py similarity index 99% rename from narwhals/_plan/expr_expansion.py rename to narwhals/_plan/_expansion.py index 2fdf92ef7d..fb2dd390a8 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/_expansion.py @@ -45,21 +45,23 @@ from narwhals._plan import common, meta from narwhals._plan._guards import is_horizontal_reduction from narwhals._plan._immutable import Immutable -from narwhals._plan.common import ExprIR, NamedIR, SelectorIR from narwhals._plan.exceptions import ( column_index_error, column_not_found_error, duplicate_error, ) -from narwhals._plan.expr import ( +from narwhals._plan.expressions import ( Alias, All, Columns, Exclude, + ExprIR, IndexColumns, KeepName, + NamedIR, Nth, RenameAlias, + SelectorIR, _ColumnSelection, col, cols, diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py new file mode 100644 index 0000000000..0646520102 --- /dev/null +++ b/narwhals/_plan/_expr_ir.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, cast + +from narwhals._plan._guards import is_function_expr, is_literal +from narwhals._plan._immutable import Immutable +from narwhals._plan.common import dispatch_getter, replace +from narwhals._plan.options import ExprIROptions +from narwhals._plan.typing import ExprIRT +from narwhals.utils import Version + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + from typing import Any, ClassVar + + from typing_extensions import Self, TypeAlias + + from narwhals._plan.expr import Expr, Selector + from narwhals._plan.expressions.expr import Alias, Cast, Column + from narwhals._plan.meta import MetaNamespace + from narwhals._plan.protocols import Ctx, FrameT_contra, R_co + from narwhals._plan.typing import ExprIRT2, MapIR, Seq + from narwhals.dtypes import DType + + Incomplete: TypeAlias = "Any" + + +def _dispatch_generate( + tp: type[ExprIRT], / +) -> Callable[[Incomplete, ExprIRT, Incomplete, str], Incomplete]: + if not tp.__expr_ir_config__.allow_dispatch: + + def _(ctx: Any, /, node: ExprIRT, _: Any, name: str) -> Any: + msg = ( + f"{tp.__name__!r} should not appear at the compliant-level.\n\n" + f"Make sure to expand all expressions first, got:\n{ctx!r}\n{node!r}\n{name!r}" + ) + raise TypeError(msg) + + return _ + getter = dispatch_getter(tp) + + def _(ctx: Any, /, node: ExprIRT, frame: Any, name: str) -> Any: + return getter(ctx)(node, frame, name) + + return _ + + +class ExprIR(Immutable): + """Anything that can be a node on a graph of expressions.""" + + _child: ClassVar[Seq[str]] = () + """Nested node names, in iteration order.""" + + __expr_ir_config__: ClassVar[ExprIROptions] = ExprIROptions.default() + __expr_ir_dispatch__: ClassVar[ + staticmethod[[Incomplete, Self, Incomplete, str], Incomplete] + ] + + def __init_subclass__( + cls: type[Self], + *args: Any, + child: Seq[str] = (), + config: ExprIROptions | None = None, + **kwds: Any, + ) -> None: + super().__init_subclass__(*args, **kwds) + if child: + cls._child = child + if config: + cls.__expr_ir_config__ = config + cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate(cls)) + + def dispatch( + self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, / + ) -> R_co: + """Evaluate expression in `frame`, using `ctx` for implementation(s).""" + return self.__expr_ir_dispatch__(ctx, cast("Self", self), frame, name) # type: ignore[no-any-return] + + def to_narwhals(self, version: Version = Version.MAIN) -> Expr: + from narwhals._plan import expr + + tp = expr.Expr if version is Version.MAIN else expr.ExprV1 + return tp._from_ir(self) + + @property + def is_scalar(self) -> bool: + return False + + def map_ir(self, function: MapIR, /) -> ExprIR: + """Apply `function` to each child node, returning a new `ExprIR`. + + See [`polars_plan::plans::iterator::Expr.map_expr`] and [`polars_plan::plans::visitor::visitors`]. + + [`polars_plan::plans::iterator::Expr.map_expr`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/iterator.rs#L152-L159 + [`polars_plan::plans::visitor::visitors`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/visitor/visitors.rs + """ + if not self._child: + return function(self) + children = ((name, getattr(self, name)) for name in self._child) + changed = {name: _map_ir_child(child, function) for name, child in children} + return function(replace(self, **changed)) + + def iter_left(self) -> Iterator[ExprIR]: + """Yield nodes root->leaf. + + Examples: + >>> from narwhals import _plan as nw + >>> + >>> a = nw.col("a") + >>> b = a.alias("b") + >>> c = b.min().alias("c") + >>> d = c.over(nw.col("e"), nw.col("f")) + >>> + >>> list(a._ir.iter_left()) + [col('a')] + >>> + >>> list(b._ir.iter_left()) + [col('a'), col('a').alias('b')] + >>> + >>> list(c._ir.iter_left()) + [col('a'), col('a').alias('b'), col('a').alias('b').min(), col('a').alias('b').min().alias('c')] + >>> + >>> list(d._ir.iter_left()) + [col('a'), col('a').alias('b'), col('a').alias('b').min(), col('a').alias('b').min().alias('c'), col('e'), col('f'), col('a').alias('b').min().alias('c').over([col('e'), col('f')])] + """ + for name in self._child: + child: ExprIR | Seq[ExprIR] = getattr(self, name) + if isinstance(child, ExprIR): + yield from child.iter_left() + else: + for node in child: + yield from node.iter_left() + yield self + + def iter_right(self) -> Iterator[ExprIR]: + """Yield nodes leaf->root. + + Note: + Identical to `iter_left` for root nodes. + + Examples: + >>> from narwhals import _plan as nw + >>> + >>> a = nw.col("a") + >>> b = a.alias("b") + >>> c = b.min().alias("c") + >>> d = c.over(nw.col("e"), nw.col("f")) + >>> + >>> list(a._ir.iter_right()) + [col('a')] + >>> + >>> list(b._ir.iter_right()) + [col('a').alias('b'), col('a')] + >>> + >>> list(c._ir.iter_right()) + [col('a').alias('b').min().alias('c'), col('a').alias('b').min(), col('a').alias('b'), col('a')] + >>> + >>> list(d._ir.iter_right()) + [col('a').alias('b').min().alias('c').over([col('e'), col('f')]), col('f'), col('e'), col('a').alias('b').min().alias('c'), col('a').alias('b').min(), col('a').alias('b'), col('a')] + """ + yield self + for name in reversed(self._child): + child: ExprIR | Seq[ExprIR] = getattr(self, name) + if isinstance(child, ExprIR): + yield from child.iter_right() + else: + for node in reversed(child): + yield from node.iter_right() + + def iter_root_names(self) -> Iterator[ExprIR]: + """Override for different iteration behavior in `ExprIR.meta.root_names`. + + Note: + Identical to `iter_left` by default. + """ + yield from self.iter_left() + + def iter_output_name(self) -> Iterator[ExprIR]: + """Override for different iteration behavior in `ExprIR.meta.output_name`. + + Note: + Identical to `iter_right` by default. + """ + yield from self.iter_right() + + @property + def meta(self) -> MetaNamespace: + from narwhals._plan.meta import MetaNamespace + + return MetaNamespace(_ir=self) + + def cast(self, dtype: DType) -> Cast: + from narwhals._plan.expressions.expr import Cast + + return Cast(expr=self, dtype=dtype) + + def alias(self, name: str) -> Alias: + from narwhals._plan.expressions.expr import Alias + + return Alias(expr=self, name=name) + + def _repr_html_(self) -> str: + return self.__repr__() + + +def _map_ir_child(obj: ExprIR | Seq[ExprIR], fn: MapIR, /) -> ExprIR | Seq[ExprIR]: + return obj.map_ir(fn) if isinstance(obj, ExprIR) else tuple(e.map_ir(fn) for e in obj) + + +class SelectorIR(ExprIR, config=ExprIROptions.no_dispatch()): + def to_narwhals(self, version: Version = Version.MAIN) -> Selector: + from narwhals._plan import expr + + if version is Version.MAIN: + return expr.Selector._from_ir(self) + return expr.SelectorV1._from_ir(self) + + def matches_column(self, name: str, dtype: DType) -> bool: + """Return True if we can select this column. + + - Thinking that we could get more cache hits on an individual column basis. + - May also be more efficient to not iterate over the schema for every selector + - Instead do one pass, evaluating every selector against a single column at a time + """ + raise NotImplementedError(type(self)) + + +class NamedIR(Immutable, Generic[ExprIRT]): + """Post-projection expansion wrapper for `ExprIR`. + + - Somewhat similar to [`polars_plan::plans::expr_ir::ExprIR`]. + - The [`polars_plan::plans::aexpr::AExpr`] stage has been skipped (*for now*) + - Parts of that will probably be in here too + - `AExpr` seems like too much duplication when we won't get the memory allocation benefits in python + + [`polars_plan::plans::expr_ir::ExprIR`]: https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/expr_ir.rs#L63-L74 + [`polars_plan::plans::aexpr::AExpr`]: https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/aexpr/mod.rs#L145-L231 + """ + + __slots__ = ("expr", "name") + expr: ExprIRT + name: str + + @staticmethod + def from_name(name: str, /) -> NamedIR[Column]: + """Construct as a simple, unaliased `col(name)` expression. + + Intended to be used in `with_columns` from a `FrozenSchema`'s keys. + """ + from narwhals._plan.expressions.expr import col + + return NamedIR(expr=col(name), name=name) + + @staticmethod + def from_ir(expr: ExprIRT2, /) -> NamedIR[ExprIRT2]: + """Construct from an already expanded `ExprIR`. + + Should be cheap to get the output name from cache, but will raise if used + without care. + """ + return NamedIR(expr=expr, name=expr.meta.output_name(raise_if_undetermined=True)) + + def map_ir(self, function: MapIR, /) -> Self: + """**WARNING**: don't use renaming ops here, or `self.name` is invalid.""" + return replace(self, expr=function(self.expr.map_ir(function))) + + def __repr__(self) -> str: + return f"{self.name}={self.expr!r}" + + def _repr_html_(self) -> str: + return f"{self.name}={self.expr._repr_html_()}" + + def is_elementwise_top_level(self) -> bool: + """Return True if the outermost node is elementwise. + + Based on [`polars_plan::plans::aexpr::properties::AExpr.is_elementwise_top_level`] + + This check: + - Is not recursive + - Is not valid on `ExprIR` *prior* to being expanded + + [`polars_plan::plans::aexpr::properties::AExpr.is_elementwise_top_level`]: https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/aexpr/properties.rs#L16-L44 + """ + from narwhals._plan.expressions import expr + + ir = self.expr + if is_function_expr(ir): + return ir.options.is_elementwise() + if is_literal(ir): + return ir.is_scalar + return isinstance(ir, (expr.BinaryExpr, expr.Column, expr.TernaryExpr, expr.Cast)) diff --git a/narwhals/_plan/_function.py b/narwhals/_plan/_function.py new file mode 100644 index 0000000000..332dbfc085 --- /dev/null +++ b/narwhals/_plan/_function.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from narwhals._plan._immutable import Immutable +from narwhals._plan.common import dispatch_getter, dispatch_method_name, replace +from narwhals._plan.options import FEOptions, FunctionOptions + +if TYPE_CHECKING: + from typing import Any, Callable, ClassVar + + from typing_extensions import Self, TypeAlias + + from narwhals._plan.expressions import ExprIR, FunctionExpr + from narwhals._plan.typing import Accessor, FunctionT + +__all__ = ["Function", "HorizontalFunction"] + +Incomplete: TypeAlias = "Any" + + +def _dispatch_generate_function( + tp: type[FunctionT], / +) -> Callable[[Incomplete, FunctionExpr[FunctionT], Incomplete, str], Incomplete]: + getter = dispatch_getter(tp) + + def _(ctx: Any, /, node: FunctionExpr[FunctionT], frame: Any, name: str) -> Any: + return getter(ctx)(node, frame, name) + + return _ + + +class Function(Immutable): + """Shared by expr functions and namespace functions. + + https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L114 + """ + + _function_options: ClassVar[staticmethod[[], FunctionOptions]] = staticmethod( + FunctionOptions.default + ) + __expr_ir_config__: ClassVar[FEOptions] = FEOptions.default() + __expr_ir_dispatch__: ClassVar[ + staticmethod[[Incomplete, FunctionExpr[Self], Incomplete, str], Incomplete] + ] + + @property + def function_options(self) -> FunctionOptions: + return self._function_options() + + @property + def is_scalar(self) -> bool: + return self.function_options.returns_scalar() + + def to_function_expr(self, *inputs: ExprIR) -> FunctionExpr[Self]: + from narwhals._plan.expressions.expr import FunctionExpr + + return FunctionExpr(input=inputs, function=self, options=self.function_options) + + def __init_subclass__( + cls: type[Self], + *args: Any, + accessor: Accessor | None = None, + options: Callable[[], FunctionOptions] | None = None, + config: FEOptions | None = None, + **kwds: Any, + ) -> None: + super().__init_subclass__(*args, **kwds) + if accessor: + config = replace(config or FEOptions.default(), accessor_name=accessor) + if options: + cls._function_options = staticmethod(options) + if config: + cls.__expr_ir_config__ = config + cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate_function(cls)) + + def __repr__(self) -> str: + return dispatch_method_name(type(self)) + + +class HorizontalFunction( + Function, options=FunctionOptions.horizontal, config=FEOptions.namespaced() +): ... diff --git a/narwhals/_plan/_guards.py b/narwhals/_plan/_guards.py index 867d16d397..0f62942ab8 100644 --- a/narwhals/_plan/_guards.py +++ b/narwhals/_plan/_guards.py @@ -11,9 +11,10 @@ if TYPE_CHECKING: from typing_extensions import TypeIs - from narwhals._plan import expr - from narwhals._plan.dummy import Expr, Series + from narwhals._plan import expressions as ir + from narwhals._plan.expr import Expr from narwhals._plan.protocols import CompliantSeries + from narwhals._plan.series import Series from narwhals._plan.typing import NativeSeriesT, Seq from narwhals.typing import NonNestedLiteral @@ -31,10 +32,10 @@ ) -def _dummy(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 - from narwhals._plan import dummy +def _ir(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 + from narwhals._plan import expressions as ir - return dummy + return ir def _expr(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 @@ -43,12 +44,18 @@ def _expr(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 return expr +def _series(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 + from narwhals._plan import series + + return series + + def is_non_nested_literal(obj: Any) -> TypeIs[NonNestedLiteral]: return obj is None or isinstance(obj, _NON_NESTED_LITERAL_TPS) def is_expr(obj: Any) -> TypeIs[Expr]: - return isinstance(obj, _dummy().Expr) + return isinstance(obj, _expr().Expr) def is_column(obj: Any) -> TypeIs[Expr]: @@ -57,7 +64,7 @@ def is_column(obj: Any) -> TypeIs[Expr]: def is_series(obj: Series[NativeSeriesT] | Any) -> TypeIs[Series[NativeSeriesT]]: - return isinstance(obj, _dummy().Series) + return isinstance(obj, _series().Series) def is_compliant_series( @@ -67,35 +74,35 @@ def is_compliant_series( def is_iterable_reject(obj: Any) -> TypeIs[str | bytes | Series | CompliantSeries]: - return isinstance(obj, (str, bytes, _dummy().Series)) or is_compliant_series(obj) + return isinstance(obj, (str, bytes, _series().Series)) or is_compliant_series(obj) -def is_window_expr(obj: Any) -> TypeIs[expr.WindowExpr]: - return isinstance(obj, _expr().WindowExpr) +def is_window_expr(obj: Any) -> TypeIs[ir.WindowExpr]: + return isinstance(obj, _ir().WindowExpr) -def is_function_expr(obj: Any) -> TypeIs[expr.FunctionExpr[Any]]: - return isinstance(obj, _expr().FunctionExpr) +def is_function_expr(obj: Any) -> TypeIs[ir.FunctionExpr[Any]]: + return isinstance(obj, _ir().FunctionExpr) -def is_binary_expr(obj: Any) -> TypeIs[expr.BinaryExpr]: - return isinstance(obj, _expr().BinaryExpr) +def is_binary_expr(obj: Any) -> TypeIs[ir.BinaryExpr]: + return isinstance(obj, _ir().BinaryExpr) -def is_agg_expr(obj: Any) -> TypeIs[expr.AggExpr]: - return isinstance(obj, _expr().AggExpr) +def is_agg_expr(obj: Any) -> TypeIs[ir.AggExpr]: + return isinstance(obj, _ir().AggExpr) -def is_aggregation(obj: Any) -> TypeIs[expr.AggExpr | expr.FunctionExpr[Any]]: +def is_aggregation(obj: Any) -> TypeIs[ir.AggExpr | ir.FunctionExpr[Any]]: """Superset of `ExprIR.is_scalar`, excludes literals & len.""" return is_agg_expr(obj) or (is_function_expr(obj) and obj.is_scalar) -def is_literal(obj: Any) -> TypeIs[expr.Literal[Any]]: - return isinstance(obj, _expr().Literal) +def is_literal(obj: Any) -> TypeIs[ir.Literal[Any]]: + return isinstance(obj, _ir().Literal) -def is_horizontal_reduction(obj: Any) -> TypeIs[expr.FunctionExpr[Any]]: +def is_horizontal_reduction(obj: Any) -> TypeIs[ir.FunctionExpr[Any]]: return is_function_expr(obj) and obj.options.is_input_wildcard_expansion() diff --git a/narwhals/_plan/expr_parsing.py b/narwhals/_plan/_parse.py similarity index 93% rename from narwhals/_plan/expr_parsing.py rename to narwhals/_plan/_parse.py index 1e450f2307..651166ebee 100644 --- a/narwhals/_plan/expr_parsing.py +++ b/narwhals/_plan/_parse.py @@ -4,7 +4,7 @@ # ruff: noqa: A002 from itertools import chain -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING from narwhals._plan._guards import is_expr, is_iterable_reject from narwhals._plan.exceptions import ( @@ -16,16 +16,16 @@ if TYPE_CHECKING: from collections.abc import Iterator - from typing import Any + from typing import Any, TypeVar import polars as pl from typing_extensions import TypeAlias, TypeIs - from narwhals._plan.common import ExprIR + from narwhals._plan.expressions import ExprIR from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq from narwhals.typing import IntoDType -T = TypeVar("T") + T = TypeVar("T") _RaisesInvalidIntoExprError: TypeAlias = "Any" """ @@ -88,14 +88,14 @@ def parse_into_expr_ir( input: IntoExpr, *, str_as_lit: bool = False, dtype: IntoDType | None = None ) -> ExprIR: """Parse a single input into an `ExprIR` node.""" - from narwhals._plan import demo as nwd + from narwhals._plan import col, lit if is_expr(input): expr = input elif isinstance(input, str) and not str_as_lit: - expr = nwd.col(input) + expr = col(input) else: - expr = nwd.lit(input, dtype=dtype) + expr = lit(input, dtype=dtype) return expr._ir @@ -157,14 +157,14 @@ def _parse_named_inputs(named_inputs: dict[str, IntoExpr], /) -> Iterator[ExprIR def _parse_constraints(constraints: dict[str, IntoExpr], /) -> Iterator[ExprIR]: - from narwhals._plan import demo as nwd + from narwhals._plan import col for name, value in constraints.items(): - yield (nwd.col(name) == value)._ir + yield (col(name) == value)._ir def _combine_predicates(predicates: Iterator[ExprIR], /) -> ExprIR: - from narwhals._plan.boolean import AllHorizontal + from narwhals._plan.expressions.boolean import AllHorizontal first = next(predicates, None) if not first: diff --git a/narwhals/_plan/expr_rewrites.py b/narwhals/_plan/_rewrites.py similarity index 72% rename from narwhals/_plan/expr_rewrites.py rename to narwhals/_plan/_rewrites.py index 597e8afc21..ae23fa4b9b 100644 --- a/narwhals/_plan/expr_rewrites.py +++ b/narwhals/_plan/_rewrites.py @@ -1,25 +1,25 @@ -"""Post-`expr_expansion` rewrites, in a similar style.""" +"""Post-`_expansion` rewrites, in a similar style.""" from __future__ import annotations from typing import TYPE_CHECKING -from narwhals._plan import expr_parsing as parse +from narwhals._plan._expansion import into_named_irs, prepare_projection from narwhals._plan._guards import ( is_aggregation, is_binary_expr, is_function_expr, is_window_expr, ) -from narwhals._plan.common import NamedIR, map_ir, replace -from narwhals._plan.expr_expansion import into_named_irs, prepare_projection +from narwhals._plan._parse import parse_into_seq_of_expr_ir +from narwhals._plan.common import replace if TYPE_CHECKING: from collections.abc import Sequence - from narwhals._plan.common import ExprIR + from narwhals._plan.expressions import ExprIR, NamedIR from narwhals._plan.schema import IntoFrozenSchema - from narwhals._plan.typing import IntoExpr, MapIR, Seq + from narwhals._plan.typing import IntoExpr, MapIR, NamedOrExprIRT, Seq def rewrite_all( @@ -31,9 +31,7 @@ def rewrite_all( - Currently we do a full traversal of each tree per-rewrite function - There's no caching *after* `prepare_projection` yet """ - out_irs, _, names = prepare_projection( - parse.parse_into_seq_of_expr_ir(*exprs), schema - ) + out_irs, _, names = prepare_projection(parse_into_seq_of_expr_ir(*exprs), schema) named_irs = into_named_irs(out_irs, names) return tuple(map_ir(ir, *rewrites) for ir in named_irs) @@ -85,3 +83,15 @@ def rewrite_binary_agg_over(window: ExprIR, /) -> ExprIR: binary_expr = window.expr return replace(binary_expr, right=replace(window, expr=binary_expr.right)) return window + + +def map_ir( + origin: NamedOrExprIRT, function: MapIR, *more_functions: MapIR +) -> NamedOrExprIRT: + """Apply one or more functions, sequentially, to all of `origin`'s children.""" + if more_functions: + result = origin + for fn in (function, *more_functions): + result = result.map_ir(fn) + return result + return origin.map_ir(function) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index fc61e69acc..27a02bc2ed 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -20,8 +20,8 @@ from narwhals._arrow.typing import ChunkedArrayAny from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar from narwhals._plan.arrow.namespace import ArrowNamespace - from narwhals._plan.common import ExprIR, NamedIR - from narwhals._plan.dummy import DataFrame as NwDataFrame + from narwhals._plan.dataframe import DataFrame as NwDataFrame + from narwhals._plan.expressions import ExprIR, NamedIR from narwhals._plan.options import SortMultipleOptions from narwhals._plan.typing import Seq from narwhals.dtypes import DType @@ -50,7 +50,7 @@ def __len__(self) -> int: return self.native.num_rows def to_narwhals(self) -> NwDataFrame[pa.Table, ChunkedArrayAny]: - from narwhals._plan.dummy import DataFrame + from narwhals._plan.dataframe import DataFrame return DataFrame[pa.Table, "ChunkedArrayAny"]._from_compliant(self) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index d8a163d120..57ec5196d6 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -9,7 +9,7 @@ from narwhals._plan.arrow import functions as fn from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co -from narwhals._plan.common import ExprIR, NamedIR +from narwhals._plan.expressions import NamedIR from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch, namespace from narwhals._utils import ( Implementation, @@ -26,8 +26,10 @@ from typing_extensions import Self, TypeAlias from narwhals._arrow.typing import ChunkedArrayAny, Incomplete - from narwhals._plan import boolean, expr - from narwhals._plan.aggregation import ( + from narwhals._plan import expressions as ir + from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame + from narwhals._plan.arrow.namespace import ArrowNamespace + from narwhals._plan.expressions.aggregation import ( ArgMax, ArgMin, Count, @@ -43,19 +45,16 @@ Sum, Var, ) - from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame - from narwhals._plan.arrow.namespace import ArrowNamespace - from narwhals._plan.boolean import All, IsBetween, IsFinite, IsNan, IsNull, Not - from narwhals._plan.expr import ( - AnonymousExpr, - BinaryExpr, - FunctionExpr, - OrderedWindowExpr, - RollingExpr, - TernaryExpr, - WindowExpr, + from narwhals._plan.expressions.boolean import ( + All, + IsBetween, + IsFinite, + IsNan, + IsNull, + Not, ) - from narwhals._plan.functions import FillNull, Pow + from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr + from narwhals._plan.expressions.functions import FillNull, Pow from narwhals.typing import Into1DArray, IntoDType, PythonLiteral Expr: TypeAlias = "ArrowExpr" @@ -74,7 +73,7 @@ def __narwhals_namespace__(self) -> ArrowNamespace: return ArrowNamespace(self.version) def _with_native(self, native: Any, name: str, /) -> StoresNativeT_co: ... - def cast(self, node: expr.Cast, frame: Frame, name: str) -> StoresNativeT_co: + def cast(self, node: ir.Cast, frame: Frame, name: str) -> StoresNativeT_co: data_type = narwhals_to_native_dtype(node.dtype, frame.version) native = node.expr.dispatch(self, frame, name).native return self._with_native(fn.cast(native, data_type), name) @@ -119,7 +118,7 @@ def all(self, node: FunctionExpr[All], frame: Frame, name: str) -> StoresNativeT return self._unary_function(fn.all_)(node, frame, name) def any( - self, node: FunctionExpr[boolean.Any], frame: Frame, name: str + self, node: FunctionExpr[ir.boolean.Any], frame: Frame, name: str ) -> StoresNativeT_co: return self._unary_function(fn.any_)(node, frame, name) @@ -147,7 +146,7 @@ def binary_expr(self, node: BinaryExpr, frame: Frame, name: str) -> StoresNative return self._with_native(result, name) def ternary_expr( - self, node: TernaryExpr, frame: Frame, name: str + self, node: ir.TernaryExpr, frame: Frame, name: str ) -> StoresNativeT_co: when = node.predicate.dispatch(self, frame, name) then = node.truthy.dispatch(self, frame, name) @@ -192,7 +191,7 @@ def _with_native(self, result: ChunkedOrScalarAny, name: str, /) -> Scalar | Sel return ArrowScalar.from_native(result, name, version=self.version) return self.from_native(result, name or self.name, self.version) - def _dispatch_expr(self, node: ExprIR, frame: Frame, name: str) -> Series: + def _dispatch_expr(self, node: ir.ExprIR, frame: Frame, name: str) -> Series: """Use instead of `_dispatch` *iff* an operation isn't natively supported on `ChunkedArray`. There is no need to broadcast, as they may have a cheaper impl elsewhere (`CompliantScalar` or `ArrowScalar`). @@ -218,12 +217,12 @@ def broadcast(self, length: int, /) -> Series: def __len__(self) -> int: return len(self._evaluated) - def sort(self, node: expr.Sort, frame: Frame, name: str) -> Expr: + def sort(self, node: ir.Sort, frame: Frame, name: str) -> Expr: native = self._dispatch_expr(node.expr, frame, name).native sorted_indices = pc.array_sort_indices(native, options=node.options.to_arrow()) return self._with_native(native.take(sorted_indices), name) - def sort_by(self, node: expr.SortBy, frame: Frame, name: str) -> Expr: + def sort_by(self, node: ir.SortBy, frame: Frame, name: str) -> Expr: series = self._dispatch_expr(node.expr, frame, name) by = ( self._dispatch_expr(e, frame, f"_{idx}") @@ -235,7 +234,7 @@ def sort_by(self, node: expr.SortBy, frame: Frame, name: str) -> Expr: result: ChunkedArrayAny = df.native.column(0).take(indices) return self._with_native(result, name) - def filter(self, node: expr.Filter, frame: Frame, name: str) -> Expr: + def filter(self, node: ir.Filter, frame: Frame, name: str) -> Expr: return self._with_native( self._dispatch_expr(node.expr, frame, name).native.filter( self._dispatch_expr(node.by, frame, name).native @@ -319,11 +318,11 @@ def min(self, node: Min, frame: Frame, name: str) -> Scalar: # - [x] `map_batches` is defined in `EagerExpr`, might be simpler here than on main # - [ ] `rolling_expr` has 4 variants - def over(self, node: WindowExpr, frame: Frame, name: str) -> Self: + def over(self, node: ir.WindowExpr, frame: Frame, name: str) -> Self: raise NotImplementedError def over_ordered( - self, node: OrderedWindowExpr, frame: Frame, name: str + self, node: ir.OrderedWindowExpr, frame: Frame, name: str ) -> Self | Scalar: if node.partition_by: msg = f"Need to implement `group_by`, `join` for:\n{node!r}" @@ -347,7 +346,7 @@ def over_ordered( return self._with_native(result, name) # NOTE: Can't implement in `EagerExpr`, since it doesn't derive `ExprDispatch` - def map_batches(self, node: AnonymousExpr, frame: Frame, name: str) -> Self: + def map_batches(self, node: ir.AnonymousExpr, frame: Frame, name: str) -> Self: if node.is_scalar: # NOTE: Just trying to avoid redoing the whole API for `Series` msg = "Only elementwise is currently supported" @@ -361,7 +360,7 @@ def map_batches(self, node: AnonymousExpr, frame: Frame, name: str) -> Self: result = result.cast(dtype) return self.from_series(result) - def rolling_expr(self, node: RollingExpr, frame: Frame, name: str) -> Self: + def rolling_expr(self, node: ir.RollingExpr, frame: Frame, name: str) -> Self: raise NotImplementedError @@ -414,7 +413,7 @@ def from_series(cls, series: Series) -> Self: msg = f"Too long {len(series)!r}" raise InvalidOperationError(msg) - def _dispatch_expr(self, node: ExprIR, frame: Frame, name: str) -> Series: + def _dispatch_expr(self, node: ir.ExprIR, frame: Frame, name: str) -> Series: msg = f"Expected unreachable, but hit at: {node!r}" raise InvalidOperationError(msg) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 83ecafae5f..7a16404d3d 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -13,7 +13,7 @@ chunked_array as _chunked_array, floordiv_compat as floordiv, ) -from narwhals._plan import operators as ops +from narwhals._plan.expressions import operators as ops from narwhals._utils import Implementation if TYPE_CHECKING: diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index f7bfaaa330..e4f68f27db 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -9,8 +9,7 @@ from narwhals._arrow.utils import narwhals_to_native_dtype from narwhals._plan._guards import is_tuple_of from narwhals._plan.arrow import functions as fn -from narwhals._plan.common import collect -from narwhals._plan.literal import is_literal_scalar +from narwhals._plan.expressions.literal import is_literal_scalar from narwhals._plan.protocols import EagerNamespace from narwhals._utils import Version from narwhals.exceptions import InvalidOperationError @@ -19,15 +18,15 @@ from collections.abc import Callable, Iterable, Iterator, Sequence from narwhals._arrow.typing import ChunkedArrayAny - from narwhals._plan import expr, functions as F from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar from narwhals._plan.arrow.series import ArrowSeries as Series - from narwhals._plan.boolean import AllHorizontal, AnyHorizontal - from narwhals._plan.dummy import Series as NwSeries - from narwhals._plan.expr import FunctionExpr, RangeExpr - from narwhals._plan.ranges import IntRange - from narwhals._plan.strings import ConcatStr + from narwhals._plan.expressions import expr, functions as F + from narwhals._plan.expressions.boolean import AllHorizontal, AnyHorizontal + from narwhals._plan.expressions.expr import FunctionExpr, RangeExpr + from narwhals._plan.expressions.ranges import IntRange + from narwhals._plan.expressions.strings import ConcatStr + from narwhals._plan.series import Series as NwSeries from narwhals.typing import ConcatMethod, NonNestedLiteral, PythonLiteral @@ -225,7 +224,7 @@ def gen(objs: Iterable[Frame | Series]) -> Iterator[tuple[ChunkedArrayAny, str]] return self._dataframe.from_native(native, self.version) def _concat_vertical(self, items: Iterable[Frame | Series]) -> Frame | Series: - collected = collect(items) + collected = items if isinstance(items, tuple) else tuple(items) if is_tuple_of(collected, self._series): sers = collected chunked = fn.concat_vertical_chunked(ser.native for ser in sers) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index f73f21b26b..0b4267f214 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -6,39 +6,27 @@ from collections.abc import Iterable from decimal import Decimal from operator import attrgetter -from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast, overload +from typing import TYPE_CHECKING, cast, overload -from narwhals._plan._guards import is_function_expr, is_iterable_reject, is_literal -from narwhals._plan._immutable import Immutable -from narwhals._plan.options import ExprIROptions, FEOptions, FunctionOptions -from narwhals._plan.typing import ( - Accessor, - DTypeT, - ExprIRT, - ExprIRT2, - FunctionT, - IRNamespaceT, - MapIR, - NamedOrExprIRT, - NonNestedDTypeT, - OneOrIterable, - Seq, -) +from narwhals._plan._guards import is_iterable_reject from narwhals.dtypes import DType from narwhals.utils import Version if TYPE_CHECKING: from collections.abc import Iterator - from typing import Any, Callable - - from typing_extensions import Self, TypeAlias - - from narwhals._plan.dummy import Expr, Selector - from narwhals._plan.expr import Alias, Cast, Column, FunctionExpr - from narwhals._plan.meta import IRMetaNamespace - from narwhals._plan.protocols import Ctx, FrameT_contra, R_co + from typing import Any, Callable, TypeVar + + from narwhals._plan.typing import ( + DTypeT, + ExprIRT, + FunctionT, + NonNestedDTypeT, + OneOrIterable, + ) from narwhals.typing import NonNestedDType, NonNestedLiteral + T = TypeVar("T") + if sys.version_info >= (3, 13): from copy import replace as replace # noqa: PLC0414 @@ -53,11 +41,7 @@ def replace(obj: T, /, **changes: Any) -> T: return func(obj, **changes) # type: ignore[no-any-return] -T = TypeVar("T") -Incomplete: TypeAlias = "Any" - - -def _pascal_to_snake_case(s: str) -> str: +def pascal_to_snake_case(s: str) -> str: """Convert a PascalCase, camelCase string to snake_case. Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62 @@ -76,374 +60,19 @@ def _re_repl_snake(match: re.Match[str], /) -> str: return f"{match.group(1)}_{match.group(2)}" -def _dispatch_method_name(tp: type[ExprIRT | FunctionT]) -> str: +def dispatch_method_name(tp: type[ExprIRT | FunctionT]) -> str: config = tp.__expr_ir_config__ - name = config.override_name or _pascal_to_snake_case(tp.__name__) + name = config.override_name or pascal_to_snake_case(tp.__name__) return f"{ns}.{name}" if (ns := getattr(config, "accessor_name", "")) else name -def _dispatch_getter(tp: type[ExprIRT | FunctionT]) -> Callable[[Any], Any]: - getter = attrgetter(_dispatch_method_name(tp)) +def dispatch_getter(tp: type[ExprIRT | FunctionT]) -> Callable[[Any], Any]: + getter = attrgetter(dispatch_method_name(tp)) if tp.__expr_ir_config__.origin == "expr": return getter return lambda ctx: getter(ctx.__narwhals_namespace__()) -def _dispatch_generate( - tp: type[ExprIRT], / -) -> Callable[[Incomplete, ExprIRT, Incomplete, str], Incomplete]: - if not tp.__expr_ir_config__.allow_dispatch: - - def _(ctx: Any, /, node: ExprIRT, _: Any, name: str) -> Any: - msg = ( - f"{tp.__name__!r} should not appear at the compliant-level.\n\n" - f"Make sure to expand all expressions first, got:\n{ctx!r}\n{node!r}\n{name!r}" - ) - raise TypeError(msg) - - return _ - getter = _dispatch_getter(tp) - - def _(ctx: Any, /, node: ExprIRT, frame: Any, name: str) -> Any: - return getter(ctx)(node, frame, name) - - return _ - - -def _dispatch_generate_function( - tp: type[FunctionT], / -) -> Callable[[Incomplete, FunctionExpr[FunctionT], Incomplete, str], Incomplete]: - getter = _dispatch_getter(tp) - - def _(ctx: Any, /, node: FunctionExpr[FunctionT], frame: Any, name: str) -> Any: - return getter(ctx)(node, frame, name) - - return _ - - -class ExprIR(Immutable): - """Anything that can be a node on a graph of expressions.""" - - _child: ClassVar[Seq[str]] = () - """Nested node names, in iteration order.""" - - __expr_ir_config__: ClassVar[ExprIROptions] = ExprIROptions.default() - __expr_ir_dispatch__: ClassVar[ - staticmethod[[Incomplete, Self, Incomplete, str], Incomplete] - ] - - def __init_subclass__( - cls: type[Self], - *args: Any, - child: Seq[str] = (), - config: ExprIROptions | None = None, - **kwds: Any, - ) -> None: - super().__init_subclass__(*args, **kwds) - if child: - cls._child = child - if config: - cls.__expr_ir_config__ = config - cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate(cls)) - - def dispatch( - self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, / - ) -> R_co: - """Evaluate expression in `frame`, using `ctx` for implementation(s).""" - return self.__expr_ir_dispatch__(ctx, cast("Self", self), frame, name) # type: ignore[no-any-return] - - def to_narwhals(self, version: Version = Version.MAIN) -> Expr: - from narwhals._plan import dummy - - tp = dummy.Expr if version is Version.MAIN else dummy.ExprV1 - return tp._from_ir(self) - - @property - def is_scalar(self) -> bool: - return False - - def map_ir(self, function: MapIR, /) -> ExprIR: - """Apply `function` to each child node, returning a new `ExprIR`. - - See [`polars_plan::plans::iterator::Expr.map_expr`] and [`polars_plan::plans::visitor::visitors`]. - - [`polars_plan::plans::iterator::Expr.map_expr`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/iterator.rs#L152-L159 - [`polars_plan::plans::visitor::visitors`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/visitor/visitors.rs - """ - if not self._child: - return function(self) - children = ((name, getattr(self, name)) for name in self._child) - changed = {name: _map_ir_child(child, function) for name, child in children} - return function(replace(self, **changed)) - - def iter_left(self) -> Iterator[ExprIR]: - """Yield nodes root->leaf. - - Examples: - >>> from narwhals._plan import demo as nwd - >>> - >>> a = nwd.col("a") - >>> b = a.alias("b") - >>> c = b.min().alias("c") - >>> d = c.over(nwd.col("e"), nwd.col("f")) - >>> - >>> list(a._ir.iter_left()) - [col('a')] - >>> - >>> list(b._ir.iter_left()) - [col('a'), col('a').alias('b')] - >>> - >>> list(c._ir.iter_left()) - [col('a'), col('a').alias('b'), col('a').alias('b').min(), col('a').alias('b').min().alias('c')] - >>> - >>> list(d._ir.iter_left()) - [col('a'), col('a').alias('b'), col('a').alias('b').min(), col('a').alias('b').min().alias('c'), col('e'), col('f'), col('a').alias('b').min().alias('c').over([col('e'), col('f')])] - """ - for name in self._child: - child: ExprIR | Seq[ExprIR] = getattr(self, name) - if isinstance(child, ExprIR): - yield from child.iter_left() - else: - for node in child: - yield from node.iter_left() - yield self - - def iter_right(self) -> Iterator[ExprIR]: - """Yield nodes leaf->root. - - Note: - Identical to `iter_left` for root nodes. - - Examples: - >>> from narwhals._plan import demo as nwd - >>> - >>> a = nwd.col("a") - >>> b = a.alias("b") - >>> c = b.min().alias("c") - >>> d = c.over(nwd.col("e"), nwd.col("f")) - >>> - >>> list(a._ir.iter_right()) - [col('a')] - >>> - >>> list(b._ir.iter_right()) - [col('a').alias('b'), col('a')] - >>> - >>> list(c._ir.iter_right()) - [col('a').alias('b').min().alias('c'), col('a').alias('b').min(), col('a').alias('b'), col('a')] - >>> - >>> list(d._ir.iter_right()) - [col('a').alias('b').min().alias('c').over([col('e'), col('f')]), col('f'), col('e'), col('a').alias('b').min().alias('c'), col('a').alias('b').min(), col('a').alias('b'), col('a')] - """ - yield self - for name in reversed(self._child): - child: ExprIR | Seq[ExprIR] = getattr(self, name) - if isinstance(child, ExprIR): - yield from child.iter_right() - else: - for node in reversed(child): - yield from node.iter_right() - - def iter_root_names(self) -> Iterator[ExprIR]: - """Override for different iteration behavior in `ExprIR.meta.root_names`. - - Note: - Identical to `iter_left` by default. - """ - yield from self.iter_left() - - def iter_output_name(self) -> Iterator[ExprIR]: - """Override for different iteration behavior in `ExprIR.meta.output_name`. - - Note: - Identical to `iter_right` by default. - """ - yield from self.iter_right() - - @property - def meta(self) -> IRMetaNamespace: - from narwhals._plan.meta import IRMetaNamespace - - return IRMetaNamespace(_ir=self) - - def cast(self, dtype: DType) -> Cast: - from narwhals._plan.expr import Cast - - return Cast(expr=self, dtype=dtype) - - def alias(self, name: str) -> Alias: - from narwhals._plan.expr import Alias - - return Alias(expr=self, name=name) - - def _repr_html_(self) -> str: - return self.__repr__() - - -class SelectorIR(ExprIR, config=ExprIROptions.no_dispatch()): - def to_narwhals(self, version: Version = Version.MAIN) -> Selector: - from narwhals._plan import dummy - - if version is Version.MAIN: - return dummy.Selector._from_ir(self) - return dummy.SelectorV1._from_ir(self) - - def matches_column(self, name: str, dtype: DType) -> bool: - """Return True if we can select this column. - - - Thinking that we could get more cache hits on an individual column basis. - - May also be more efficient to not iterate over the schema for every selector - - Instead do one pass, evaluating every selector against a single column at a time - """ - raise NotImplementedError(type(self)) - - -class NamedIR(Immutable, Generic[ExprIRT]): - """Post-projection expansion wrapper for `ExprIR`. - - - Somewhat similar to [`polars_plan::plans::expr_ir::ExprIR`]. - - The [`polars_plan::plans::aexpr::AExpr`] stage has been skipped (*for now*) - - Parts of that will probably be in here too - - `AExpr` seems like too much duplication when we won't get the memory allocation benefits in python - - [`polars_plan::plans::expr_ir::ExprIR`]: https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/expr_ir.rs#L63-L74 - [`polars_plan::plans::aexpr::AExpr`]: https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/aexpr/mod.rs#L145-L231 - """ - - __slots__ = ("expr", "name") - expr: ExprIRT - name: str - - @staticmethod - def from_name(name: str, /) -> NamedIR[Column]: - """Construct as a simple, unaliased `col(name)` expression. - - Intended to be used in `with_columns` from a `FrozenSchema`'s keys. - """ - from narwhals._plan.expr import col - - return NamedIR(expr=col(name), name=name) - - @staticmethod - def from_ir(expr: ExprIRT2, /) -> NamedIR[ExprIRT2]: - """Construct from an already expanded `ExprIR`. - - Should be cheap to get the output name from cache, but will raise if used - without care. - """ - return NamedIR(expr=expr, name=expr.meta.output_name(raise_if_undetermined=True)) - - def map_ir(self, function: MapIR, /) -> Self: - """**WARNING**: don't use renaming ops here, or `self.name` is invalid.""" - return replace(self, expr=function(self.expr.map_ir(function))) - - def __repr__(self) -> str: - return f"{self.name}={self.expr!r}" - - def _repr_html_(self) -> str: - return f"{self.name}={self.expr._repr_html_()}" - - def is_elementwise_top_level(self) -> bool: - """Return True if the outermost node is elementwise. - - Based on [`polars_plan::plans::aexpr::properties::AExpr.is_elementwise_top_level`] - - This check: - - Is not recursive - - Is not valid on `ExprIR` *prior* to being expanded - - [`polars_plan::plans::aexpr::properties::AExpr.is_elementwise_top_level`]: https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/aexpr/properties.rs#L16-L44 - """ - from narwhals._plan import expr - - ir = self.expr - if is_function_expr(ir): - return ir.options.is_elementwise() - if is_literal(ir): - return ir.is_scalar - return isinstance(ir, (expr.BinaryExpr, expr.Column, expr.TernaryExpr, expr.Cast)) - - -class IRNamespace(Immutable): - __slots__ = ("_ir",) - _ir: ExprIR - - @classmethod - def from_expr(cls, expr: Expr, /) -> Self: - return cls(_ir=expr._ir) - - -class ExprNamespace(Immutable, Generic[IRNamespaceT]): - __slots__ = ("_expr",) - _expr: Expr - - @property - def _ir_namespace(self) -> type[IRNamespaceT]: - raise NotImplementedError - - @property - def _ir(self) -> IRNamespaceT: - return self._ir_namespace.from_expr(self._expr) - - def _to_narwhals(self, ir: ExprIR, /) -> Expr: - return self._expr._from_ir(ir) - - def _with_unary(self, function: Function, /) -> Expr: - return self._expr._with_unary(function) - - -class Function(Immutable): - """Shared by expr functions and namespace functions. - - https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L114 - """ - - _function_options: ClassVar[staticmethod[[], FunctionOptions]] = staticmethod( - FunctionOptions.default - ) - __expr_ir_config__: ClassVar[FEOptions] = FEOptions.default() - __expr_ir_dispatch__: ClassVar[ - staticmethod[[Incomplete, FunctionExpr[Self], Incomplete, str], Incomplete] - ] - - @property - def function_options(self) -> FunctionOptions: - return self._function_options() - - @property - def is_scalar(self) -> bool: - return self.function_options.returns_scalar() - - def to_function_expr(self, *inputs: ExprIR) -> FunctionExpr[Self]: - from narwhals._plan.expr import FunctionExpr - - return FunctionExpr(input=inputs, function=self, options=self.function_options) - - def __init_subclass__( - cls: type[Self], - *args: Any, - accessor: Accessor | None = None, - options: Callable[[], FunctionOptions] | None = None, - config: FEOptions | None = None, - **kwds: Any, - ) -> None: - super().__init_subclass__(*args, **kwds) - if accessor: - config = replace(config or FEOptions.default(), accessor_name=accessor) - if options: - cls._function_options = staticmethod(options) - if config: - cls.__expr_ir_config__ = config - cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate_function(cls)) - - def __repr__(self) -> str: - return _dispatch_method_name(type(self)) - - -class HorizontalFunction( - Function, options=FunctionOptions.horizontal, config=FEOptions.namespaced() -): ... - - def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) -> DType: dtypes = version.dtypes mapping: dict[type[NonNestedLiteral], type[NonNestedDType]] = { @@ -467,33 +96,12 @@ def into_dtype(dtype: type[NonNestedDTypeT], /) -> NonNestedDTypeT: ... @overload def into_dtype(dtype: DTypeT, /) -> DTypeT: ... def into_dtype(dtype: DTypeT | type[NonNestedDTypeT], /) -> DTypeT | NonNestedDTypeT: + # NOTE: `mypy` needs to learn intersections if isinstance(dtype, type) and issubclass(dtype, DType): - # NOTE: `mypy` needs to learn intersections - return dtype() # type: ignore[return-value] + return cast("NonNestedDTypeT", dtype()) return dtype -def collect(iterable: Seq[T] | Iterable[T], /) -> Seq[T]: - """Collect `iterable` into a `tuple`, *iff* it is not one already.""" - return iterable if isinstance(iterable, tuple) else tuple(iterable) - - -def map_ir( - origin: NamedOrExprIRT, function: MapIR, *more_functions: MapIR -) -> NamedOrExprIRT: - """Apply one or more functions, sequentially, to all of `origin`'s children.""" - if more_functions: - result = origin - for fn in (function, *more_functions): - result = result.map_ir(fn) - return result - return origin.map_ir(function) - - -def _map_ir_child(obj: ExprIR | Seq[ExprIR], fn: MapIR, /) -> ExprIR | Seq[ExprIR]: - return obj.map_ir(fn) if isinstance(obj, ExprIR) else tuple(e.map_ir(fn) for e in obj) - - # TODO @dangotbanned: Review again and try to work around (https://github.com/microsoft/pyright/issues/10673#issuecomment-3033789021) # The issue is `T` possibly being `Iterable` # Ignoring here still leaks the issue to the caller, where you need to annotate the base case diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py new file mode 100644 index 0000000000..8f06f1e5c9 --- /dev/null +++ b/narwhals/_plan/dataframe.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload + +from narwhals._plan import _expansion, _parse +from narwhals._plan.contexts import ExprContext +from narwhals._plan.expr import _parse_sort_by +from narwhals._plan.series import Series +from narwhals._plan.typing import ( + IntoExpr, + NativeDataFrameT, + NativeFrameT, + NativeSeriesT, + OneOrIterable, +) +from narwhals._utils import Version, generate_repr +from narwhals.dependencies import is_pyarrow_table +from narwhals.schema import Schema + +if TYPE_CHECKING: + import pyarrow as pa + from typing_extensions import Self + + from narwhals._plan.expressions import ExprIR, NamedIR + from narwhals._plan.protocols import CompliantBaseFrame, CompliantDataFrame + from narwhals._plan.schema import FrozenSchema + from narwhals._plan.typing import Seq + from narwhals.typing import NativeFrame + + +class BaseFrame(Generic[NativeFrameT]): + _compliant: CompliantBaseFrame[Any, NativeFrameT] + _version: ClassVar[Version] = Version.MAIN + + @property + def version(self) -> Version: + return self._version + + @property + def schema(self) -> Schema: + return Schema(self._compliant.schema.items()) + + @property + def columns(self) -> list[str]: + return self._compliant.columns + + def __repr__(self) -> str: # pragma: no cover + return generate_repr(f"nw.{type(self).__name__}", self.to_native().__repr__()) + + @classmethod + def from_native(cls, native: Any, /) -> Self: + raise NotImplementedError + + @classmethod + def _from_compliant(cls, compliant: CompliantBaseFrame[Any, NativeFrameT], /) -> Self: + obj = cls.__new__(cls) + obj._compliant = compliant + return obj + + def to_native(self) -> NativeFrameT: + return self._compliant.native + + def _project( + self, + exprs: tuple[OneOrIterable[IntoExpr], ...], + named_exprs: dict[str, Any], + context: ExprContext, + /, + ) -> tuple[Seq[NamedIR[ExprIR]], FrozenSchema]: + """Temp, while these parts aren't connected, this is easier for testing.""" + irs, schema_frozen, output_names = _expansion.prepare_projection( + _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), self.schema + ) + named_irs = _expansion.into_named_irs(irs, output_names) + return schema_frozen.project(named_irs, context) + + def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: + named_irs, _ = self._project(exprs, named_exprs, ExprContext.SELECT) + return self._from_compliant(self._compliant.select(named_irs)) + + def with_columns(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: + named_irs, _ = self._project(exprs, named_exprs, ExprContext.WITH_COLUMNS) + return self._from_compliant(self._compliant.with_columns(named_irs)) + + def sort( + self, + by: OneOrIterable[str], + *more_by: str, + descending: OneOrIterable[bool] = False, + nulls_last: OneOrIterable[bool] = False, + ) -> Self: + sort, opts = _parse_sort_by( + by, *more_by, descending=descending, nulls_last=nulls_last + ) + irs, _, output_names = _expansion.prepare_projection(sort, self.schema) + named_irs = _expansion.into_named_irs(irs, output_names) + return self._from_compliant(self._compliant.sort(named_irs, opts)) + + +class DataFrame(BaseFrame[NativeDataFrameT], Generic[NativeDataFrameT, NativeSeriesT]): + _compliant: CompliantDataFrame[Any, NativeDataFrameT, NativeSeriesT] + + @property + def _series(self) -> type[Series[NativeSeriesT]]: + return Series[NativeSeriesT] + + # NOTE: Gave up on trying to get typing working for now + @classmethod + def from_native( # type: ignore[override] + cls, native: NativeFrame, / + ) -> DataFrame[pa.Table, pa.ChunkedArray[Any]]: + if is_pyarrow_table(native): + from narwhals._plan.arrow.dataframe import ArrowDataFrame + + return ArrowDataFrame.from_native(native, cls._version).to_narwhals() + + raise NotImplementedError(type(native)) + + @overload + def to_dict( + self, *, as_series: Literal[True] = ... + ) -> dict[str, Series[NativeSeriesT]]: ... + @overload + def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... + @overload + def to_dict( + self, *, as_series: bool + ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[Any]]: ... + def to_dict( + self, *, as_series: bool = True + ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[Any]]: + if as_series: + return { + key: self._series._from_compliant(value) + for key, value in self._compliant.to_dict(as_series=as_series).items() + } + return self._compliant.to_dict(as_series=as_series) + + def __len__(self) -> int: + return len(self._compliant) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py deleted file mode 100644 index ab89b97c96..0000000000 --- a/narwhals/_plan/demo.py +++ /dev/null @@ -1,161 +0,0 @@ -from __future__ import annotations - -import builtins -import typing as t - -from narwhals._plan import _guards, boolean, expr, expr_parsing as parse, functions as F -from narwhals._plan.common import into_dtype, py_to_narwhals_dtype -from narwhals._plan.expr import All, Len -from narwhals._plan.literal import ScalarLiteral, SeriesLiteral -from narwhals._plan.ranges import IntRange -from narwhals._plan.strings import ConcatStr -from narwhals._plan.when_then import When -from narwhals._utils import Version, flatten - -if t.TYPE_CHECKING: - from narwhals._plan.dummy import Expr, Series - from narwhals._plan.typing import IntoExpr, IntoExprColumn, NativeSeriesT - from narwhals.dtypes import IntegerType - from narwhals.typing import IntoDType, NonNestedLiteral - - -def col(*names: str | t.Iterable[str]) -> Expr: - flat = tuple(flatten(names)) - node = expr.col(flat[0]) if builtins.len(flat) == 1 else expr.cols(*flat) - return node.to_narwhals() - - -def nth(*indices: int | t.Sequence[int]) -> Expr: - flat = tuple(flatten(indices)) - node = expr.nth(flat[0]) if builtins.len(flat) == 1 else expr.index_columns(*flat) - return node.to_narwhals() - - -def lit( - value: NonNestedLiteral | Series[NativeSeriesT], dtype: IntoDType | None = None -) -> Expr: - if _guards.is_series(value): - return SeriesLiteral(value=value).to_literal().to_narwhals() - if not _guards.is_non_nested_literal(value): - msg = f"{type(value).__name__!r} is not supported in `nw.lit`, got: {value!r}." - raise TypeError(msg) - if dtype is None: - dtype = py_to_narwhals_dtype(value, Version.MAIN) - else: - dtype = into_dtype(dtype) - return ScalarLiteral(value=value, dtype=dtype).to_literal().to_narwhals() - - -def len() -> Expr: - return Len().to_narwhals() - - -def all() -> Expr: - return All().to_narwhals() - - -def exclude(*names: str | t.Iterable[str]) -> Expr: - return all().exclude(*names) - - -def max(*columns: str) -> Expr: - return col(columns).max() - - -def mean(*columns: str) -> Expr: - return col(columns).mean() - - -def min(*columns: str) -> Expr: - return col(columns).min() - - -def median(*columns: str) -> Expr: - return col(columns).median() - - -def sum(*columns: str) -> Expr: - return col(columns).sum() - - -def all_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: - it = parse.parse_into_seq_of_expr_ir(*exprs) - return boolean.AllHorizontal().to_function_expr(*it).to_narwhals() - - -def any_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: - it = parse.parse_into_seq_of_expr_ir(*exprs) - return boolean.AnyHorizontal().to_function_expr(*it).to_narwhals() - - -def sum_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: - it = parse.parse_into_seq_of_expr_ir(*exprs) - return F.SumHorizontal().to_function_expr(*it).to_narwhals() - - -def min_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: - it = parse.parse_into_seq_of_expr_ir(*exprs) - return F.MinHorizontal().to_function_expr(*it).to_narwhals() - - -def max_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: - it = parse.parse_into_seq_of_expr_ir(*exprs) - return F.MaxHorizontal().to_function_expr(*it).to_narwhals() - - -def mean_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: - it = parse.parse_into_seq_of_expr_ir(*exprs) - return F.MeanHorizontal().to_function_expr(*it).to_narwhals() - - -def concat_str( - exprs: IntoExpr | t.Iterable[IntoExpr], - *more_exprs: IntoExpr, - separator: str = "", - ignore_nulls: bool = False, -) -> Expr: - it = parse.parse_into_seq_of_expr_ir(exprs, *more_exprs) - return ( - ConcatStr(separator=separator, ignore_nulls=ignore_nulls) - .to_function_expr(*it) - .to_narwhals() - ) - - -def when( - *predicates: IntoExprColumn | t.Iterable[IntoExprColumn], **constraints: t.Any -) -> When: - """Start a `when-then-otherwise` expression. - - Examples: - >>> from narwhals._plan import demo as nwd - - >>> nwd.when(nwd.col("y") == "b").then(1) - nw._plan.Expr(main): - .when([(col('y')) == (lit(str: b))]).then(lit(int: 1)).otherwise(lit(null)) - """ - condition = parse.parse_predicates_constraints_into_expr_ir( - *predicates, **constraints - ) - return When._from_ir(condition) - - -def int_range( - start: int | IntoExprColumn = 0, - end: int | IntoExprColumn | None = None, - step: int = 1, - *, - dtype: IntegerType | type[IntegerType] = Version.MAIN.dtypes.Int64, - eager: bool = False, -) -> Expr: - if end is None: - end = start - start = 0 - if eager: - msg = f"{eager=}" - raise NotImplementedError(msg) - return ( - IntRange(step=step, dtype=into_dtype(dtype)) - .to_function_expr(*parse.parse_into_seq_of_expr_ir(start, end)) - .to_narwhals() - ) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py deleted file mode 100644 index 0a1e469917..0000000000 --- a/narwhals/_plan/dummy.py +++ /dev/null @@ -1,887 +0,0 @@ -"""Mock version of current narwhals API.""" - -from __future__ import annotations - -import math -from collections.abc import Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload - -from narwhals._plan import ( - aggregation as agg, - boolean, - expr, - expr_expansion, - expr_parsing as parse, - functions as F, - operators as ops, -) -from narwhals._plan._guards import is_column, is_expr, is_series -from narwhals._plan.common import into_dtype -from narwhals._plan.contexts import ExprContext -from narwhals._plan.options import ( - EWMOptions, - RankOptions, - SortMultipleOptions, - SortOptions, - rolling_options, -) -from narwhals._plan.selectors import by_name -from narwhals._plan.typing import NativeDataFrameT, NativeFrameT, NativeSeriesT -from narwhals._plan.window import Over -from narwhals._utils import Version, generate_repr -from narwhals.dependencies import is_pyarrow_chunked_array, is_pyarrow_table -from narwhals.exceptions import ComputeError, InvalidOperationError -from narwhals.schema import Schema - -if TYPE_CHECKING: - import pyarrow as pa - from typing_extensions import Never, Self - - from narwhals._plan.categorical import ExprCatNamespace - from narwhals._plan.common import ExprIR, Function, NamedIR - from narwhals._plan.lists import ExprListNamespace - from narwhals._plan.meta import IRMetaNamespace - from narwhals._plan.name import ExprNameNamespace - from narwhals._plan.protocols import ( - CompliantBaseFrame, - CompliantDataFrame, - CompliantSeries, - ) - from narwhals._plan.schema import FrozenSchema - from narwhals._plan.strings import ExprStringNamespace - from narwhals._plan.struct import ExprStructNamespace - from narwhals._plan.temporal import ExprDateTimeNamespace - from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq, Udf - from narwhals.dtypes import DType - from narwhals.typing import ( - ClosedInterval, - FillNullStrategy, - IntoDType, - NativeFrame, - NativeSeries, - NumericLiteral, - RankMethod, - RollingInterpolationMethod, - TemporalLiteral, - ) - - -# NOTE: Trying to keep consistent logic between `DataFrame.sort` and `Expr.sort_by` -def _parse_sort_by( - by: OneOrIterable[IntoExpr] = (), - *more_by: IntoExpr, - descending: OneOrIterable[bool] = False, - nulls_last: OneOrIterable[bool] = False, -) -> tuple[Seq[ExprIR], SortMultipleOptions]: - sort_by = parse.parse_into_seq_of_expr_ir(by, *more_by) - if length_changing := next((e for e in sort_by if e.is_scalar), None): - msg = f"All expressions sort keys must preserve length, but got:\n{length_changing!r}" - raise InvalidOperationError(msg) - options = SortMultipleOptions.parse(descending=descending, nulls_last=nulls_last) - return sort_by, options - - -# NOTE: Overly simplified placeholders for mocking typing -# Entirely ignoring namespace + function binding -class Expr: - _ir: ExprIR - _version: ClassVar[Version] = Version.MAIN - - def __repr__(self) -> str: - return f"nw._plan.Expr({self.version.name.lower()}):\n{self._ir!r}" - - def __str__(self) -> str: - """Use `print(self)` for formatting.""" - return f"nw._plan.Expr({self.version.name.lower()}):\n{self._ir!s}" - - def _repr_html_(self) -> str: - return self._ir._repr_html_() - - @classmethod - def _from_ir(cls, ir: ExprIR, /) -> Self: - obj = cls.__new__(cls) - obj._ir = ir - return obj - - @property - def version(self) -> Version: - return self._version - - def alias(self, name: str) -> Self: - return self._from_ir(self._ir.alias(name)) - - def cast(self, dtype: IntoDType) -> Self: - return self._from_ir(self._ir.cast(into_dtype(dtype))) - - def exclude(self, *names: OneOrIterable[str]) -> Self: - return self._from_ir(expr.Exclude.from_names(self._ir, *names)) - - def count(self) -> Self: - return self._from_ir(agg.Count(expr=self._ir)) - - def max(self) -> Self: - return self._from_ir(agg.Max(expr=self._ir)) - - def mean(self) -> Self: - return self._from_ir(agg.Mean(expr=self._ir)) - - def min(self) -> Self: - return self._from_ir(agg.Min(expr=self._ir)) - - def median(self) -> Self: - return self._from_ir(agg.Median(expr=self._ir)) - - def n_unique(self) -> Self: - return self._from_ir(agg.NUnique(expr=self._ir)) - - def sum(self) -> Self: - return self._from_ir(agg.Sum(expr=self._ir)) - - def arg_min(self) -> Self: - return self._from_ir(agg.ArgMin(expr=self._ir)) - - def arg_max(self) -> Self: - return self._from_ir(agg.ArgMax(expr=self._ir)) - - def first(self) -> Self: - return self._from_ir(agg.First(expr=self._ir)) - - def last(self) -> Self: - return self._from_ir(agg.Last(expr=self._ir)) - - def var(self, *, ddof: int = 1) -> Self: - return self._from_ir(agg.Var(expr=self._ir, ddof=ddof)) - - def std(self, *, ddof: int = 1) -> Self: - return self._from_ir(agg.Std(expr=self._ir, ddof=ddof)) - - def quantile( - self, quantile: float, interpolation: RollingInterpolationMethod - ) -> Self: - return self._from_ir( - agg.Quantile(expr=self._ir, quantile=quantile, interpolation=interpolation) - ) - - def over( - self, - *partition_by: OneOrIterable[IntoExpr], - order_by: OneOrIterable[IntoExpr] = None, - descending: bool = False, - nulls_last: bool = False, - ) -> Self: - node: expr.WindowExpr | expr.OrderedWindowExpr - partition: Seq[ExprIR] = () - if not (partition_by) and order_by is None: - msg = "At least one of `partition_by` or `order_by` must be specified." - raise TypeError(msg) - if partition_by: - partition = parse.parse_into_seq_of_expr_ir(*partition_by) - if order_by is not None: - by = parse.parse_into_seq_of_expr_ir(order_by) - options = SortOptions(descending=descending, nulls_last=nulls_last) - node = Over().to_ordered_window_expr(self._ir, partition, by, options) - else: - node = Over().to_window_expr(self._ir, partition) - return self._from_ir(node) - - def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: - options = SortOptions(descending=descending, nulls_last=nulls_last) - return self._from_ir(expr.Sort(expr=self._ir, options=options)) - - def sort_by( - self, - by: OneOrIterable[IntoExpr], - *more_by: IntoExpr, - descending: OneOrIterable[bool] = False, - nulls_last: OneOrIterable[bool] = False, - ) -> Self: - keys, opts = _parse_sort_by( - by, *more_by, descending=descending, nulls_last=nulls_last - ) - return self._from_ir(expr.SortBy(expr=self._ir, by=keys, options=opts)) - - def filter( - self, *predicates: OneOrIterable[IntoExprColumn], **constraints: Any - ) -> Self: - by = parse.parse_predicates_constraints_into_expr_ir(*predicates, **constraints) - return self._from_ir(expr.Filter(expr=self._ir, by=by)) - - def _with_unary(self, function: Function, /) -> Self: - return self._from_ir(function.to_function_expr(self._ir)) - - def abs(self) -> Self: - return self._with_unary(F.Abs()) - - def hist( - self, - bins: Sequence[float] | None = None, - *, - bin_count: int | None = None, - include_breakpoint: bool = True, - ) -> Self: - node: F.Hist - if bins is not None: - if bin_count is not None: - msg = "can only provide one of `bin_count` or `bins`" - raise ComputeError(msg) - node = F.HistBins(bins=tuple(bins), include_breakpoint=include_breakpoint) - elif bin_count is not None: - node = F.HistBinCount( - bin_count=bin_count, include_breakpoint=include_breakpoint - ) - else: - node = F.HistBinCount(include_breakpoint=include_breakpoint) - return self._with_unary(node) - - def log(self, base: float = math.e) -> Self: - return self._with_unary(F.Log(base=base)) - - def exp(self) -> Self: - return self._with_unary(F.Exp()) - - def sqrt(self) -> Self: - return self._with_unary(F.Sqrt()) - - def kurtosis(self, *, fisher: bool = True, bias: bool = True) -> Self: - return self._with_unary(F.Kurtosis(fisher=fisher, bias=bias)) - - def null_count(self) -> Self: - return self._with_unary(F.NullCount()) - - def fill_null( - self, - value: IntoExpr = None, - strategy: FillNullStrategy | None = None, - limit: int | None = None, - ) -> Self: - if strategy is None: - ir = parse.parse_into_expr_ir(value, str_as_lit=True) - return self._from_ir(F.FillNull().to_function_expr(self._ir, ir)) - return self._with_unary(F.FillNullWithStrategy(strategy=strategy, limit=limit)) - - def shift(self, n: int) -> Self: - return self._with_unary(F.Shift(n=n)) - - def drop_nulls(self) -> Self: - return self._with_unary(F.DropNulls()) - - def mode(self) -> Self: - return self._with_unary(F.Mode()) - - def skew(self) -> Self: - return self._with_unary(F.Skew()) - - def rank(self, method: RankMethod = "average", *, descending: bool = False) -> Self: - options = RankOptions(method=method, descending=descending) - return self._with_unary(F.Rank(options=options)) - - def clip( - self, - lower_bound: IntoExprColumn | NumericLiteral | TemporalLiteral | None = None, - upper_bound: IntoExprColumn | NumericLiteral | TemporalLiteral | None = None, - ) -> Self: - return self._from_ir( - F.Clip().to_function_expr( - self._ir, *parse.parse_into_seq_of_expr_ir(lower_bound, upper_bound) - ) - ) - - def cum_count(self, *, reverse: bool = False) -> Self: - return self._with_unary(F.CumCount(reverse=reverse)) - - def cum_min(self, *, reverse: bool = False) -> Self: - return self._with_unary(F.CumMin(reverse=reverse)) - - def cum_max(self, *, reverse: bool = False) -> Self: - return self._with_unary(F.CumMax(reverse=reverse)) - - def cum_prod(self, *, reverse: bool = False) -> Self: - return self._with_unary(F.CumProd(reverse=reverse)) - - def cum_sum(self, *, reverse: bool = False) -> Self: - return self._with_unary(F.CumSum(reverse=reverse)) - - def rolling_sum( - self, window_size: int, *, min_samples: int | None = None, center: bool = False - ) -> Self: - options = rolling_options(window_size, min_samples, center=center) - return self._with_unary(F.RollingSum(options=options)) - - def rolling_mean( - self, window_size: int, *, min_samples: int | None = None, center: bool = False - ) -> Self: - options = rolling_options(window_size, min_samples, center=center) - return self._with_unary(F.RollingMean(options=options)) - - def rolling_var( - self, - window_size: int, - *, - min_samples: int | None = None, - center: bool = False, - ddof: int = 1, - ) -> Self: - options = rolling_options(window_size, min_samples, center=center, ddof=ddof) - return self._with_unary(F.RollingVar(options=options)) - - def rolling_std( - self, - window_size: int, - *, - min_samples: int | None = None, - center: bool = False, - ddof: int = 1, - ) -> Self: - options = rolling_options(window_size, min_samples, center=center, ddof=ddof) - return self._with_unary(F.RollingStd(options=options)) - - def diff(self) -> Self: - return self._with_unary(F.Diff()) - - def unique(self) -> Self: - return self._with_unary(F.Unique()) - - def round(self, decimals: int = 0) -> Self: - return self._with_unary(F.Round(decimals=decimals)) - - def ewm_mean( - self, - *, - com: float | None = None, - span: float | None = None, - half_life: float | None = None, - alpha: float | None = None, - adjust: bool = True, - min_samples: int = 1, - ignore_nulls: bool = False, - ) -> Self: - options = EWMOptions( - com=com, - span=span, - half_life=half_life, - alpha=alpha, - adjust=adjust, - min_samples=min_samples, - ignore_nulls=ignore_nulls, - ) - return self._with_unary(F.EwmMean(options=options)) - - def replace_strict( - self, - old: Sequence[Any] | Mapping[Any, Any], - new: Sequence[Any] | None = None, - *, - return_dtype: IntoDType | None = None, - ) -> Self: - before: Seq[Any] - after: Seq[Any] - if new is None: - if not isinstance(old, Mapping): - msg = "`new` argument is required if `old` argument is not a Mapping type" - raise TypeError(msg) - before = tuple(old) - after = tuple(old.values()) - elif isinstance(old, Mapping): - msg = "`new` argument cannot be used if `old` argument is a Mapping type" - raise TypeError(msg) - else: - before = tuple(old) - after = tuple(new) - if return_dtype is not None: - return_dtype = into_dtype(return_dtype) - function = F.ReplaceStrict(old=before, new=after, return_dtype=return_dtype) - return self._with_unary(function) - - def gather_every(self, n: int, offset: int = 0) -> Self: - return self._with_unary(F.GatherEvery(n=n, offset=offset)) - - def map_batches( - self, - function: Udf, - return_dtype: IntoDType | None = None, - *, - is_elementwise: bool = False, - returns_scalar: bool = False, - ) -> Self: - if return_dtype is not None: - return_dtype = into_dtype(return_dtype) - return self._with_unary( - F.MapBatches( - function=function, - return_dtype=return_dtype, - is_elementwise=is_elementwise, - returns_scalar=returns_scalar, - ) - ) - - def any(self) -> Self: - return self._with_unary(boolean.Any()) - - def all(self) -> Self: - return self._with_unary(boolean.All()) - - def is_duplicated(self) -> Self: - return self._with_unary(boolean.IsDuplicated()) - - def is_finite(self) -> Self: - return self._with_unary(boolean.IsFinite()) - - def is_nan(self) -> Self: - return self._with_unary(boolean.IsNan()) - - def is_null(self) -> Self: - return self._with_unary(boolean.IsNull()) - - def is_first_distinct(self) -> Self: - return self._with_unary(boolean.IsFirstDistinct()) - - def is_last_distinct(self) -> Self: - return self._with_unary(boolean.IsLastDistinct()) - - def is_unique(self) -> Self: - return self._with_unary(boolean.IsUnique()) - - def is_between( - self, - lower_bound: IntoExpr, - upper_bound: IntoExpr, - closed: ClosedInterval = "both", - ) -> Self: - it = parse.parse_into_seq_of_expr_ir(lower_bound, upper_bound) - return self._from_ir( - boolean.IsBetween(closed=closed).to_function_expr(self._ir, *it) - ) - - def is_in(self, other: Iterable[Any]) -> Self: - if is_series(other): - return self._with_unary(boolean.IsInSeries.from_series(other)) - if isinstance(other, Iterable): - return self._with_unary(boolean.IsInSeq.from_iterable(other)) - if is_expr(other): - return self._with_unary(boolean.IsInExpr(other=other._ir)) - msg = f"`is_in` only supports iterables, got: {type(other).__name__}" - raise TypeError(msg) - - def _with_binary( - self, - op: type[ops.Operator], - other: IntoExpr, - *, - str_as_lit: bool = False, - reflect: bool = False, - ) -> Self: - other_ir = parse.parse_into_expr_ir(other, str_as_lit=str_as_lit) - args = (self._ir, other_ir) if not reflect else (other_ir, self._ir) - return self._from_ir(op().to_binary_expr(*args)) - - def __eq__(self, other: IntoExpr) -> Self: # type: ignore[override] - return self._with_binary(ops.Eq, other, str_as_lit=True) - - def __ne__(self, other: IntoExpr) -> Self: # type: ignore[override] - return self._with_binary(ops.NotEq, other, str_as_lit=True) - - def __lt__(self, other: IntoExpr) -> Self: - return self._with_binary(ops.Lt, other, str_as_lit=True) - - def __le__(self, other: IntoExpr) -> Self: - return self._with_binary(ops.LtEq, other, str_as_lit=True) - - def __gt__(self, other: IntoExpr) -> Self: - return self._with_binary(ops.Gt, other, str_as_lit=True) - - def __ge__(self, other: IntoExpr) -> Self: - return self._with_binary(ops.GtEq, other, str_as_lit=True) - - def __add__(self, other: IntoExpr) -> Self: - return self._with_binary(ops.Add, other, str_as_lit=True) - - def __radd__(self, other: IntoExpr) -> Self: - return self._with_binary(ops.Add, other, str_as_lit=True, reflect=True) - - def __sub__(self, other: IntoExpr) -> Self: - return self._with_binary(ops.Sub, other) - - def __rsub__(self, other: IntoExpr) -> Self: - return self._with_binary(ops.Sub, other, reflect=True) - - def __mul__(self, other: IntoExpr) -> Self: - return self._with_binary(ops.Multiply, other) - - def __rmul__(self, other: IntoExpr) -> Self: - return self._with_binary(ops.Multiply, other, reflect=True) - - def __truediv__(self, other: IntoExpr) -> Self: - return self._with_binary(ops.TrueDivide, other) - - def __rtruediv__(self, other: IntoExpr) -> Self: - return self._with_binary(ops.TrueDivide, other, reflect=True) - - def __floordiv__(self, other: IntoExpr) -> Self: - return self._with_binary(ops.FloorDivide, other) - - def __rfloordiv__(self, other: IntoExpr) -> Self: - return self._with_binary(ops.FloorDivide, other, reflect=True) - - def __mod__(self, other: IntoExpr) -> Self: - return self._with_binary(ops.Modulus, other) - - def __rmod__(self, other: IntoExpr) -> Self: - return self._with_binary(ops.Modulus, other, reflect=True) - - def __and__(self, other: IntoExprColumn | int | bool) -> Self: - return self._with_binary(ops.And, other) - - def __rand__(self, other: IntoExprColumn | int | bool) -> Self: - return self._with_binary(ops.And, other, reflect=True) - - def __or__(self, other: IntoExprColumn | int | bool) -> Self: - return self._with_binary(ops.Or, other) - - def __ror__(self, other: IntoExprColumn | int | bool) -> Self: - return self._with_binary(ops.Or, other, reflect=True) - - def __xor__(self, other: IntoExprColumn | int | bool) -> Self: - return self._with_binary(ops.ExclusiveOr, other) - - def __rxor__(self, other: IntoExprColumn | int | bool) -> Self: - return self._with_binary(ops.ExclusiveOr, other, reflect=True) - - def __pow__(self, exponent: IntoExprColumn | float) -> Self: - exp = parse.parse_into_expr_ir(exponent) - return self._from_ir(F.Pow().to_function_expr(self._ir, exp)) - - def __rpow__(self, base: IntoExprColumn | float) -> Self: - base_ = parse.parse_into_expr_ir(base) - return self._from_ir(F.Pow().to_function_expr(base_, self._ir)) - - def __invert__(self) -> Self: - return self._with_unary(boolean.Not()) - - @property - def meta(self) -> IRMetaNamespace: - from narwhals._plan.meta import IRMetaNamespace - - return IRMetaNamespace.from_expr(self) - - @property - def name(self) -> ExprNameNamespace: - """Specialized expressions for modifying the name of existing expressions. - - Examples: - >>> from narwhals._plan import demo as nw - >>> - >>> renamed = nw.col("a", "b").name.suffix("_changed") - >>> str(renamed._ir) - "RenameAlias(expr=Columns(names=[a, b]), function=Suffix(suffix='_changed'))" - """ - from narwhals._plan.name import ExprNameNamespace - - return ExprNameNamespace(_expr=self) - - @property - def cat(self) -> ExprCatNamespace: - from narwhals._plan.categorical import ExprCatNamespace - - return ExprCatNamespace(_expr=self) - - @property - def struct(self) -> ExprStructNamespace: - from narwhals._plan.struct import ExprStructNamespace - - return ExprStructNamespace(_expr=self) - - @property - def dt(self) -> ExprDateTimeNamespace: - from narwhals._plan.temporal import ExprDateTimeNamespace - - return ExprDateTimeNamespace(_expr=self) - - @property - def list(self) -> ExprListNamespace: - from narwhals._plan.lists import ExprListNamespace - - return ExprListNamespace(_expr=self) - - @property - def str(self) -> ExprStringNamespace: - from narwhals._plan.strings import ExprStringNamespace - - return ExprStringNamespace(_expr=self) - - -class Selector(Expr): - _ir: expr.SelectorIR - - def __repr__(self) -> str: - return f"nw._plan.Selector({self.version.name.lower()}):\n{self._ir!r}" - - @classmethod - def _from_ir(cls, ir: expr.SelectorIR, /) -> Self: # type: ignore[override] - obj = cls.__new__(cls) - obj._ir = ir - return obj - - def _to_expr(self) -> Expr: - return self._ir.to_narwhals(self.version) - - @overload # type: ignore[override] - def __or__(self, other: Self) -> Self: ... - @overload - def __or__(self, other: IntoExprColumn | int | bool) -> Expr: ... - def __or__(self, other: IntoExprColumn | int | bool) -> Self | Expr: - if isinstance(other, type(self)): - op = ops.Or() - return self._from_ir(op.to_binary_selector(self._ir, other._ir)) - return self._to_expr() | other - - @overload # type: ignore[override] - def __and__(self, other: Self) -> Self: ... - @overload - def __and__(self, other: IntoExprColumn | int | bool) -> Expr: ... - def __and__(self, other: IntoExprColumn | int | bool) -> Self | Expr: - if is_column(other) and (name := other.meta.output_name()): - other = by_name(name) - if isinstance(other, type(self)): - op = ops.And() - return self._from_ir(op.to_binary_selector(self._ir, other._ir)) - return self._to_expr() & other - - @overload # type: ignore[override] - def __sub__(self, other: Self) -> Self: ... - @overload - def __sub__(self, other: IntoExpr) -> Expr: ... - def __sub__(self, other: IntoExpr) -> Self | Expr: - if isinstance(other, type(self)): - op = ops.Sub() - return self._from_ir(op.to_binary_selector(self._ir, other._ir)) - return self._to_expr() - other - - @overload # type: ignore[override] - def __xor__(self, other: Self) -> Self: ... - @overload - def __xor__(self, other: IntoExprColumn | int | bool) -> Expr: ... - def __xor__(self, other: IntoExprColumn | int | bool) -> Self | Expr: - if isinstance(other, type(self)): - op = ops.ExclusiveOr() - return self._from_ir(op.to_binary_selector(self._ir, other._ir)) - return self._to_expr() ^ other - - def __invert__(self) -> Self: - return self._from_ir(expr.InvertSelector(selector=self._ir)) - - def __add__(self, other: Any) -> Expr: # type: ignore[override] - if isinstance(other, type(self)): - msg = "unsupported operand type(s) for op: ('Selector' + 'Selector')" - raise TypeError(msg) - return self._to_expr() + other # type: ignore[no-any-return] - - def __radd__(self, other: Any) -> Never: - msg = "unsupported operand type(s) for op: ('Expr' + 'Selector')" - raise TypeError(msg) - - def __rsub__(self, other: Any) -> Never: - msg = "unsupported operand type(s) for op: ('Expr' - 'Selector')" - raise TypeError(msg) - - @overload # type: ignore[override] - def __rand__(self, other: Self) -> Self: ... - @overload - def __rand__(self, other: IntoExprColumn | int | bool) -> Expr: ... - def __rand__(self, other: IntoExprColumn | int | bool) -> Self | Expr: - if is_column(other) and (name := other.meta.output_name()): - return by_name(name) & self - return self._to_expr().__rand__(other) - - @overload # type: ignore[override] - def __ror__(self, other: Self) -> Self: ... - @overload - def __ror__(self, other: IntoExprColumn | int | bool) -> Expr: ... - def __ror__(self, other: IntoExprColumn | int | bool) -> Self | Expr: - if is_column(other) and (name := other.meta.output_name()): - return by_name(name) | self - return self._to_expr().__ror__(other) - - @overload # type: ignore[override] - def __rxor__(self, other: Self) -> Self: ... - @overload - def __rxor__(self, other: IntoExprColumn | int | bool) -> Expr: ... - def __rxor__(self, other: IntoExprColumn | int | bool) -> Self | Expr: - if is_column(other) and (name := other.meta.output_name()): - return by_name(name) ^ self - return self._to_expr().__rxor__(other) - - -class ExprV1(Expr): - _version: ClassVar[Version] = Version.V1 - - -class SelectorV1(Selector): - _version: ClassVar[Version] = Version.V1 - - -class BaseFrame(Generic[NativeFrameT]): - _compliant: CompliantBaseFrame[Any, NativeFrameT] - _version: ClassVar[Version] = Version.MAIN - - @property - def version(self) -> Version: - return self._version - - @property - def schema(self) -> Schema: - return Schema(self._compliant.schema.items()) - - @property - def columns(self) -> list[str]: - return self._compliant.columns - - def __repr__(self) -> str: # pragma: no cover - return generate_repr(f"nw.{type(self).__name__}", self.to_native().__repr__()) - - @classmethod - def from_native(cls, native: Any, /) -> Self: - raise NotImplementedError - - @classmethod - def _from_compliant(cls, compliant: CompliantBaseFrame[Any, NativeFrameT], /) -> Self: - obj = cls.__new__(cls) - obj._compliant = compliant - return obj - - def to_native(self) -> NativeFrameT: - return self._compliant.native - - def _project( - self, - exprs: tuple[OneOrIterable[IntoExpr], ...], - named_exprs: dict[str, Any], - context: ExprContext, - /, - ) -> tuple[Seq[NamedIR[ExprIR]], FrozenSchema]: - """Temp, while these parts aren't connected, this is easier for testing.""" - irs, schema_frozen, output_names = expr_expansion.prepare_projection( - parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), self.schema - ) - named_irs = expr_expansion.into_named_irs(irs, output_names) - return schema_frozen.project(named_irs, context) - - def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: - named_irs, schema_projected = self._project( - exprs, named_exprs, ExprContext.SELECT - ) - return self._from_compliant(self._compliant.select(named_irs)) - - def with_columns(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: - named_irs, schema_projected = self._project( - exprs, named_exprs, ExprContext.WITH_COLUMNS - ) - return self._from_compliant(self._compliant.with_columns(named_irs)) - - def sort( - self, - by: OneOrIterable[str], - *more_by: str, - descending: OneOrIterable[bool] = False, - nulls_last: OneOrIterable[bool] = False, - ) -> Self: - sort, opts = _parse_sort_by( - by, *more_by, descending=descending, nulls_last=nulls_last - ) - irs, schema_frozen, output_names = expr_expansion.prepare_projection( - sort, self.schema - ) - named_irs = expr_expansion.into_named_irs(irs, output_names) - return self._from_compliant(self._compliant.sort(named_irs, opts)) - - -class DataFrame(BaseFrame[NativeDataFrameT], Generic[NativeDataFrameT, NativeSeriesT]): - _compliant: CompliantDataFrame[Any, NativeDataFrameT, NativeSeriesT] - - @property - def _series(self) -> type[Series[NativeSeriesT]]: - return Series[NativeSeriesT] - - # NOTE: Gave up on trying to get typing working for now - @classmethod - def from_native( # type: ignore[override] - cls, native: NativeFrame, / - ) -> DataFrame[pa.Table, pa.ChunkedArray[Any]]: - if is_pyarrow_table(native): - from narwhals._plan.arrow.dataframe import ArrowDataFrame - - return ArrowDataFrame.from_native(native, cls._version).to_narwhals() - - raise NotImplementedError(type(native)) - - @overload - def to_dict( - self, *, as_series: Literal[True] = ... - ) -> dict[str, Series[NativeSeriesT]]: ... - @overload - def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... - @overload - def to_dict( - self, *, as_series: bool - ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[Any]]: ... - def to_dict( - self, *, as_series: bool = True - ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[Any]]: - if as_series: - return { - key: self._series._from_compliant(value) - for key, value in self._compliant.to_dict(as_series=as_series).items() - } - return self._compliant.to_dict(as_series=as_series) - - def __len__(self) -> int: - return len(self._compliant) - - -class Series(Generic[NativeSeriesT]): - _compliant: CompliantSeries[NativeSeriesT] - _version: ClassVar[Version] = Version.MAIN - - @property - def version(self) -> Version: - return self._version - - @property - def dtype(self) -> DType: - return self._compliant.dtype - - @property - def name(self) -> str: - return self._compliant.name - - # NOTE: Gave up on trying to get typing working for now - @classmethod - def from_native( - cls, native: NativeSeries, name: str = "", / - ) -> Series[pa.ChunkedArray[Any]]: - if is_pyarrow_chunked_array(native): - from narwhals._plan.arrow.series import ArrowSeries - - return ArrowSeries.from_native( - native, name, version=cls._version - ).to_narwhals() - - raise NotImplementedError(type(native)) - - @classmethod - def _from_compliant(cls, compliant: CompliantSeries[NativeSeriesT], /) -> Self: - obj = cls.__new__(cls) - obj._compliant = compliant - return obj - - def to_native(self) -> NativeSeriesT: - return self._compliant.native - - def to_list(self) -> list[Any]: - return self._compliant.to_list() - - def __iter__(self) -> Iterator[Any]: - yield from self.to_native() - - -class SeriesV1(Series[NativeSeriesT]): - _version: ClassVar[Version] = Version.V1 diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index 127471720d..8f4348aaa3 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -24,10 +24,9 @@ import pandas as pd import polars as pl - from narwhals._plan.aggregation import AggExpr - from narwhals._plan.common import ExprIR, Function - from narwhals._plan.expr import FunctionExpr, WindowExpr - from narwhals._plan.operators import Operator + from narwhals._plan import expressions as ir + from narwhals._plan._function import Function + from narwhals._plan.expressions.operators import Operator from narwhals._plan.options import SortOptions from narwhals._plan.typing import IntoExpr, Seq @@ -37,13 +36,13 @@ # TODO @dangotbanned: Use arguments in error message -def agg_scalar_error(agg: AggExpr, scalar: ExprIR, /) -> InvalidOperationError: # noqa: ARG001 +def agg_scalar_error(agg: ir.AggExpr, scalar: ir.ExprIR, /) -> InvalidOperationError: # noqa: ARG001 msg = "Can't apply aggregations to scalar-like expressions." return InvalidOperationError(msg) def function_expr_invalid_operation_error( - function: Function, parent: ExprIR + function: Function, parent: ir.ExprIR ) -> InvalidOperationError: msg = f"Cannot use `{function!r}()` on aggregated expression `{parent!r}`." return InvalidOperationError(msg) @@ -57,7 +56,9 @@ def hist_bins_monotonic_error(bins: Seq[float]) -> ComputeError: # noqa: ARG001 # NOTE: Always underlining `right`, since the message refers to both types of exprs # Assuming the most recent as the issue -def binary_expr_shape_error(left: ExprIR, op: Operator, right: ExprIR) -> ShapeError: +def binary_expr_shape_error( + left: ir.ExprIR, op: Operator, right: ir.ExprIR +) -> ShapeError: lhs_op = f"{left!r} {op!r} " rhs = repr(right) indent = len(lhs_op) * " " @@ -71,7 +72,7 @@ def binary_expr_shape_error(left: ExprIR, op: Operator, right: ExprIR) -> ShapeE # TODO @dangotbanned: Share the right underline code w/ `binary_expr_shape_error` def binary_expr_multi_output_error( - left: ExprIR, op: Operator, right: ExprIR + left: ir.ExprIR, op: Operator, right: ir.ExprIR ) -> MultiOutputExpressionError: lhs_op = f"{left!r} {op!r} " rhs = repr(right) @@ -86,7 +87,7 @@ def binary_expr_multi_output_error( def binary_expr_length_changing_error( - left: ExprIR, op: Operator, right: ExprIR + left: ir.ExprIR, op: Operator, right: ir.ExprIR ) -> LengthChangingExprError: lhs, rhs = repr(left), repr(right) op_s = f" {op!r} " @@ -103,9 +104,9 @@ def binary_expr_length_changing_error( # TODO @dangotbanned: Use arguments in error message def over_nested_error( - expr: WindowExpr, # noqa: ARG001 - partition_by: Seq[ExprIR], # noqa: ARG001 - order_by: Seq[ExprIR] = (), # noqa: ARG001 + expr: ir.WindowExpr, # noqa: ARG001 + partition_by: Seq[ir.ExprIR], # noqa: ARG001 + order_by: Seq[ir.ExprIR] = (), # noqa: ARG001 sort_options: SortOptions | None = None, # noqa: ARG001 ) -> InvalidOperationError: msg = "Cannot nest `over` statements." @@ -114,9 +115,9 @@ def over_nested_error( # TODO @dangotbanned: Use arguments in error message def over_elementwise_error( - expr: FunctionExpr[Function], - partition_by: Seq[ExprIR], # noqa: ARG001 - order_by: Seq[ExprIR] = (), # noqa: ARG001 + expr: ir.FunctionExpr, + partition_by: Seq[ir.ExprIR], # noqa: ARG001 + order_by: Seq[ir.ExprIR] = (), # noqa: ARG001 sort_options: SortOptions | None = None, # noqa: ARG001 ) -> InvalidOperationError: msg = f"Cannot use `over` on expressions which are elementwise.\n{expr!r}" @@ -125,9 +126,9 @@ def over_elementwise_error( # TODO @dangotbanned: Use arguments in error message def over_row_separable_error( - expr: FunctionExpr[Function], - partition_by: Seq[ExprIR], # noqa: ARG001 - order_by: Seq[ExprIR] = (), # noqa: ARG001 + expr: ir.FunctionExpr, + partition_by: Seq[ir.ExprIR], # noqa: ARG001 + order_by: Seq[ir.ExprIR] = (), # noqa: ARG001 sort_options: SortOptions | None = None, # noqa: ARG001 ) -> InvalidOperationError: msg = f"Cannot use `over` on expressions which change length.\n{expr!r}" @@ -169,7 +170,7 @@ def is_iterable_polars_error( return TypeError(msg) -def duplicate_error(exprs: Seq[ExprIR]) -> DuplicateError: +def duplicate_error(exprs: Seq[ir.ExprIR]) -> DuplicateError: INDENT = "\n " # noqa: N806 names = [_output_name(expr) for expr in exprs] duplicates = {k for k, v in Counter(names).items() if v > 1} @@ -184,7 +185,7 @@ def duplicate_error(exprs: Seq[ExprIR]) -> DuplicateError: return DuplicateError(msg) -def _output_name(expr: ExprIR) -> str: +def _output_name(expr: ir.ExprIR) -> str: return expr.meta.output_name() diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index ecf4efe136..b0f369bd77 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -1,511 +1,697 @@ -"""Top-level `Expr` nodes.""" - from __future__ import annotations -# NOTE: Needed to avoid naming collisions -# - Literal -import typing as t - -from narwhals._plan.aggregation import AggExpr, OrderableAggExpr -from narwhals._plan.common import ExprIR, SelectorIR, collect -from narwhals._plan.exceptions import function_expr_invalid_operation_error -from narwhals._plan.name import KeepName, RenameAlias -from narwhals._plan.options import ExprIROptions -from narwhals._plan.typing import ( - FunctionT_co, - LeftSelectorT, - LeftT, - LiteralT, - OperatorT, - RangeT_co, - RightSelectorT, - RightT, - RollingT_co, - SelectorOperatorT, - SelectorT, - Seq, -) -from narwhals._utils import flatten -from narwhals.exceptions import InvalidOperationError - -if t.TYPE_CHECKING: - from typing_extensions import Self - - from narwhals._plan.functions import MapBatches # noqa: F401 - from narwhals._plan.literal import LiteralValue - from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions - from narwhals._plan.protocols import Ctx, FrameT_contra, R_co - from narwhals._plan.selectors import Selector - from narwhals._plan.window import Window - from narwhals.dtypes import DType - -__all__ = [ - "AggExpr", - "Alias", - "All", - "AnonymousExpr", - "BinaryExpr", - "BinarySelector", - "Cast", - "Column", - "Columns", - "Exclude", - "Filter", - "FunctionExpr", - "IndexColumns", - "KeepName", - "Len", - "Literal", - "Nth", - "OrderableAggExpr", - "RenameAlias", - "RollingExpr", - "RootSelector", - "SelectorIR", - "Sort", - "SortBy", - "TernaryExpr", - "WindowExpr", - "col", -] - - -def col(name: str, /) -> Column: - return Column(name=name) - - -def cols(*names: str) -> Columns: - return Columns(names=names) - - -def nth(index: int, /) -> Nth: - return Nth(index=index) - - -def index_columns(*indices: int) -> IndexColumns: - return IndexColumns(indices=indices) - - -class Alias(ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): - __slots__ = ("expr", "name") - expr: ExprIR - name: str +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, overload - @property - def is_scalar(self) -> bool: - return self.expr.is_scalar +from narwhals._plan import common, expressions as ir +from narwhals._plan._guards import is_column, is_expr, is_series +from narwhals._plan._parse import ( + parse_into_expr_ir, + parse_into_seq_of_expr_ir, + parse_predicates_constraints_into_expr_ir, +) +from narwhals._plan.expressions import ( + aggregation as agg, + functions as F, + operators as ops, +) +from narwhals._plan.expressions.selectors import by_name +from narwhals._plan.options import ( + EWMOptions, + RankOptions, + SortMultipleOptions, + SortOptions, + rolling_options, +) +from narwhals._utils import Version +from narwhals.exceptions import ComputeError, InvalidOperationError + +if TYPE_CHECKING: + from typing_extensions import Never, Self + + from narwhals._plan._function import Function + from narwhals._plan.expressions.categorical import ExprCatNamespace + from narwhals._plan.expressions.lists import ExprListNamespace + from narwhals._plan.expressions.name import ExprNameNamespace + from narwhals._plan.expressions.strings import ExprStringNamespace + from narwhals._plan.expressions.struct import ExprStructNamespace + from narwhals._plan.expressions.temporal import ExprDateTimeNamespace + from narwhals._plan.meta import MetaNamespace + from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq, Udf + from narwhals.typing import ( + ClosedInterval, + FillNullStrategy, + IntoDType, + NumericLiteral, + RankMethod, + RollingInterpolationMethod, + TemporalLiteral, + ) + + +# NOTE: Trying to keep consistent logic between `DataFrame.sort` and `Expr.sort_by` +def _parse_sort_by( + by: OneOrIterable[IntoExpr] = (), + *more_by: IntoExpr, + descending: OneOrIterable[bool] = False, + nulls_last: OneOrIterable[bool] = False, +) -> tuple[Seq[ir.ExprIR], SortMultipleOptions]: + sort_by = parse_into_seq_of_expr_ir(by, *more_by) + if length_changing := next((e for e in sort_by if e.is_scalar), None): + msg = f"All expressions sort keys must preserve length, but got:\n{length_changing!r}" + raise InvalidOperationError(msg) + options = SortMultipleOptions.parse(descending=descending, nulls_last=nulls_last) + return sort_by, options + + +# NOTE: Overly simplified placeholders for mocking typing +# Entirely ignoring namespace + function binding +class Expr: + _ir: ir.ExprIR + _version: ClassVar[Version] = Version.MAIN def __repr__(self) -> str: - return f"{self.expr!r}.alias({self.name!r})" + return f"nw._plan.Expr({self.version.name.lower()}):\n{self._ir!r}" + def __str__(self) -> str: + """Use `print(self)` for formatting.""" + return f"nw._plan.Expr({self.version.name.lower()}):\n{self._ir!s}" -class Column(ExprIR, config=ExprIROptions.namespaced("col")): - __slots__ = ("name",) - name: str + def _repr_html_(self) -> str: + return self._ir._repr_html_() - def __repr__(self) -> str: - return f"col({self.name!r})" + @classmethod + def _from_ir(cls, expr_ir: ir.ExprIR, /) -> Self: + obj = cls.__new__(cls) + obj._ir = expr_ir + return obj + @property + def version(self) -> Version: + return self._version -class _ColumnSelection(ExprIR, config=ExprIROptions.no_dispatch()): - """Nodes which can resolve to `Column`(s) with a `Schema`.""" + def alias(self, name: str) -> Self: + return self._from_ir(self._ir.alias(name)) + def cast(self, dtype: IntoDType) -> Self: + return self._from_ir(self._ir.cast(common.into_dtype(dtype))) -class Columns(_ColumnSelection): - __slots__ = ("names",) - names: Seq[str] + def exclude(self, *names: OneOrIterable[str]) -> Self: + return self._from_ir(ir.Exclude.from_names(self._ir, *names)) - def __repr__(self) -> str: - return f"cols({list(self.names)!r})" + def count(self) -> Self: + return self._from_ir(agg.Count(expr=self._ir)) + def max(self) -> Self: + return self._from_ir(agg.Max(expr=self._ir)) -class Nth(_ColumnSelection): - __slots__ = ("index",) - index: int + def mean(self) -> Self: + return self._from_ir(agg.Mean(expr=self._ir)) - def __repr__(self) -> str: - return f"nth({self.index})" + def min(self) -> Self: + return self._from_ir(agg.Min(expr=self._ir)) + def median(self) -> Self: + return self._from_ir(agg.Median(expr=self._ir)) -class IndexColumns(_ColumnSelection): - __slots__ = ("indices",) - indices: Seq[int] + def n_unique(self) -> Self: + return self._from_ir(agg.NUnique(expr=self._ir)) - def __repr__(self) -> str: - return f"index_columns({self.indices!r})" + def sum(self) -> Self: + return self._from_ir(agg.Sum(expr=self._ir)) + def arg_min(self) -> Self: + return self._from_ir(agg.ArgMin(expr=self._ir)) -class All(_ColumnSelection): - def __repr__(self) -> str: - return "all()" + def arg_max(self) -> Self: + return self._from_ir(agg.ArgMax(expr=self._ir)) + def first(self) -> Self: + return self._from_ir(agg.First(expr=self._ir)) -class Exclude(_ColumnSelection, child=("expr",)): - __slots__ = ("expr", "names") - expr: ExprIR - """Default is `all()`.""" - names: Seq[str] - """Excluded names.""" + def last(self) -> Self: + return self._from_ir(agg.Last(expr=self._ir)) - @staticmethod - def from_names(expr: ExprIR, *names: str | t.Iterable[str]) -> Exclude: - flat = flatten(names) - return Exclude(expr=expr, names=collect(flat)) + def var(self, *, ddof: int = 1) -> Self: + return self._from_ir(agg.Var(expr=self._ir, ddof=ddof)) - def __repr__(self) -> str: - return f"{self.expr!r}.exclude({list(self.names)!r})" + def std(self, *, ddof: int = 1) -> Self: + return self._from_ir(agg.Std(expr=self._ir, ddof=ddof)) + def quantile( + self, quantile: float, interpolation: RollingInterpolationMethod + ) -> Self: + return self._from_ir( + agg.Quantile(expr=self._ir, quantile=quantile, interpolation=interpolation) + ) -class Literal(ExprIR, t.Generic[LiteralT], config=ExprIROptions.namespaced("lit")): - """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L81.""" + def over( + self, + *partition_by: OneOrIterable[IntoExpr], + order_by: OneOrIterable[IntoExpr] = None, + descending: bool = False, + nulls_last: bool = False, + ) -> Self: + if not (partition_by) and order_by is None: + msg = "At least one of `partition_by` or `order_by` must be specified." + raise TypeError(msg) + parse = parse_into_seq_of_expr_ir + fn = self._ir + group = parse(*partition_by) if partition_by else () + if order_by is None: + return self._from_ir(ir.over(fn, group)) + over = ir.over_ordered + order = parse(order_by) + desc, nulls = descending, nulls_last + return self._from_ir(over(fn, group, order, descending=desc, nulls_last=nulls)) + + def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: + options = SortOptions(descending=descending, nulls_last=nulls_last) + return self._from_ir(ir.Sort(expr=self._ir, options=options)) + + def sort_by( + self, + by: OneOrIterable[IntoExpr], + *more_by: IntoExpr, + descending: OneOrIterable[bool] = False, + nulls_last: OneOrIterable[bool] = False, + ) -> Self: + keys, opts = _parse_sort_by( + by, *more_by, descending=descending, nulls_last=nulls_last + ) + return self._from_ir(ir.SortBy(expr=self._ir, by=keys, options=opts)) - __slots__ = ("value",) - value: LiteralValue[LiteralT] + def filter( + self, *predicates: OneOrIterable[IntoExprColumn], **constraints: Any + ) -> Self: + by = parse_predicates_constraints_into_expr_ir(*predicates, **constraints) + return self._from_ir(ir.Filter(expr=self._ir, by=by)) - @property - def is_scalar(self) -> bool: - return self.value.is_scalar + def _with_unary(self, function: Function, /) -> Self: + return self._from_ir(function.to_function_expr(self._ir)) - @property - def dtype(self) -> DType: - return self.value.dtype + def abs(self) -> Self: + return self._with_unary(F.Abs()) - @property - def name(self) -> str: - return self.value.name + def hist( + self, + bins: Sequence[float] | None = None, + *, + bin_count: int | None = None, + include_breakpoint: bool = True, + ) -> Self: + node: F.Hist + if bins is not None: + if bin_count is not None: + msg = "can only provide one of `bin_count` or `bins`" + raise ComputeError(msg) + node = F.HistBins(bins=tuple(bins), include_breakpoint=include_breakpoint) + elif bin_count is not None: + node = F.HistBinCount( + bin_count=bin_count, include_breakpoint=include_breakpoint + ) + else: + node = F.HistBinCount(include_breakpoint=include_breakpoint) + return self._with_unary(node) - def __repr__(self) -> str: - return f"lit({self.value!r})" + def log(self, base: float = math.e) -> Self: + return self._with_unary(F.Log(base=base)) - def unwrap(self) -> LiteralT: - return self.value.unwrap() + def exp(self) -> Self: + return self._with_unary(F.Exp()) + def sqrt(self) -> Self: + return self._with_unary(F.Sqrt()) -class _BinaryOp(ExprIR, t.Generic[LeftT, OperatorT, RightT]): - __slots__ = ("left", "op", "right") - left: LeftT - op: OperatorT - right: RightT + def kurtosis(self, *, fisher: bool = True, bias: bool = True) -> Self: + return self._with_unary(F.Kurtosis(fisher=fisher, bias=bias)) - @property - def is_scalar(self) -> bool: - return self.left.is_scalar and self.right.is_scalar + def null_count(self) -> Self: + return self._with_unary(F.NullCount()) - def __repr__(self) -> str: - return f"[({self.left!r}) {self.op!r} ({self.right!r})]" + def fill_null( + self, + value: IntoExpr = None, + strategy: FillNullStrategy | None = None, + limit: int | None = None, + ) -> Self: + if strategy is None: + e = parse_into_expr_ir(value, str_as_lit=True) + return self._from_ir(F.FillNull().to_function_expr(self._ir, e)) + return self._with_unary(F.FillNullWithStrategy(strategy=strategy, limit=limit)) + def shift(self, n: int) -> Self: + return self._with_unary(F.Shift(n=n)) -class BinaryExpr( - _BinaryOp[LeftT, OperatorT, RightT], - t.Generic[LeftT, OperatorT, RightT], - child=("left", "right"), -): - """Application of two exprs via an `Operator`.""" + def drop_nulls(self) -> Self: + return self._with_unary(F.DropNulls()) - def iter_output_name(self) -> t.Iterator[ExprIR]: - yield from self.left.iter_output_name() + def mode(self) -> Self: + return self._with_unary(F.Mode()) + def skew(self) -> Self: + return self._with_unary(F.Skew()) -class Cast(ExprIR, child=("expr",)): - __slots__ = ("expr", "dtype") # noqa: RUF023 - expr: ExprIR - dtype: DType + def rank(self, method: RankMethod = "average", *, descending: bool = False) -> Self: + options = RankOptions(method=method, descending=descending) + return self._with_unary(F.Rank(options=options)) - @property - def is_scalar(self) -> bool: - return self.expr.is_scalar + def clip( + self, + lower_bound: IntoExprColumn | NumericLiteral | TemporalLiteral | None = None, + upper_bound: IntoExprColumn | NumericLiteral | TemporalLiteral | None = None, + ) -> Self: + it = parse_into_seq_of_expr_ir(lower_bound, upper_bound) + return self._from_ir(F.Clip().to_function_expr(self._ir, *it)) - def __repr__(self) -> str: - return f"{self.expr!r}.cast({self.dtype!r})" + def cum_count(self, *, reverse: bool = False) -> Self: + return self._with_unary(F.CumCount(reverse=reverse)) - def iter_output_name(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_output_name() + def cum_min(self, *, reverse: bool = False) -> Self: + return self._with_unary(F.CumMin(reverse=reverse)) + def cum_max(self, *, reverse: bool = False) -> Self: + return self._with_unary(F.CumMax(reverse=reverse)) -class Sort(ExprIR, child=("expr",)): - __slots__ = ("expr", "options") - expr: ExprIR - options: SortOptions + def cum_prod(self, *, reverse: bool = False) -> Self: + return self._with_unary(F.CumProd(reverse=reverse)) - @property - def is_scalar(self) -> bool: - return self.expr.is_scalar + def cum_sum(self, *, reverse: bool = False) -> Self: + return self._with_unary(F.CumSum(reverse=reverse)) - def __repr__(self) -> str: - direction = "desc" if self.options.descending else "asc" - return f"{self.expr!r}.sort({direction})" + def rolling_sum( + self, window_size: int, *, min_samples: int | None = None, center: bool = False + ) -> Self: + options = rolling_options(window_size, min_samples, center=center) + return self._with_unary(F.RollingSum(options=options)) - def iter_output_name(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_output_name() + def rolling_mean( + self, window_size: int, *, min_samples: int | None = None, center: bool = False + ) -> Self: + options = rolling_options(window_size, min_samples, center=center) + return self._with_unary(F.RollingMean(options=options)) + def rolling_var( + self, + window_size: int, + *, + min_samples: int | None = None, + center: bool = False, + ddof: int = 1, + ) -> Self: + options = rolling_options(window_size, min_samples, center=center, ddof=ddof) + return self._with_unary(F.RollingVar(options=options)) + + def rolling_std( + self, + window_size: int, + *, + min_samples: int | None = None, + center: bool = False, + ddof: int = 1, + ) -> Self: + options = rolling_options(window_size, min_samples, center=center, ddof=ddof) + return self._with_unary(F.RollingStd(options=options)) -class SortBy(ExprIR, child=("expr", "by")): - """https://github.com/narwhals-dev/narwhals/issues/2534.""" + def diff(self) -> Self: + return self._with_unary(F.Diff()) - __slots__ = ("expr", "by", "options") # noqa: RUF023 - expr: ExprIR - by: Seq[ExprIR] - options: SortMultipleOptions + def unique(self) -> Self: + return self._with_unary(F.Unique()) - @property - def is_scalar(self) -> bool: - return self.expr.is_scalar + def round(self, decimals: int = 0) -> Self: + return self._with_unary(F.Round(decimals=decimals)) - def __repr__(self) -> str: - return f"{self.expr!r}.sort_by(by={self.by!r}, options={self.options!r})" + def ewm_mean( + self, + *, + com: float | None = None, + span: float | None = None, + half_life: float | None = None, + alpha: float | None = None, + adjust: bool = True, + min_samples: int = 1, + ignore_nulls: bool = False, + ) -> Self: + options = EWMOptions( + com=com, + span=span, + half_life=half_life, + alpha=alpha, + adjust=adjust, + min_samples=min_samples, + ignore_nulls=ignore_nulls, + ) + return self._with_unary(F.EwmMean(options=options)) - def iter_output_name(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_output_name() + def replace_strict( + self, + old: Sequence[Any] | Mapping[Any, Any], + new: Sequence[Any] | None = None, + *, + return_dtype: IntoDType | None = None, + ) -> Self: + before: Seq[Any] + after: Seq[Any] + if new is None: + if not isinstance(old, Mapping): + msg = "`new` argument is required if `old` argument is not a Mapping type" + raise TypeError(msg) + before = tuple(old) + after = tuple(old.values()) + elif isinstance(old, Mapping): + msg = "`new` argument cannot be used if `old` argument is a Mapping type" + raise TypeError(msg) + else: + before = tuple(old) + after = tuple(new) + if return_dtype is not None: + return_dtype = common.into_dtype(return_dtype) + function = F.ReplaceStrict(old=before, new=after, return_dtype=return_dtype) + return self._with_unary(function) + def gather_every(self, n: int, offset: int = 0) -> Self: + return self._with_unary(F.GatherEvery(n=n, offset=offset)) -# mypy: disable-error-code="misc" -class FunctionExpr(ExprIR, t.Generic[FunctionT_co], child=("input",)): - """**Representing `Expr::Function`**. + def map_batches( + self, + function: Udf, + return_dtype: IntoDType | None = None, + *, + is_elementwise: bool = False, + returns_scalar: bool = False, + ) -> Self: + if return_dtype is not None: + return_dtype = common.into_dtype(return_dtype) + return self._with_unary( + F.MapBatches( + function=function, + return_dtype=return_dtype, + is_elementwise=is_elementwise, + returns_scalar=returns_scalar, + ) + ) - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L114-L120 - https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/function_expr/mod.rs#L123 - """ + def any(self) -> Self: + return self._with_unary(ir.boolean.Any()) - __slots__ = ("function", "input", "options") - input: Seq[ExprIR] - # NOTE: mypy being mypy - the top error can't be silenced 🤦‍♂️ - # narwhals/_plan/expr.py: error: Cannot use a covariant type variable as a parameter [misc] - # narwhals/_plan/expr.py:272:15: error: Cannot use a covariant type variable as a parameter [misc] - # function: FunctionT_co # noqa: ERA001 - # ^ - # Found 2 errors in 1 file (checked 476 source files) - function: FunctionT_co - """Operation applied to each element of `input`.""" + def all(self) -> Self: + return self._with_unary(ir.boolean.All()) - options: FunctionOptions - """Combined flags from chained operations.""" + def is_duplicated(self) -> Self: + return self._with_unary(ir.boolean.IsDuplicated()) - @property - def is_scalar(self) -> bool: - return self.function.is_scalar + def is_finite(self) -> Self: + return self._with_unary(ir.boolean.IsFinite()) - def __repr__(self) -> str: - if self.input: - first = self.input[0] - if len(self.input) >= 2: - return f"{first!r}.{self.function!r}({list(self.input[1:])!r})" - return f"{first!r}.{self.function!r}()" - return f"{self.function!r}()" + def is_nan(self) -> Self: + return self._with_unary(ir.boolean.IsNan()) - def iter_output_name(self) -> t.Iterator[ExprIR]: - """When we have multiple inputs, we want the name of the left-most expression. + def is_null(self) -> Self: + return self._with_unary(ir.boolean.IsNull()) - For expr: + def is_first_distinct(self) -> Self: + return self._with_unary(ir.boolean.IsFirstDistinct()) - col("c").alias("x").fill_null(50) + def is_last_distinct(self) -> Self: + return self._with_unary(ir.boolean.IsLastDistinct()) - We are interested in the name which comes from the root: + def is_unique(self) -> Self: + return self._with_unary(ir.boolean.IsUnique()) - FunctionExpr(..., [Alias(..., name='...'), Literal(...), ...]) - # ^^^^^ ^^^ - """ - for e in self.input[:1]: - yield from e.iter_output_name() + def is_between( + self, + lower_bound: IntoExpr, + upper_bound: IntoExpr, + closed: ClosedInterval = "both", + ) -> Self: + it = parse_into_seq_of_expr_ir(lower_bound, upper_bound) + return self._from_ir( + ir.boolean.IsBetween(closed=closed).to_function_expr(self._ir, *it) + ) - def __init__( + def is_in(self, other: Iterable[Any]) -> Self: + if is_series(other): + return self._with_unary(ir.boolean.IsInSeries.from_series(other)) + if isinstance(other, Iterable): + return self._with_unary(ir.boolean.IsInSeq.from_iterable(other)) + if is_expr(other): + return self._with_unary(ir.boolean.IsInExpr(other=other._ir)) + msg = f"`is_in` only supports iterables, got: {type(other).__name__}" + raise TypeError(msg) + + def _with_binary( self, + op: type[ops.Operator], + other: IntoExpr, *, - input: Seq[ExprIR], # noqa: A002 - function: FunctionT_co, - options: FunctionOptions, - **kwds: t.Any, - ) -> None: - parent = input[0] - if parent.is_scalar and not options.is_elementwise(): - raise function_expr_invalid_operation_error(function, parent) - super().__init__(**dict(input=input, function=function, options=options, **kwds)) + str_as_lit: bool = False, + reflect: bool = False, + ) -> Self: + other_ir = parse_into_expr_ir(other, str_as_lit=str_as_lit) + args = (self._ir, other_ir) if not reflect else (other_ir, self._ir) + return self._from_ir(op().to_binary_expr(*args)) - def dispatch( - self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str - ) -> R_co: - return self.function.__expr_ir_dispatch__(ctx, t.cast("Self", self), frame, name) # type: ignore[no-any-return] + def __eq__(self, other: IntoExpr) -> Self: # type: ignore[override] + return self._with_binary(ops.Eq, other, str_as_lit=True) + def __ne__(self, other: IntoExpr) -> Self: # type: ignore[override] + return self._with_binary(ops.NotEq, other, str_as_lit=True) -class RollingExpr(FunctionExpr[RollingT_co]): ... + def __lt__(self, other: IntoExpr) -> Self: + return self._with_binary(ops.Lt, other, str_as_lit=True) + def __le__(self, other: IntoExpr) -> Self: + return self._with_binary(ops.LtEq, other, str_as_lit=True) -class AnonymousExpr( - FunctionExpr["MapBatches"], config=ExprIROptions.renamed("map_batches") -): - """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L158-L166.""" + def __gt__(self, other: IntoExpr) -> Self: + return self._with_binary(ops.Gt, other, str_as_lit=True) - def dispatch( - self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str - ) -> R_co: - return self.__expr_ir_dispatch__(ctx, t.cast("Self", self), frame, name) # type: ignore[no-any-return] + def __ge__(self, other: IntoExpr) -> Self: + return self._with_binary(ops.GtEq, other, str_as_lit=True) + def __add__(self, other: IntoExpr) -> Self: + return self._with_binary(ops.Add, other, str_as_lit=True) -class RangeExpr(FunctionExpr[RangeT_co]): - """E.g. `int_range(...)`. + def __radd__(self, other: IntoExpr) -> Self: + return self._with_binary(ops.Add, other, str_as_lit=True, reflect=True) - Special-cased as it is only allowed scalar inputs, and is row_separable. - """ + def __sub__(self, other: IntoExpr) -> Self: + return self._with_binary(ops.Sub, other) - def __init__( - self, - *, - input: Seq[ExprIR], # noqa: A002 - function: RangeT_co, - options: FunctionOptions, - **kwds: t.Any, - ) -> None: - # NOTE: `IntRange` has 2x scalar inputs, so always triggered error in parent - if len(input) < 2: - msg = f"Expected at least 2 inputs for `{function!r}()`, but got `{len(input)}`.\n`{input}`" - raise InvalidOperationError(msg) - if not all(e.is_scalar for e in input): - msg = f"All inputs for `{function!r}()` must be scalar or aggregations, but got \n`{input}`" - raise InvalidOperationError(msg) - super(ExprIR, self).__init__( - **dict(input=input, function=function, options=options, **kwds) - ) + def __rsub__(self, other: IntoExpr) -> Self: + return self._with_binary(ops.Sub, other, reflect=True) - def __repr__(self) -> str: - return f"{self.function!r}({list(self.input)!r})" + def __mul__(self, other: IntoExpr) -> Self: + return self._with_binary(ops.Multiply, other) + def __rmul__(self, other: IntoExpr) -> Self: + return self._with_binary(ops.Multiply, other, reflect=True) -class Filter(ExprIR, child=("expr", "by")): - __slots__ = ("expr", "by") # noqa: RUF023 - expr: ExprIR - by: ExprIR + def __truediv__(self, other: IntoExpr) -> Self: + return self._with_binary(ops.TrueDivide, other) - @property - def is_scalar(self) -> bool: - return self.expr.is_scalar and self.by.is_scalar + def __rtruediv__(self, other: IntoExpr) -> Self: + return self._with_binary(ops.TrueDivide, other, reflect=True) - def __repr__(self) -> str: - return f"{self.expr!r}.filter({self.by!r})" + def __floordiv__(self, other: IntoExpr) -> Self: + return self._with_binary(ops.FloorDivide, other) - def iter_output_name(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_output_name() + def __rfloordiv__(self, other: IntoExpr) -> Self: + return self._with_binary(ops.FloorDivide, other, reflect=True) + def __mod__(self, other: IntoExpr) -> Self: + return self._with_binary(ops.Modulus, other) -class WindowExpr( - ExprIR, child=("expr", "partition_by"), config=ExprIROptions.renamed("over") -): - """A fully specified `.over()`, that occurred after another expression. + def __rmod__(self, other: IntoExpr) -> Self: + return self._with_binary(ops.Modulus, other, reflect=True) - Related: - - https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L129-L136 - - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/mod.rs#L835-L838 - - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/mod.rs#L840-L876 - """ + def __and__(self, other: IntoExprColumn | int | bool) -> Self: + return self._with_binary(ops.And, other) - __slots__ = ("expr", "partition_by", "options") # noqa: RUF023 - expr: ExprIR - """For lazy backends, this should be the only place we allow `rolling_*`, `cum_*`.""" - partition_by: Seq[ExprIR] - options: Window + def __rand__(self, other: IntoExprColumn | int | bool) -> Self: + return self._with_binary(ops.And, other, reflect=True) - def __repr__(self) -> str: - return f"{self.expr!r}.over({list(self.partition_by)!r})" + def __or__(self, other: IntoExprColumn | int | bool) -> Self: + return self._with_binary(ops.Or, other) - def iter_output_name(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_output_name() + def __ror__(self, other: IntoExprColumn | int | bool) -> Self: + return self._with_binary(ops.Or, other, reflect=True) + def __xor__(self, other: IntoExprColumn | int | bool) -> Self: + return self._with_binary(ops.ExclusiveOr, other) -class OrderedWindowExpr( - WindowExpr, - child=("expr", "partition_by", "order_by"), - config=ExprIROptions.renamed("over_ordered"), -): - __slots__ = ("expr", "partition_by", "order_by", "sort_options", "options") # noqa: RUF023 - expr: ExprIR - partition_by: Seq[ExprIR] - order_by: Seq[ExprIR] - sort_options: SortOptions - options: Window + def __rxor__(self, other: IntoExprColumn | int | bool) -> Self: + return self._with_binary(ops.ExclusiveOr, other, reflect=True) - def __repr__(self) -> str: - order = self.order_by - if not self.partition_by: - args = f"order_by={list(order)!r}" - else: - args = f"partition_by={list(self.partition_by)!r}, order_by={list(order)!r}" - return f"{self.expr!r}.over({args})" + def __pow__(self, exponent: IntoExprColumn | float) -> Self: + exp = parse_into_expr_ir(exponent) + return self._from_ir(F.Pow().to_function_expr(self._ir, exp)) - def iter_root_names(self) -> t.Iterator[ExprIR]: - # NOTE: `order_by` is never considered in `polars` - # To match that behavior for `root_names` - but still expand in all other cases - # - this little escape hatch exists - # https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/plans/iterator.rs#L76-L86 - yield from self.expr.iter_left() - for e in self.partition_by: - yield from e.iter_left() - yield self + def __rpow__(self, base: IntoExprColumn | float) -> Self: + return self._from_ir(F.Pow().to_function_expr(parse_into_expr_ir(base), self._ir)) + def __invert__(self) -> Self: + return self._with_unary(ir.boolean.Not()) -class Len(ExprIR, config=ExprIROptions.namespaced()): @property - def is_scalar(self) -> bool: - return True - - @property - def name(self) -> str: - return "len" - - def __repr__(self) -> str: - return "len()" - + def meta(self) -> MetaNamespace: + from narwhals._plan.meta import MetaNamespace -class RootSelector(SelectorIR): - """A single selector expression.""" + return MetaNamespace.from_expr(self) - __slots__ = ("selector",) - selector: Selector + @property + def name(self) -> ExprNameNamespace: + """Specialized expressions for modifying the name of existing expressions. + + Examples: + >>> from narwhals import _plan as nw + >>> + >>> renamed = nw.col("a", "b").name.suffix("_changed") + >>> str(renamed._ir) + "RenameAlias(expr=Columns(names=[a, b]), function=Suffix(suffix='_changed'))" + """ + from narwhals._plan.expressions.name import ExprNameNamespace - def __repr__(self) -> str: - return f"{self.selector!r}" + return ExprNameNamespace(_expr=self) - def matches_column(self, name: str, dtype: DType) -> bool: - return self.selector.matches_column(name, dtype) + @property + def cat(self) -> ExprCatNamespace: + from narwhals._plan.expressions.categorical import ExprCatNamespace + return ExprCatNamespace(_expr=self) -class BinarySelector( - _BinaryOp[LeftSelectorT, SelectorOperatorT, RightSelectorT], - SelectorIR, - t.Generic[LeftSelectorT, SelectorOperatorT, RightSelectorT], -): - """Application of two selector exprs via a set operator.""" + @property + def struct(self) -> ExprStructNamespace: + from narwhals._plan.expressions.struct import ExprStructNamespace - def matches_column(self, name: str, dtype: DType) -> bool: - left = self.left.matches_column(name, dtype) - right = self.right.matches_column(name, dtype) - return bool(self.op(left, right)) + return ExprStructNamespace(_expr=self) + @property + def dt(self) -> ExprDateTimeNamespace: + from narwhals._plan.expressions.temporal import ExprDateTimeNamespace -class InvertSelector(SelectorIR, t.Generic[SelectorT]): - __slots__ = ("selector",) - selector: SelectorT + return ExprDateTimeNamespace(_expr=self) - def __repr__(self) -> str: - return f"~{self.selector!r}" + @property + def list(self) -> ExprListNamespace: + from narwhals._plan.expressions.lists import ExprListNamespace - def matches_column(self, name: str, dtype: DType) -> bool: - return not self.selector.matches_column(name, dtype) + return ExprListNamespace(_expr=self) + @property + def str(self) -> ExprStringNamespace: + from narwhals._plan.expressions.strings import ExprStringNamespace -class TernaryExpr(ExprIR, child=("truthy", "falsy", "predicate")): - """When-Then-Otherwise.""" + return ExprStringNamespace(_expr=self) - __slots__ = ("truthy", "falsy", "predicate") # noqa: RUF023 - predicate: ExprIR - truthy: ExprIR - falsy: ExprIR - @property - def is_scalar(self) -> bool: - return self.predicate.is_scalar and self.truthy.is_scalar and self.falsy.is_scalar +class Selector(Expr): + _ir: ir.SelectorIR def __repr__(self) -> str: - return ( - f".when({self.predicate!r}).then({self.truthy!r}).otherwise({self.falsy!r})" - ) - - def iter_output_name(self) -> t.Iterator[ExprIR]: - yield from self.truthy.iter_output_name() + return f"nw._plan.Selector({self.version.name.lower()}):\n{self._ir!r}" + + @classmethod + def _from_ir(cls, selector_ir: ir.SelectorIR, /) -> Self: # type: ignore[override] + obj = cls.__new__(cls) + obj._ir = selector_ir + return obj + + def _to_expr(self) -> Expr: + return self._ir.to_narwhals(self.version) + + @overload # type: ignore[override] + def __or__(self, other: Self) -> Self: ... + @overload + def __or__(self, other: IntoExprColumn | int | bool) -> Expr: ... + def __or__(self, other: IntoExprColumn | int | bool) -> Self | Expr: + if isinstance(other, type(self)): + op = ops.Or() + return self._from_ir(op.to_binary_selector(self._ir, other._ir)) + return self._to_expr() | other + + @overload # type: ignore[override] + def __and__(self, other: Self) -> Self: ... + @overload + def __and__(self, other: IntoExprColumn | int | bool) -> Expr: ... + def __and__(self, other: IntoExprColumn | int | bool) -> Self | Expr: + if is_column(other) and (name := other.meta.output_name()): + other = by_name(name) + if isinstance(other, type(self)): + op = ops.And() + return self._from_ir(op.to_binary_selector(self._ir, other._ir)) + return self._to_expr() & other + + @overload # type: ignore[override] + def __sub__(self, other: Self) -> Self: ... + @overload + def __sub__(self, other: IntoExpr) -> Expr: ... + def __sub__(self, other: IntoExpr) -> Self | Expr: + if isinstance(other, type(self)): + op = ops.Sub() + return self._from_ir(op.to_binary_selector(self._ir, other._ir)) + return self._to_expr() - other + + @overload # type: ignore[override] + def __xor__(self, other: Self) -> Self: ... + @overload + def __xor__(self, other: IntoExprColumn | int | bool) -> Expr: ... + def __xor__(self, other: IntoExprColumn | int | bool) -> Self | Expr: + if isinstance(other, type(self)): + op = ops.ExclusiveOr() + return self._from_ir(op.to_binary_selector(self._ir, other._ir)) + return self._to_expr() ^ other + + def __invert__(self) -> Self: + return self._from_ir(ir.InvertSelector(selector=self._ir)) + + def __add__(self, other: Any) -> Expr: # type: ignore[override] + if isinstance(other, type(self)): + msg = "unsupported operand type(s) for op: ('Selector' + 'Selector')" + raise TypeError(msg) + return self._to_expr() + other # type: ignore[no-any-return] + + def __radd__(self, other: Any) -> Never: + msg = "unsupported operand type(s) for op: ('Expr' + 'Selector')" + raise TypeError(msg) + + def __rsub__(self, other: Any) -> Never: + msg = "unsupported operand type(s) for op: ('Expr' - 'Selector')" + raise TypeError(msg) + + @overload # type: ignore[override] + def __rand__(self, other: Self) -> Self: ... + @overload + def __rand__(self, other: IntoExprColumn | int | bool) -> Expr: ... + def __rand__(self, other: IntoExprColumn | int | bool) -> Self | Expr: + if is_column(other) and (name := other.meta.output_name()): + return by_name(name) & self + return self._to_expr().__rand__(other) + + @overload # type: ignore[override] + def __ror__(self, other: Self) -> Self: ... + @overload + def __ror__(self, other: IntoExprColumn | int | bool) -> Expr: ... + def __ror__(self, other: IntoExprColumn | int | bool) -> Self | Expr: + if is_column(other) and (name := other.meta.output_name()): + return by_name(name) | self + return self._to_expr().__ror__(other) + + @overload # type: ignore[override] + def __rxor__(self, other: Self) -> Self: ... + @overload + def __rxor__(self, other: IntoExprColumn | int | bool) -> Expr: ... + def __rxor__(self, other: IntoExprColumn | int | bool) -> Self | Expr: + if is_column(other) and (name := other.meta.output_name()): + return by_name(name) ^ self + return self._to_expr().__rxor__(other) + + +class ExprV1(Expr): + _version: ClassVar[Version] = Version.V1 + + +class SelectorV1(Selector): + _version: ClassVar[Version] = Version.V1 diff --git a/narwhals/_plan/expressions/__init__.py b/narwhals/_plan/expressions/__init__.py new file mode 100644 index 0000000000..237ee36e81 --- /dev/null +++ b/narwhals/_plan/expressions/__init__.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from narwhals._plan._expr_ir import ( # prob should move into package? + ExprIR, + NamedIR, + SelectorIR, +) +from narwhals._plan.expressions import ( + aggregation, + boolean, + functions, + operators, + selectors, +) +from narwhals._plan.expressions.aggregation import AggExpr, OrderableAggExpr +from narwhals._plan.expressions.expr import ( + Alias, + All, + AnonymousExpr, + BinaryExpr, + BinarySelector, + Cast, + Column, + Columns, + Exclude, + Filter, + FunctionExpr, + IndexColumns, + InvertSelector, + Len, + Literal, + Nth, + OrderedWindowExpr, + RangeExpr, + RollingExpr, + RootSelector, + Sort, + SortBy, + TernaryExpr, + WindowExpr, + _ColumnSelection, # if needs exposing, make it public! + col, + cols, + index_columns, + nth, +) +from narwhals._plan.expressions.name import KeepName, RenameAlias +from narwhals._plan.expressions.window import over, over_ordered + +__all__ = [ + "AggExpr", + "Alias", + "All", + "AnonymousExpr", + "BinaryExpr", + "BinarySelector", + "Cast", + "Column", + "Columns", + "Exclude", + "ExprIR", + "Filter", + "FunctionExpr", + "IndexColumns", + "InvertSelector", + "KeepName", + "Len", + "Literal", + "NamedIR", + "Nth", + "OrderableAggExpr", + "OrderedWindowExpr", + "RangeExpr", + "RenameAlias", + "RollingExpr", + "RootSelector", + "SelectorIR", + "Sort", + "SortBy", + "TernaryExpr", + "WindowExpr", + "_ColumnSelection", + "aggregation", + "boolean", + "col", + "cols", + "functions", + "index_columns", + "nth", + "operators", + "over", + "over_ordered", + "selectors", +] diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/expressions/aggregation.py similarity index 89% rename from narwhals/_plan/aggregation.py rename to narwhals/_plan/expressions/aggregation.py index b1f47ca1d7..263ca300e5 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/expressions/aggregation.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING, Any -from narwhals._plan.common import ExprIR, _pascal_to_snake_case +from narwhals._plan._expr_ir import ExprIR +from narwhals._plan.common import pascal_to_snake_case from narwhals._plan.exceptions import agg_scalar_error if TYPE_CHECKING: @@ -20,7 +21,7 @@ def is_scalar(self) -> bool: return True def __repr__(self) -> str: - return f"{self.expr!r}.{_pascal_to_snake_case(type(self).__name__)}()" + return f"{self.expr!r}.{pascal_to_snake_case(type(self).__name__)}()" def iter_output_name(self) -> Iterator[ExprIR]: yield from self.expr.iter_output_name() diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/expressions/boolean.py similarity index 90% rename from narwhals/_plan/boolean.py rename to narwhals/_plan/expressions/boolean.py index 23f7d27dd3..ebc2a8643b 100644 --- a/narwhals/_plan/boolean.py +++ b/narwhals/_plan/expressions/boolean.py @@ -4,16 +4,16 @@ # - Any import typing as t -from narwhals._plan.common import Function, HorizontalFunction +from narwhals._plan._function import Function, HorizontalFunction from narwhals._plan.options import FEOptions, FunctionOptions from narwhals._typing_compat import TypeVar if t.TYPE_CHECKING: from typing_extensions import Self - from narwhals._plan.common import ExprIR - from narwhals._plan.dummy import Series - from narwhals._plan.expr import FunctionExpr, Literal # noqa: F401 + from narwhals._plan._expr_ir import ExprIR + from narwhals._plan.expressions.expr import FunctionExpr, Literal # noqa: F401 + from narwhals._plan.series import Series from narwhals._plan.typing import NativeSeriesT, Seq # noqa: F401 from narwhals.typing import ClosedInterval @@ -68,7 +68,7 @@ def from_iterable(cls, other: t.Iterable[t.Any], /) -> IsInSeq: class IsInSeries(IsIn["Literal[Series[NativeSeriesT]]"]): @classmethod def from_series(cls, other: Series[NativeSeriesT], /) -> IsInSeries[NativeSeriesT]: - from narwhals._plan.literal import SeriesLiteral + from narwhals._plan.expressions.literal import SeriesLiteral return IsInSeries(other=SeriesLiteral(value=other).to_literal()) diff --git a/narwhals/_plan/categorical.py b/narwhals/_plan/expressions/categorical.py similarity index 77% rename from narwhals/_plan/categorical.py rename to narwhals/_plan/expressions/categorical.py index 13791bed16..7c59fd4443 100644 --- a/narwhals/_plan/categorical.py +++ b/narwhals/_plan/expressions/categorical.py @@ -2,10 +2,11 @@ from typing import TYPE_CHECKING, ClassVar -from narwhals._plan.common import ExprNamespace, Function, IRNamespace +from narwhals._plan._function import Function +from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace if TYPE_CHECKING: - from narwhals._plan.dummy import Expr + from narwhals._plan.expr import Expr # fmt: off diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py new file mode 100644 index 0000000000..a898bf879b --- /dev/null +++ b/narwhals/_plan/expressions/expr.py @@ -0,0 +1,505 @@ +"""Top-level `Expr` nodes.""" + +from __future__ import annotations + +# NOTE: Needed to avoid naming collisions +# - Literal +import typing as t + +from narwhals._plan._expr_ir import ExprIR, SelectorIR +from narwhals._plan.common import flatten_hash_safe +from narwhals._plan.exceptions import function_expr_invalid_operation_error +from narwhals._plan.options import ExprIROptions +from narwhals._plan.typing import ( + FunctionT_co, + LeftSelectorT, + LeftT, + LiteralT, + OperatorT, + RangeT_co, + RightSelectorT, + RightT, + RollingT_co, + SelectorOperatorT, + SelectorT, + Seq, +) +from narwhals.exceptions import InvalidOperationError + +if t.TYPE_CHECKING: + from typing_extensions import Self + + from narwhals._plan.expressions.functions import MapBatches # noqa: F401 + from narwhals._plan.expressions.literal import LiteralValue + from narwhals._plan.expressions.selectors import Selector + from narwhals._plan.expressions.window import Window + from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions + from narwhals._plan.protocols import Ctx, FrameT_contra, R_co + from narwhals.dtypes import DType + +__all__ = [ + "Alias", + "All", + "AnonymousExpr", + "BinaryExpr", + "BinarySelector", + "Cast", + "Column", + "Columns", + "Exclude", + "Filter", + "FunctionExpr", + "IndexColumns", + "Len", + "Literal", + "Nth", + "RollingExpr", + "RootSelector", + "SelectorIR", + "Sort", + "SortBy", + "TernaryExpr", + "WindowExpr", + "col", +] + + +def col(name: str, /) -> Column: + return Column(name=name) + + +def cols(*names: str) -> Columns: + return Columns(names=names) + + +def nth(index: int, /) -> Nth: + return Nth(index=index) + + +def index_columns(*indices: int) -> IndexColumns: + return IndexColumns(indices=indices) + + +class Alias(ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): + __slots__ = ("expr", "name") + expr: ExprIR + name: str + + @property + def is_scalar(self) -> bool: + return self.expr.is_scalar + + def __repr__(self) -> str: + return f"{self.expr!r}.alias({self.name!r})" + + +class Column(ExprIR, config=ExprIROptions.namespaced("col")): + __slots__ = ("name",) + name: str + + def __repr__(self) -> str: + return f"col({self.name!r})" + + +class _ColumnSelection(ExprIR, config=ExprIROptions.no_dispatch()): + """Nodes which can resolve to `Column`(s) with a `Schema`.""" + + +class Columns(_ColumnSelection): + __slots__ = ("names",) + names: Seq[str] + + def __repr__(self) -> str: + return f"cols({list(self.names)!r})" + + +class Nth(_ColumnSelection): + __slots__ = ("index",) + index: int + + def __repr__(self) -> str: + return f"nth({self.index})" + + +class IndexColumns(_ColumnSelection): + __slots__ = ("indices",) + indices: Seq[int] + + def __repr__(self) -> str: + return f"index_columns({self.indices!r})" + + +class All(_ColumnSelection): + def __repr__(self) -> str: + return "all()" + + +class Exclude(_ColumnSelection, child=("expr",)): + __slots__ = ("expr", "names") + expr: ExprIR + """Default is `all()`.""" + names: Seq[str] + """Excluded names.""" + + @staticmethod + def from_names(expr: ExprIR, *names: str | t.Iterable[str]) -> Exclude: + flat: t.Iterator[str] = flatten_hash_safe(names) + return Exclude(expr=expr, names=tuple(flat)) + + def __repr__(self) -> str: + return f"{self.expr!r}.exclude({list(self.names)!r})" + + +class Literal(ExprIR, t.Generic[LiteralT], config=ExprIROptions.namespaced("lit")): + """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L81.""" + + __slots__ = ("value",) + value: LiteralValue[LiteralT] + + @property + def is_scalar(self) -> bool: + return self.value.is_scalar + + @property + def dtype(self) -> DType: + return self.value.dtype + + @property + def name(self) -> str: + return self.value.name + + def __repr__(self) -> str: + return f"lit({self.value!r})" + + def unwrap(self) -> LiteralT: + return self.value.unwrap() + + +class _BinaryOp(ExprIR, t.Generic[LeftT, OperatorT, RightT]): + __slots__ = ("left", "op", "right") + left: LeftT + op: OperatorT + right: RightT + + @property + def is_scalar(self) -> bool: + return self.left.is_scalar and self.right.is_scalar + + def __repr__(self) -> str: + return f"[({self.left!r}) {self.op!r} ({self.right!r})]" + + +class BinaryExpr( + _BinaryOp[LeftT, OperatorT, RightT], + t.Generic[LeftT, OperatorT, RightT], + child=("left", "right"), +): + """Application of two exprs via an `Operator`.""" + + def iter_output_name(self) -> t.Iterator[ExprIR]: + yield from self.left.iter_output_name() + + +class Cast(ExprIR, child=("expr",)): + __slots__ = ("expr", "dtype") # noqa: RUF023 + expr: ExprIR + dtype: DType + + @property + def is_scalar(self) -> bool: + return self.expr.is_scalar + + def __repr__(self) -> str: + return f"{self.expr!r}.cast({self.dtype!r})" + + def iter_output_name(self) -> t.Iterator[ExprIR]: + yield from self.expr.iter_output_name() + + +class Sort(ExprIR, child=("expr",)): + __slots__ = ("expr", "options") + expr: ExprIR + options: SortOptions + + @property + def is_scalar(self) -> bool: + return self.expr.is_scalar + + def __repr__(self) -> str: + direction = "desc" if self.options.descending else "asc" + return f"{self.expr!r}.sort({direction})" + + def iter_output_name(self) -> t.Iterator[ExprIR]: + yield from self.expr.iter_output_name() + + +class SortBy(ExprIR, child=("expr", "by")): + """https://github.com/narwhals-dev/narwhals/issues/2534.""" + + __slots__ = ("expr", "by", "options") # noqa: RUF023 + expr: ExprIR + by: Seq[ExprIR] + options: SortMultipleOptions + + @property + def is_scalar(self) -> bool: + return self.expr.is_scalar + + def __repr__(self) -> str: + return f"{self.expr!r}.sort_by(by={self.by!r}, options={self.options!r})" + + def iter_output_name(self) -> t.Iterator[ExprIR]: + yield from self.expr.iter_output_name() + + +# mypy: disable-error-code="misc" +class FunctionExpr(ExprIR, t.Generic[FunctionT_co], child=("input",)): + """**Representing `Expr::Function`**. + + https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L114-L120 + https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/function_expr/mod.rs#L123 + """ + + __slots__ = ("function", "input", "options") + input: Seq[ExprIR] + # NOTE: mypy being mypy - the top error can't be silenced 🤦‍♂️ + # narwhals/_plan/expr.py: error: Cannot use a covariant type variable as a parameter [misc] + # narwhals/_plan/expr.py:272:15: error: Cannot use a covariant type variable as a parameter [misc] + # function: FunctionT_co # noqa: ERA001 + # ^ + # Found 2 errors in 1 file (checked 476 source files) + function: FunctionT_co + """Operation applied to each element of `input`.""" + + options: FunctionOptions + """Combined flags from chained operations.""" + + @property + def is_scalar(self) -> bool: + return self.function.is_scalar + + def __repr__(self) -> str: + if self.input: + first = self.input[0] + if len(self.input) >= 2: + return f"{first!r}.{self.function!r}({list(self.input[1:])!r})" + return f"{first!r}.{self.function!r}()" + return f"{self.function!r}()" + + def iter_output_name(self) -> t.Iterator[ExprIR]: + """When we have multiple inputs, we want the name of the left-most expression. + + For expr: + + col("c").alias("x").fill_null(50) + + We are interested in the name which comes from the root: + + FunctionExpr(..., [Alias(..., name='...'), Literal(...), ...]) + # ^^^^^ ^^^ + """ + for e in self.input[:1]: + yield from e.iter_output_name() + + def __init__( + self, + *, + input: Seq[ExprIR], # noqa: A002 + function: FunctionT_co, + options: FunctionOptions, + **kwds: t.Any, + ) -> None: + parent = input[0] + if parent.is_scalar and not options.is_elementwise(): + raise function_expr_invalid_operation_error(function, parent) + super().__init__(**dict(input=input, function=function, options=options, **kwds)) + + def dispatch( + self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str + ) -> R_co: + return self.function.__expr_ir_dispatch__(ctx, t.cast("Self", self), frame, name) # type: ignore[no-any-return] + + +class RollingExpr(FunctionExpr[RollingT_co]): ... + + +class AnonymousExpr( + FunctionExpr["MapBatches"], config=ExprIROptions.renamed("map_batches") +): + """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L158-L166.""" + + def dispatch( + self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str + ) -> R_co: + return self.__expr_ir_dispatch__(ctx, t.cast("Self", self), frame, name) # type: ignore[no-any-return] + + +class RangeExpr(FunctionExpr[RangeT_co]): + """E.g. `int_range(...)`. + + Special-cased as it is only allowed scalar inputs, and is row_separable. + """ + + def __init__( + self, + *, + input: Seq[ExprIR], # noqa: A002 + function: RangeT_co, + options: FunctionOptions, + **kwds: t.Any, + ) -> None: + # NOTE: `IntRange` has 2x scalar inputs, so always triggered error in parent + if len(input) < 2: + msg = f"Expected at least 2 inputs for `{function!r}()`, but got `{len(input)}`.\n`{input}`" + raise InvalidOperationError(msg) + if not all(e.is_scalar for e in input): + msg = f"All inputs for `{function!r}()` must be scalar or aggregations, but got \n`{input}`" + raise InvalidOperationError(msg) + super(ExprIR, self).__init__( + **dict(input=input, function=function, options=options, **kwds) + ) + + def __repr__(self) -> str: + return f"{self.function!r}({list(self.input)!r})" + + +class Filter(ExprIR, child=("expr", "by")): + __slots__ = ("expr", "by") # noqa: RUF023 + expr: ExprIR + by: ExprIR + + @property + def is_scalar(self) -> bool: + return self.expr.is_scalar and self.by.is_scalar + + def __repr__(self) -> str: + return f"{self.expr!r}.filter({self.by!r})" + + def iter_output_name(self) -> t.Iterator[ExprIR]: + yield from self.expr.iter_output_name() + + +class WindowExpr( + ExprIR, child=("expr", "partition_by"), config=ExprIROptions.renamed("over") +): + """A fully specified `.over()`, that occurred after another expression. + + Related: + - https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L129-L136 + - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/mod.rs#L835-L838 + - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/mod.rs#L840-L876 + """ + + __slots__ = ("expr", "partition_by", "options") # noqa: RUF023 + expr: ExprIR + """For lazy backends, this should be the only place we allow `rolling_*`, `cum_*`.""" + partition_by: Seq[ExprIR] + options: Window + + def __repr__(self) -> str: + return f"{self.expr!r}.over({list(self.partition_by)!r})" + + def iter_output_name(self) -> t.Iterator[ExprIR]: + yield from self.expr.iter_output_name() + + +class OrderedWindowExpr( + WindowExpr, + child=("expr", "partition_by", "order_by"), + config=ExprIROptions.renamed("over_ordered"), +): + __slots__ = ("expr", "partition_by", "order_by", "sort_options", "options") # noqa: RUF023 + expr: ExprIR + partition_by: Seq[ExprIR] + order_by: Seq[ExprIR] + sort_options: SortOptions + options: Window + + def __repr__(self) -> str: + order = self.order_by + if not self.partition_by: + args = f"order_by={list(order)!r}" + else: + args = f"partition_by={list(self.partition_by)!r}, order_by={list(order)!r}" + return f"{self.expr!r}.over({args})" + + def iter_root_names(self) -> t.Iterator[ExprIR]: + # NOTE: `order_by` is never considered in `polars` + # To match that behavior for `root_names` - but still expand in all other cases + # - this little escape hatch exists + # https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/plans/iterator.rs#L76-L86 + yield from self.expr.iter_left() + for e in self.partition_by: + yield from e.iter_left() + yield self + + +class Len(ExprIR, config=ExprIROptions.namespaced()): + @property + def is_scalar(self) -> bool: + return True + + @property + def name(self) -> str: + return "len" + + def __repr__(self) -> str: + return "len()" + + +class RootSelector(SelectorIR): + """A single selector expression.""" + + __slots__ = ("selector",) + selector: Selector + + def __repr__(self) -> str: + return f"{self.selector!r}" + + def matches_column(self, name: str, dtype: DType) -> bool: + return self.selector.matches_column(name, dtype) + + +class BinarySelector( + _BinaryOp[LeftSelectorT, SelectorOperatorT, RightSelectorT], + SelectorIR, + t.Generic[LeftSelectorT, SelectorOperatorT, RightSelectorT], +): + """Application of two selector exprs via a set operator.""" + + def matches_column(self, name: str, dtype: DType) -> bool: + left = self.left.matches_column(name, dtype) + right = self.right.matches_column(name, dtype) + return bool(self.op(left, right)) + + +class InvertSelector(SelectorIR, t.Generic[SelectorT]): + __slots__ = ("selector",) + selector: SelectorT + + def __repr__(self) -> str: + return f"~{self.selector!r}" + + def matches_column(self, name: str, dtype: DType) -> bool: + return not self.selector.matches_column(name, dtype) + + +class TernaryExpr(ExprIR, child=("truthy", "falsy", "predicate")): + """When-Then-Otherwise.""" + + __slots__ = ("truthy", "falsy", "predicate") # noqa: RUF023 + predicate: ExprIR + truthy: ExprIR + falsy: ExprIR + + @property + def is_scalar(self) -> bool: + return self.predicate.is_scalar and self.truthy.is_scalar and self.falsy.is_scalar + + def __repr__(self) -> str: + return ( + f".when({self.predicate!r}).then({self.truthy!r}).otherwise({self.falsy!r})" + ) + + def iter_output_name(self) -> t.Iterator[ExprIR]: + yield from self.truthy.iter_output_name() diff --git a/narwhals/_plan/expressions/functions.py b/narwhals/_plan/expressions/functions.py new file mode 100644 index 0000000000..f7ff80abe5 --- /dev/null +++ b/narwhals/_plan/expressions/functions.py @@ -0,0 +1,182 @@ +"""General functions that aren't namespaced.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from narwhals._plan._function import Function, HorizontalFunction +from narwhals._plan.exceptions import hist_bins_monotonic_error +from narwhals._plan.options import FunctionFlags, FunctionOptions + +if TYPE_CHECKING: + from typing import Any + + from typing_extensions import Self + + from narwhals._plan._expr_ir import ExprIR + from narwhals._plan.expressions.expr import AnonymousExpr, FunctionExpr, RollingExpr + from narwhals._plan.options import EWMOptions, RankOptions, RollingOptionsFixedWindow + from narwhals._plan.typing import Seq, Udf + from narwhals.dtypes import DType + from narwhals.typing import FillNullStrategy + + +class CumAgg(Function, options=FunctionOptions.length_preserving): + __slots__ = ("reverse",) + reverse: bool + + +class RollingWindow(Function, options=FunctionOptions.length_preserving): + __slots__ = ("options",) + options: RollingOptionsFixedWindow + + def to_function_expr(self, *inputs: ExprIR) -> RollingExpr[Self]: + from narwhals._plan.expressions.expr import RollingExpr + + options = self.function_options + return RollingExpr(input=inputs, function=self, options=options) + + +# fmt: off +class Abs(Function, options=FunctionOptions.elementwise): ... +class NullCount(Function, options=FunctionOptions.aggregation): ... +class Exp(Function, options=FunctionOptions.elementwise): ... +class Sqrt(Function, options=FunctionOptions.elementwise): ... +class DropNulls(Function, options=FunctionOptions.row_separable): ... +class Mode(Function): ... +class Skew(Function, options=FunctionOptions.aggregation): ... +class Clip(Function, options=FunctionOptions.elementwise): ... +class CumCount(CumAgg): ... +class CumMin(CumAgg): ... +class CumMax(CumAgg): ... +class CumProd(CumAgg): ... +class CumSum(CumAgg): ... +class RollingSum(RollingWindow): ... +class RollingMean(RollingWindow): ... +class RollingVar(RollingWindow): ... +class RollingStd(RollingWindow): ... +class Diff(Function, options=FunctionOptions.length_preserving): ... +class Unique(Function): ... +class SumHorizontal(HorizontalFunction): ... +class MinHorizontal(HorizontalFunction): ... +class MaxHorizontal(HorizontalFunction): ... +class MeanHorizontal(HorizontalFunction): ... +# fmt: on +class Hist(Function): + """Only supported for `Series` so far.""" + + __slots__ = ("include_breakpoint",) + include_breakpoint: bool + + def __repr__(self) -> str: + return "hist" + + +class HistBins(Hist): + __slots__ = ("bins", *Hist.__slots__) + bins: Seq[float] + + def __init__(self, *, bins: Seq[float], include_breakpoint: bool = True) -> None: + for i in range(1, len(bins)): + if bins[i - 1] >= bins[i]: + raise hist_bins_monotonic_error(bins) + object.__setattr__(self, "bins", bins) + object.__setattr__(self, "include_breakpoint", include_breakpoint) + + +class HistBinCount(Hist): + __slots__ = ("bin_count", *Hist.__slots__) + bin_count: int + + def __init__(self, *, bin_count: int = 10, include_breakpoint: bool = True) -> None: + object.__setattr__(self, "bin_count", bin_count) + object.__setattr__(self, "include_breakpoint", include_breakpoint) + + +class Log(Function, options=FunctionOptions.elementwise): + __slots__ = ("base",) + base: float + + +class Pow(Function, options=FunctionOptions.elementwise): + """N-ary (base, exponent).""" + + def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: + base, exponent = node.input + return base, exponent + + +class Kurtosis(Function, options=FunctionOptions.aggregation): + __slots__ = ("bias", "fisher") + fisher: bool + bias: bool + + +class FillNull(Function, options=FunctionOptions.elementwise): + """N-ary (expr, value).""" + + def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: + expr, value = node.input + return expr, value + + +class FillNullWithStrategy(Function): + __slots__ = ("limit", "strategy") + strategy: FillNullStrategy + limit: int | None + + +class Shift(Function, options=FunctionOptions.length_preserving): + __slots__ = ("n",) + n: int + + +class Rank(Function): + __slots__ = ("options",) + options: RankOptions + + +class Round(Function, options=FunctionOptions.elementwise): + __slots__ = ("decimals",) + decimals: int + + +class EwmMean(Function, options=FunctionOptions.length_preserving): + __slots__ = ("options",) + options: EWMOptions + + +class ReplaceStrict(Function, options=FunctionOptions.elementwise): + __slots__ = ("new", "old", "return_dtype") + old: Seq[Any] + new: Seq[Any] + return_dtype: DType | None + + +class GatherEvery(Function): + __slots__ = ("n", "offset") + n: int + offset: int + + +class MapBatches(Function): + __slots__ = ("function", "is_elementwise", "return_dtype", "returns_scalar") + function: Udf + return_dtype: DType | None + is_elementwise: bool + returns_scalar: bool + + @property + def function_options(self) -> FunctionOptions: + options = super().function_options + if self.is_elementwise: + options = options.with_elementwise() + if self.returns_scalar: + options = options.with_flags(FunctionFlags.RETURNS_SCALAR) + return options + + def to_function_expr(self, *inputs: ExprIR) -> AnonymousExpr: + from narwhals._plan.expressions.expr import AnonymousExpr + + options = self.function_options + return AnonymousExpr(input=inputs, function=self, options=options) diff --git a/narwhals/_plan/lists.py b/narwhals/_plan/expressions/lists.py similarity index 78% rename from narwhals/_plan/lists.py rename to narwhals/_plan/expressions/lists.py index f4a45f217f..604e054a5e 100644 --- a/narwhals/_plan/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -2,11 +2,12 @@ from typing import TYPE_CHECKING, ClassVar -from narwhals._plan.common import ExprNamespace, Function, IRNamespace +from narwhals._plan._function import Function +from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: - from narwhals._plan.dummy import Expr + from narwhals._plan.expr import Expr # fmt: off diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/expressions/literal.py similarity index 92% rename from narwhals/_plan/literal.py rename to narwhals/_plan/expressions/literal.py index 22c170508c..7d46c8436c 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/expressions/literal.py @@ -9,8 +9,8 @@ if TYPE_CHECKING: from typing_extensions import TypeIs - from narwhals._plan.dummy import Series - from narwhals._plan.expr import Literal + from narwhals._plan.expressions.expr import Literal + from narwhals._plan.series import Series from narwhals.dtypes import DType @@ -30,7 +30,7 @@ def is_scalar(self) -> bool: return False def to_literal(self) -> Literal[LiteralT]: - from narwhals._plan.expr import Literal + from narwhals._plan.expressions.expr import Literal return Literal(value=self) diff --git a/narwhals/_plan/name.py b/narwhals/_plan/expressions/name.py similarity index 85% rename from narwhals/_plan/name.py rename to narwhals/_plan/expressions/name.py index 4147f20450..7f1e71fb09 100644 --- a/narwhals/_plan/name.py +++ b/narwhals/_plan/expressions/name.py @@ -2,18 +2,19 @@ from typing import TYPE_CHECKING -from narwhals._plan import common +from narwhals._plan._expr_ir import ExprIR from narwhals._plan._immutable import Immutable +from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import ExprIROptions if TYPE_CHECKING: from narwhals._compliant.typing import AliasName - from narwhals._plan.dummy import Expr + from narwhals._plan.expr import Expr -class KeepName(common.ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): +class KeepName(ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): __slots__ = ("expr",) - expr: common.ExprIR + expr: ExprIR @property def is_scalar(self) -> bool: @@ -23,9 +24,9 @@ def __repr__(self) -> str: return f"{self.expr!r}.name.keep()" -class RenameAlias(common.ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): +class RenameAlias(ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): __slots__ = ("expr", "function") - expr: common.ExprIR + expr: ExprIR function: AliasName @property @@ -52,7 +53,7 @@ def __call__(self, name: str, /) -> str: return f"{name}{self.suffix}" -class IRNameNamespace(common.IRNamespace): +class IRNameNamespace(IRNamespace): def keep(self) -> KeepName: return KeepName(expr=self._ir) @@ -72,7 +73,7 @@ def to_uppercase(self) -> RenameAlias: return self.map(str.upper) -class ExprNameNamespace(common.ExprNamespace[IRNameNamespace]): +class ExprNameNamespace(ExprNamespace[IRNameNamespace]): @property def _ir_namespace(self) -> type[IRNameNamespace]: return IRNameNamespace diff --git a/narwhals/_plan/expressions/namespace.py b/narwhals/_plan/expressions/namespace.py new file mode 100644 index 0000000000..be548dba4c --- /dev/null +++ b/narwhals/_plan/expressions/namespace.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic + +from narwhals._plan._immutable import Immutable +from narwhals._plan.typing import IRNamespaceT + +if TYPE_CHECKING: + from typing_extensions import Self + + from narwhals._plan._function import Function + from narwhals._plan.expr import Expr + from narwhals._plan.expressions import ExprIR + + +class IRNamespace(Immutable): + __slots__ = ("_ir",) + _ir: ExprIR + + @classmethod + def from_expr(cls, expr: Expr, /) -> Self: + return cls(_ir=expr._ir) + + +class ExprNamespace(Immutable, Generic[IRNamespaceT]): + __slots__ = ("_expr",) + _expr: Expr + + @property + def _ir_namespace(self) -> type[IRNamespaceT]: + raise NotImplementedError + + @property + def _ir(self) -> IRNamespaceT: + return self._ir_namespace.from_expr(self._expr) + + def _to_narwhals(self, ir: ExprIR, /) -> Expr: + return self._expr._from_ir(ir) + + def _with_unary(self, function: Function, /) -> Expr: + return self._expr._with_unary(function) diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/expressions/operators.py similarity index 93% rename from narwhals/_plan/operators.py rename to narwhals/_plan/expressions/operators.py index 78b33b042f..9ecc45737a 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/expressions/operators.py @@ -16,8 +16,7 @@ from typing_extensions import Self - from narwhals._plan.common import ExprIR - from narwhals._plan.expr import BinaryExpr, BinarySelector + from narwhals._plan.expressions import BinaryExpr, BinarySelector, ExprIR from narwhals._plan.typing import ( LeftSelectorT, LeftT, @@ -45,7 +44,7 @@ def __init_subclass__( def to_binary_expr( self, left: LeftT, right: RightT, / ) -> BinaryExpr[LeftT, Self, RightT]: - from narwhals._plan.expr import BinaryExpr + from narwhals._plan.expressions.expr import BinaryExpr if right.meta.has_multiple_outputs(): raise binary_expr_multi_output_error(left, self, right) @@ -74,7 +73,7 @@ class SelectorOperator(Operator, func=None): def to_binary_selector( self, left: LeftSelectorT, right: RightSelectorT, / ) -> BinarySelector[LeftSelectorT, Self, RightSelectorT]: - from narwhals._plan.expr import BinarySelector + from narwhals._plan.expressions.expr import BinarySelector return BinarySelector(left=left, op=self, right=right) diff --git a/narwhals/_plan/ranges.py b/narwhals/_plan/expressions/ranges.py similarity index 82% rename from narwhals/_plan/ranges.py rename to narwhals/_plan/expressions/ranges.py index 4f8e49b531..6befe3fa6d 100644 --- a/narwhals/_plan/ranges.py +++ b/narwhals/_plan/expressions/ranges.py @@ -2,19 +2,19 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import ExprIR, Function +from narwhals._plan._function import Function from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: from typing_extensions import Self - from narwhals._plan.expr import RangeExpr + from narwhals._plan.expressions import ExprIR, RangeExpr from narwhals.dtypes import IntegerType class RangeFunction(Function, config=FEOptions.namespaced()): def to_function_expr(self, *inputs: ExprIR) -> RangeExpr[Self]: - from narwhals._plan.expr import RangeExpr + from narwhals._plan.expressions.expr import RangeExpr return RangeExpr(input=inputs, function=self, options=self.function_options) diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/expressions/selectors.py similarity index 90% rename from narwhals/_plan/selectors.py rename to narwhals/_plan/expressions/selectors.py index 4aa5f58a3d..5d6bfa8292 100644 --- a/narwhals/_plan/selectors.py +++ b/narwhals/_plan/expressions/selectors.py @@ -18,8 +18,8 @@ from datetime import timezone from typing import TypeVar - from narwhals._plan import dummy - from narwhals._plan.expr import RootSelector + from narwhals._plan import expr + from narwhals._plan.expressions.expr import RootSelector from narwhals._plan.typing import OneOrIterable from narwhals.dtypes import DType from narwhals.typing import TimeUnit @@ -31,7 +31,7 @@ class Selector(Immutable): def to_selector(self) -> RootSelector: - from narwhals._plan.expr import RootSelector + from narwhals._plan.expressions.expr import RootSelector return RootSelector(selector=self) @@ -153,30 +153,30 @@ def matches_column(self, name: str, dtype: DType) -> bool: return isinstance(dtype, dtypes.String) -def all() -> dummy.Selector: +def all() -> expr.Selector: return All().to_selector().to_narwhals() -def by_dtype(*dtypes: OneOrIterable[DType | type[DType]]) -> dummy.Selector: +def by_dtype(*dtypes: OneOrIterable[DType | type[DType]]) -> expr.Selector: return ByDType.from_dtypes(*dtypes).to_selector().to_narwhals() -def by_name(*names: OneOrIterable[str]) -> dummy.Selector: +def by_name(*names: OneOrIterable[str]) -> expr.Selector: return Matches.from_names(*names).to_selector().to_narwhals() -def boolean() -> dummy.Selector: +def boolean() -> expr.Selector: return Boolean().to_selector().to_narwhals() -def categorical() -> dummy.Selector: +def categorical() -> expr.Selector: return Categorical().to_selector().to_narwhals() def datetime( time_unit: OneOrIterable[TimeUnit] | None = None, time_zone: OneOrIterable[str | timezone | None] = ("*", None), -) -> dummy.Selector: +) -> expr.Selector: return ( Datetime.from_time_unit_and_time_zone(time_unit, time_zone) .to_selector() @@ -184,13 +184,13 @@ def datetime( ) -def matches(pattern: str) -> dummy.Selector: +def matches(pattern: str) -> expr.Selector: return Matches.from_string(pattern).to_selector().to_narwhals() -def numeric() -> dummy.Selector: +def numeric() -> expr.Selector: return Numeric().to_selector().to_narwhals() -def string() -> dummy.Selector: +def string() -> expr.Selector: return String().to_selector().to_narwhals() diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/expressions/strings.py similarity index 96% rename from narwhals/_plan/strings.py rename to narwhals/_plan/expressions/strings.py index 4c1f4af303..6e60a7b530 100644 --- a/narwhals/_plan/strings.py +++ b/narwhals/_plan/expressions/strings.py @@ -2,11 +2,12 @@ from typing import TYPE_CHECKING, ClassVar -from narwhals._plan.common import ExprNamespace, Function, HorizontalFunction, IRNamespace +from narwhals._plan._function import Function, HorizontalFunction +from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: - from narwhals._plan.dummy import Expr + from narwhals._plan.expr import Expr # fmt: off diff --git a/narwhals/_plan/struct.py b/narwhals/_plan/expressions/struct.py similarity index 83% rename from narwhals/_plan/struct.py rename to narwhals/_plan/expressions/struct.py index 2a3eca0b27..e3625adb8a 100644 --- a/narwhals/_plan/struct.py +++ b/narwhals/_plan/expressions/struct.py @@ -2,11 +2,12 @@ from typing import TYPE_CHECKING, ClassVar -from narwhals._plan.common import ExprNamespace, Function, IRNamespace +from narwhals._plan._function import Function +from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: - from narwhals._plan.dummy import Expr + from narwhals._plan.expr import Expr class StructFunction(Function, accessor="struct"): ... diff --git a/narwhals/_plan/temporal.py b/narwhals/_plan/expressions/temporal.py similarity index 97% rename from narwhals/_plan/temporal.py rename to narwhals/_plan/expressions/temporal.py index bd21388728..11a87599ab 100644 --- a/narwhals/_plan/temporal.py +++ b/narwhals/_plan/expressions/temporal.py @@ -3,14 +3,15 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal from narwhals._duration import Interval -from narwhals._plan.common import ExprNamespace, Function, IRNamespace +from narwhals._plan._function import Function +from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: from typing_extensions import TypeAlias, TypeIs from narwhals._duration import IntervalUnit - from narwhals._plan.dummy import Expr + from narwhals._plan.expr import Expr from narwhals.typing import TimeUnit PolarsTimeUnit: TypeAlias = Literal["ns", "us", "ms"] diff --git a/narwhals/_plan/expressions/window.py b/narwhals/_plan/expressions/window.py new file mode 100644 index 0000000000..772af084f0 --- /dev/null +++ b/narwhals/_plan/expressions/window.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from narwhals._plan._guards import is_function_expr, is_window_expr +from narwhals._plan._immutable import Immutable +from narwhals._plan.exceptions import ( + over_elementwise_error as elementwise_error, + over_nested_error as nested_error, + over_row_separable_error as row_separable_error, +) +from narwhals._plan.expressions.expr import OrderedWindowExpr, WindowExpr +from narwhals._plan.options import SortOptions + +if TYPE_CHECKING: + from narwhals._plan.expressions import ExprIR + from narwhals._plan.typing import Seq + + +class Window(Immutable): + """Renamed from `WindowType` https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/options/mod.rs#L139.""" + + +class Over(Window): + @staticmethod + def _validate_over( + expr: ExprIR, + partition_by: Seq[ExprIR], + order_by: Seq[ExprIR] = (), + sort_options: SortOptions | None = None, + /, + ) -> ValueError | None: + if is_window_expr(expr): + return nested_error(expr, partition_by, order_by, sort_options) + if is_function_expr(expr): + if expr.options.is_elementwise(): + return elementwise_error(expr, partition_by, order_by, sort_options) + if expr.options.is_row_separable(): + return row_separable_error(expr, partition_by, order_by, sort_options) + return None + + +def over(expr: ExprIR, partition_by: Seq[ExprIR], /) -> WindowExpr: + if err := Over._validate_over(expr, partition_by): + raise err + return WindowExpr(expr=expr, partition_by=partition_by, options=Over()) + + +def over_ordered( + expr: ExprIR, + partition_by: Seq[ExprIR], + order_by: Seq[ExprIR], + /, + *, + descending: bool = False, + nulls_last: bool = False, +) -> OrderedWindowExpr: + sort_options = SortOptions(descending=descending, nulls_last=nulls_last) + if err := Over._validate_over(expr, partition_by, order_by, sort_options): + raise err + return OrderedWindowExpr( + expr=expr, + partition_by=partition_by, + order_by=order_by, + sort_options=sort_options, + options=Over(), + ) diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 570d75b4d0..c07fe92c29 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -1,182 +1,161 @@ -"""General functions that aren't namespaced.""" - from __future__ import annotations -from typing import TYPE_CHECKING - -from narwhals._plan.common import Function, HorizontalFunction -from narwhals._plan.exceptions import hist_bins_monotonic_error -from narwhals._plan.options import FunctionFlags, FunctionOptions - -if TYPE_CHECKING: - from typing import Any - - from typing_extensions import Self - - from narwhals._plan.common import ExprIR - from narwhals._plan.expr import AnonymousExpr, FunctionExpr, RollingExpr - from narwhals._plan.options import EWMOptions, RankOptions, RollingOptionsFixedWindow - from narwhals._plan.typing import Seq, Udf - from narwhals.dtypes import DType - from narwhals.typing import FillNullStrategy - +import builtins +import typing as t -class CumAgg(Function, options=FunctionOptions.length_preserving): - __slots__ = ("reverse",) - reverse: bool +from narwhals._plan import _guards, _parse, common, expressions as ir +from narwhals._plan.expressions import functions as F +from narwhals._plan.expressions.literal import ScalarLiteral, SeriesLiteral +from narwhals._plan.expressions.ranges import IntRange +from narwhals._plan.expressions.strings import ConcatStr +from narwhals._plan.when_then import When +from narwhals._utils import Version, flatten +if t.TYPE_CHECKING: + from narwhals._plan.expr import Expr + from narwhals._plan.series import Series + from narwhals._plan.typing import IntoExpr, IntoExprColumn, NativeSeriesT + from narwhals.dtypes import IntegerType + from narwhals.typing import IntoDType, NonNestedLiteral -class RollingWindow(Function, options=FunctionOptions.length_preserving): - __slots__ = ("options",) - options: RollingOptionsFixedWindow - def to_function_expr(self, *inputs: ExprIR) -> RollingExpr[Self]: - from narwhals._plan.expr import RollingExpr +def col(*names: str | t.Iterable[str]) -> Expr: + flat = tuple(flatten(names)) + node = ir.col(flat[0]) if builtins.len(flat) == 1 else ir.cols(*flat) + return node.to_narwhals() - options = self.function_options - return RollingExpr(input=inputs, function=self, options=options) +def nth(*indices: int | t.Sequence[int]) -> Expr: + flat = tuple(flatten(indices)) + node = ir.nth(flat[0]) if builtins.len(flat) == 1 else ir.index_columns(*flat) + return node.to_narwhals() -# fmt: off -class Abs(Function, options=FunctionOptions.elementwise): ... -class NullCount(Function, options=FunctionOptions.aggregation): ... -class Exp(Function, options=FunctionOptions.elementwise): ... -class Sqrt(Function, options=FunctionOptions.elementwise): ... -class DropNulls(Function, options=FunctionOptions.row_separable): ... -class Mode(Function): ... -class Skew(Function, options=FunctionOptions.aggregation): ... -class Clip(Function, options=FunctionOptions.elementwise): ... -class CumCount(CumAgg): ... -class CumMin(CumAgg): ... -class CumMax(CumAgg): ... -class CumProd(CumAgg): ... -class CumSum(CumAgg): ... -class RollingSum(RollingWindow): ... -class RollingMean(RollingWindow): ... -class RollingVar(RollingWindow): ... -class RollingStd(RollingWindow): ... -class Diff(Function, options=FunctionOptions.length_preserving): ... -class Unique(Function): ... -class SumHorizontal(HorizontalFunction): ... -class MinHorizontal(HorizontalFunction): ... -class MaxHorizontal(HorizontalFunction): ... -class MeanHorizontal(HorizontalFunction): ... -# fmt: on -class Hist(Function): - """Only supported for `Series` so far.""" - __slots__ = ("include_breakpoint",) - include_breakpoint: bool +def lit( + value: NonNestedLiteral | Series[NativeSeriesT], dtype: IntoDType | None = None +) -> Expr: + if _guards.is_series(value): + return SeriesLiteral(value=value).to_literal().to_narwhals() + if not _guards.is_non_nested_literal(value): + msg = f"{type(value).__name__!r} is not supported in `nw.lit`, got: {value!r}." + raise TypeError(msg) + if dtype is None: + dtype = common.py_to_narwhals_dtype(value, Version.MAIN) + else: + dtype = common.into_dtype(dtype) + return ScalarLiteral(value=value, dtype=dtype).to_literal().to_narwhals() - def __repr__(self) -> str: - return "hist" +def len() -> Expr: + return ir.Len().to_narwhals() -class HistBins(Hist): - __slots__ = ("bins", *Hist.__slots__) - bins: Seq[float] - def __init__(self, *, bins: Seq[float], include_breakpoint: bool = True) -> None: - for i in range(1, len(bins)): - if bins[i - 1] >= bins[i]: - raise hist_bins_monotonic_error(bins) - object.__setattr__(self, "bins", bins) - object.__setattr__(self, "include_breakpoint", include_breakpoint) +def all() -> Expr: + return ir.All().to_narwhals() -class HistBinCount(Hist): - __slots__ = ("bin_count", *Hist.__slots__) - bin_count: int +def exclude(*names: str | t.Iterable[str]) -> Expr: + return all().exclude(*names) - def __init__(self, *, bin_count: int = 10, include_breakpoint: bool = True) -> None: - object.__setattr__(self, "bin_count", bin_count) - object.__setattr__(self, "include_breakpoint", include_breakpoint) +def max(*columns: str) -> Expr: + return col(columns).max() -class Log(Function, options=FunctionOptions.elementwise): - __slots__ = ("base",) - base: float +def mean(*columns: str) -> Expr: + return col(columns).mean() -class Pow(Function, options=FunctionOptions.elementwise): - """N-ary (base, exponent).""" - def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: - base, exponent = node.input - return base, exponent +def min(*columns: str) -> Expr: + return col(columns).min() -class Kurtosis(Function, options=FunctionOptions.aggregation): - __slots__ = ("bias", "fisher") - fisher: bool - bias: bool +def median(*columns: str) -> Expr: + return col(columns).median() -class FillNull(Function, options=FunctionOptions.elementwise): - """N-ary (expr, value).""" +def sum(*columns: str) -> Expr: + return col(columns).sum() - def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: - expr, value = node.input - return expr, value +def all_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: + it = _parse.parse_into_seq_of_expr_ir(*exprs) + return ir.boolean.AllHorizontal().to_function_expr(*it).to_narwhals() -class FillNullWithStrategy(Function): - __slots__ = ("limit", "strategy") - strategy: FillNullStrategy - limit: int | None +def any_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: + it = _parse.parse_into_seq_of_expr_ir(*exprs) + return ir.boolean.AnyHorizontal().to_function_expr(*it).to_narwhals() -class Shift(Function, options=FunctionOptions.length_preserving): - __slots__ = ("n",) - n: int +def sum_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: + it = _parse.parse_into_seq_of_expr_ir(*exprs) + return F.SumHorizontal().to_function_expr(*it).to_narwhals() -class Rank(Function): - __slots__ = ("options",) - options: RankOptions +def min_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: + it = _parse.parse_into_seq_of_expr_ir(*exprs) + return F.MinHorizontal().to_function_expr(*it).to_narwhals() -class Round(Function, options=FunctionOptions.elementwise): - __slots__ = ("decimals",) - decimals: int +def max_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: + it = _parse.parse_into_seq_of_expr_ir(*exprs) + return F.MaxHorizontal().to_function_expr(*it).to_narwhals() -class EwmMean(Function, options=FunctionOptions.length_preserving): - __slots__ = ("options",) - options: EWMOptions +def mean_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: + it = _parse.parse_into_seq_of_expr_ir(*exprs) + return F.MeanHorizontal().to_function_expr(*it).to_narwhals() -class ReplaceStrict(Function, options=FunctionOptions.elementwise): - __slots__ = ("new", "old", "return_dtype") - old: Seq[Any] - new: Seq[Any] - return_dtype: DType | None +def concat_str( + exprs: IntoExpr | t.Iterable[IntoExpr], + *more_exprs: IntoExpr, + separator: str = "", + ignore_nulls: bool = False, +) -> Expr: + it = _parse.parse_into_seq_of_expr_ir(exprs, *more_exprs) + return ( + ConcatStr(separator=separator, ignore_nulls=ignore_nulls) + .to_function_expr(*it) + .to_narwhals() + ) -class GatherEvery(Function): - __slots__ = ("n", "offset") - n: int - offset: int +def when( + *predicates: IntoExprColumn | t.Iterable[IntoExprColumn], **constraints: t.Any +) -> When: + """Start a `when-then-otherwise` expression. -class MapBatches(Function): - __slots__ = ("function", "is_elementwise", "return_dtype", "returns_scalar") - function: Udf - return_dtype: DType | None - is_elementwise: bool - returns_scalar: bool + Examples: + >>> from narwhals import _plan as nw - @property - def function_options(self) -> FunctionOptions: - options = super().function_options - if self.is_elementwise: - options = options.with_elementwise() - if self.returns_scalar: - options = options.with_flags(FunctionFlags.RETURNS_SCALAR) - return options + >>> nw.when(nw.col("y") == "b").then(1) + nw._plan.Expr(main): + .when([(col('y')) == (lit(str: b))]).then(lit(int: 1)).otherwise(lit(null)) + """ + condition = _parse.parse_predicates_constraints_into_expr_ir( + *predicates, **constraints + ) + return When._from_ir(condition) - def to_function_expr(self, *inputs: ExprIR) -> AnonymousExpr: - from narwhals._plan.expr import AnonymousExpr - options = self.function_options - return AnonymousExpr(input=inputs, function=self, options=options) +def int_range( + start: int | IntoExprColumn = 0, + end: int | IntoExprColumn | None = None, + step: int = 1, + *, + dtype: IntegerType | type[IntegerType] = Version.MAIN.dtypes.Int64, + eager: bool = False, +) -> Expr: + if end is None: + end = start + start = 0 + if eager: + msg = f"{eager=}" + raise NotImplementedError(msg) + return ( + IntRange(step=step, dtype=common.into_dtype(dtype)) + .to_function_expr(*_parse.parse_into_seq_of_expr_ir(start, end)) + .to_narwhals() + ) diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index ce78165c00..bb7a4315b3 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -10,27 +10,25 @@ from itertools import chain from typing import TYPE_CHECKING, Literal, overload -from narwhals._plan.common import IRNamespace +from narwhals._plan import expressions as ir +from narwhals._plan._guards import is_literal +from narwhals._plan.expressions.literal import is_literal_scalar +from narwhals._plan.expressions.namespace import IRNamespace from narwhals.exceptions import ComputeError from narwhals.utils import Version if TYPE_CHECKING: from collections.abc import Iterable, Iterator - from typing_extensions import TypeIs - from narwhals._plan.common import ExprIR - from narwhals._plan.expr import Column - - -class IRMetaNamespace(IRNamespace): +class MetaNamespace(IRNamespace): """Methods to modify and traverse existing expressions.""" def has_multiple_outputs(self) -> bool: return any(_has_multiple_outputs(e) for e in self._ir.iter_left()) def is_column(self) -> bool: - return is_column(self._ir) + return isinstance(self._ir, ir.Column) def is_column_selection(self, *, allow_aliasing: bool = False) -> bool: return all( @@ -51,9 +49,9 @@ def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: """Get the output name of this expression. Examples: - >>> from narwhals._plan import demo as nwd + >>> from narwhals import _plan as nw >>> - >>> a = nwd.col("a") + >>> a = nw.col("a") >>> b = a.alias("b") >>> c = b.min().alias("c") >>> @@ -78,23 +76,21 @@ def root_names(self) -> list[str]: return list(_expr_to_leaf_column_names_iter(self._ir)) -def _expr_to_leaf_column_names_iter(ir: ExprIR) -> Iterator[str]: - for e in _expr_to_leaf_column_exprs_iter(ir): +def _expr_to_leaf_column_names_iter(expr: ir.ExprIR, /) -> Iterator[str]: + for e in _expr_to_leaf_column_exprs_iter(expr): result = _expr_to_leaf_column_name(e) if isinstance(result, str): yield result -def _expr_to_leaf_column_exprs_iter(ir: ExprIR) -> Iterator[ExprIR]: - from narwhals._plan import expr - - for outer in ir.iter_root_names(): - if isinstance(outer, (expr.Column, expr.All)): +def _expr_to_leaf_column_exprs_iter(expr: ir.ExprIR, /) -> Iterator[ir.ExprIR]: + for outer in expr.iter_root_names(): + if isinstance(outer, (ir.Column, ir.All)): yield outer -def _expr_to_leaf_column_name(ir: ExprIR) -> str | ComputeError: - leaves = list(_expr_to_leaf_column_exprs_iter(ir)) +def _expr_to_leaf_column_name(expr: ir.ExprIR, /) -> str | ComputeError: + leaves = list(_expr_to_leaf_column_exprs_iter(expr)) if not len(leaves) <= 1: msg = "found more than one root column name" return ComputeError(msg) @@ -102,40 +98,38 @@ def _expr_to_leaf_column_name(ir: ExprIR) -> str | ComputeError: msg = "no root column name found" return ComputeError(msg) leaf = leaves[0] - from narwhals._plan import expr - - if isinstance(leaf, expr.Column): + if isinstance(leaf, ir.Column): return leaf.name - if isinstance(leaf, expr.All): + if isinstance(leaf, ir.All): msg = "wildcard has no root column name" return ComputeError(msg) msg = f"Expected unreachable, got {type(leaf).__name__!r}\n\n{leaf}" return ComputeError(msg) -def root_names_unique(irs: Iterable[ExprIR], /) -> set[str]: - return set(chain.from_iterable(_expr_to_leaf_column_names_iter(e) for e in irs)) +def root_names_unique(exprs: Iterable[ir.ExprIR], /) -> set[str]: + return set(chain.from_iterable(_expr_to_leaf_column_names_iter(e) for e in exprs)) @lru_cache(maxsize=32) -def _expr_output_name(ir: ExprIR) -> str | ComputeError: - from narwhals._plan import expr - - for e in ir.iter_output_name(): - if isinstance(e, (expr.Column, expr.Alias, expr.Literal, expr.Len)): +def _expr_output_name(expr: ir.ExprIR, /) -> str | ComputeError: + for e in expr.iter_output_name(): + if isinstance(e, (ir.Column, ir.Alias, ir.Literal, ir.Len)): return e.name - if isinstance(e, (expr.All, expr.KeepName, expr.RenameAlias)): + if isinstance(e, (ir.All, ir.KeepName, ir.RenameAlias)): msg = "cannot determine output column without a context for this expression" return ComputeError(msg) - if isinstance(e, (expr.Columns, expr.IndexColumns, expr.Nth)): + if isinstance(e, (ir.Columns, ir.IndexColumns, ir.Nth)): msg = "this expression may produce multiple output names" return ComputeError(msg) continue - msg = f"unable to find root column name for expr '{ir!r}' when calling 'output_name'" + msg = ( + f"unable to find root column name for expr '{expr!r}' when calling 'output_name'" + ) return ComputeError(msg) -def get_single_leaf_name(ir: ExprIR) -> str | ComputeError: +def get_single_leaf_name(expr: ir.ExprIR, /) -> str | ComputeError: """Find the name at the start of an expression. Normal iteration would just return the first root column it found. @@ -144,60 +138,45 @@ def get_single_leaf_name(ir: ExprIR) -> str | ComputeError: [`polars_plan::utils::get_single_leaf`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/utils.rs#L151-L168 """ - from narwhals._plan import expr - - for e in ir.iter_right(): - if isinstance(e, (expr.WindowExpr, expr.SortBy, expr.Filter)): + for e in expr.iter_right(): + if isinstance(e, (ir.WindowExpr, ir.SortBy, ir.Filter)): return get_single_leaf_name(e.expr) - if isinstance(e, expr.BinaryExpr): + if isinstance(e, ir.BinaryExpr): return get_single_leaf_name(e.left) # NOTE: `polars` doesn't include `Literal` here - if isinstance(e, (expr.Column, expr.Len)): + if isinstance(e, (ir.Column, ir.Len)): return e.name - msg = f"unable to find a single leaf column in expr '{ir!r}'" + msg = f"unable to find a single leaf column in expr '{expr!r}'" return ComputeError(msg) -def _has_multiple_outputs(ir: ExprIR) -> bool: - from narwhals._plan import expr +def _has_multiple_outputs(expr: ir.ExprIR, /) -> bool: + return isinstance(expr, (ir.Columns, ir.IndexColumns, ir.SelectorIR, ir.All)) - return isinstance(ir, (expr.Columns, expr.IndexColumns, expr.SelectorIR, expr.All)) - -def has_expr_ir(ir: ExprIR, *matches: type[ExprIR]) -> bool: +def has_expr_ir(expr: ir.ExprIR, *matches: type[ir.ExprIR]) -> bool: """Return True if any node in the tree is in type `matches`. Based on [`polars_plan::utils::has_expr`] [`polars_plan::utils::has_expr`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/utils.rs#L70-L77 """ - return any(isinstance(e, matches) for e in ir.iter_right()) - - -def is_column(ir: ExprIR) -> TypeIs[Column]: - from narwhals._plan.expr import Column + return any(isinstance(e, matches) for e in expr.iter_right()) - return isinstance(ir, Column) - - -def _is_literal(ir: ExprIR, *, allow_aliasing: bool) -> bool: - from narwhals._plan import expr - from narwhals._plan.literal import is_literal_scalar +def _is_literal(expr: ir.ExprIR, /, *, allow_aliasing: bool) -> bool: return ( - isinstance(ir, expr.Literal) - or (allow_aliasing and isinstance(ir, expr.Alias)) + is_literal(expr) + or (allow_aliasing and isinstance(expr, ir.Alias)) or ( - isinstance(ir, expr.Cast) - and is_literal_scalar(ir.expr) - and isinstance(ir.expr.dtype, Version.MAIN.dtypes.Datetime) + isinstance(expr, ir.Cast) + and is_literal_scalar(expr.expr) + and isinstance(expr.expr.dtype, Version.MAIN.dtypes.Datetime) ) ) -def _is_column_selection(ir: ExprIR, *, allow_aliasing: bool) -> bool: - from narwhals._plan import expr - - return isinstance(ir, (expr.Column, expr._ColumnSelection, expr.SelectorIR)) or ( - allow_aliasing and isinstance(ir, (expr.Alias, expr.KeepName, expr.RenameAlias)) +def _is_column_selection(expr: ir.ExprIR, /, *, allow_aliasing: bool) -> bool: + return isinstance(expr, (ir.Column, ir._ColumnSelection, ir.SelectorIR)) or ( + allow_aliasing and isinstance(expr, (ir.Alias, ir.KeepName, ir.RenameAlias)) ) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 951dfa850e..11a17eb081 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -3,7 +3,7 @@ from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized from typing import TYPE_CHECKING, Any, Literal, Protocol, overload -from narwhals._plan.common import ExprIR, NamedIR, flatten_hash_safe +from narwhals._plan.common import flatten_hash_safe from narwhals._plan.typing import NativeDataFrameT, NativeFrameT, NativeSeriesT, Seq from narwhals._typing_compat import TypeVar from narwhals._utils import Version @@ -11,13 +11,21 @@ if TYPE_CHECKING: from typing_extensions import Self, TypeAlias, TypeIs - from narwhals._plan import aggregation as agg, boolean, expr, functions as F - from narwhals._plan.boolean import IsBetween, IsFinite, IsNan, IsNull, Not - from narwhals._plan.dummy import BaseFrame, DataFrame, Series - from narwhals._plan.expr import BinaryExpr, FunctionExpr, RangeExpr + from narwhals._plan import expressions as ir + from narwhals._plan.dataframe import BaseFrame, DataFrame + from narwhals._plan.expressions import ( + BinaryExpr, + FunctionExpr, + NamedIR, + aggregation as agg, + boolean, + functions as F, + ) + from narwhals._plan.expressions.boolean import IsBetween, IsFinite, IsNan, IsNull, Not + from narwhals._plan.expressions.ranges import IntRange + from narwhals._plan.expressions.strings import ConcatStr from narwhals._plan.options import SortMultipleOptions - from narwhals._plan.ranges import IntRange - from narwhals._plan.strings import ConcatStr + from narwhals._plan.series import Series from narwhals._plan.typing import OneOrIterable from narwhals.dtypes import DType from narwhals.typing import ( @@ -162,13 +170,13 @@ def _length_required( class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]): @classmethod - def from_ir(cls, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: + def from_ir(cls, node: ir.ExprIR, frame: FrameT_contra, name: str) -> R_co: obj = cls.__new__(cls) obj._version = frame.version return node.dispatch(obj, frame, name) @classmethod - def from_named_ir(cls, named_ir: NamedIR[ExprIR], frame: FrameT_contra) -> R_co: + def from_named_ir(cls, named_ir: NamedIR[ir.ExprIR], frame: FrameT_contra) -> R_co: return cls.from_ir(named_ir.expr, frame, named_ir.name) # NOTE: Needs to stay `covariant` and never be used as a parameter @@ -191,7 +199,7 @@ def _with_native(self, native: Any, name: str, /) -> Self: return self.from_native(native, name or self.name, self.version) # series & scalar - def cast(self, node: expr.Cast, frame: FrameT_contra, name: str) -> Self: ... + def cast(self, node: ir.Cast, frame: FrameT_contra, name: str) -> Self: ... def pow(self, node: FunctionExpr[F.Pow], frame: FrameT_contra, name: str) -> Self: ... def not_(self, node: FunctionExpr[Not], frame: FrameT_contra, name: str) -> Self: ... def fill_null( @@ -211,24 +219,24 @@ def is_null( ) -> Self: ... def binary_expr(self, node: BinaryExpr, frame: FrameT_contra, name: str) -> Self: ... def ternary_expr( - self, node: expr.TernaryExpr, frame: FrameT_contra, name: str + self, node: ir.TernaryExpr, frame: FrameT_contra, name: str ) -> Self: ... - def over(self, node: expr.WindowExpr, frame: FrameT_contra, name: str) -> Self: ... + def over(self, node: ir.WindowExpr, frame: FrameT_contra, name: str) -> Self: ... # NOTE: `Scalar` is returned **only** for un-partitioned `OrderableAggExpr` # e.g. `nw.col("a").first().over(order_by="b")` def over_ordered( - self, node: expr.OrderedWindowExpr, frame: FrameT_contra, name: str + self, node: ir.OrderedWindowExpr, frame: FrameT_contra, name: str ) -> Self | CompliantScalar[FrameT_contra, SeriesT_co]: ... def map_batches( - self, node: expr.AnonymousExpr, frame: FrameT_contra, name: str + self, node: ir.AnonymousExpr, frame: FrameT_contra, name: str ) -> Self: ... def rolling_expr( - self, node: expr.RollingExpr, frame: FrameT_contra, name: str + self, node: ir.RollingExpr, frame: FrameT_contra, name: str ) -> Self: ... # series only (section 3) - def sort(self, node: expr.Sort, frame: FrameT_contra, name: str) -> Self: ... - def sort_by(self, node: expr.SortBy, frame: FrameT_contra, name: str) -> Self: ... - def filter(self, node: expr.Filter, frame: FrameT_contra, name: str) -> Self: ... + def sort(self, node: ir.Sort, frame: FrameT_contra, name: str) -> Self: ... + def sort_by(self, node: ir.SortBy, frame: FrameT_contra, name: str) -> Self: ... + def filter(self, node: ir.Filter, frame: FrameT_contra, name: str) -> Self: ... # series -> scalar def first( self, node: agg.First, frame: FrameT_contra, name: str @@ -328,7 +336,7 @@ def last(self, node: agg.Last, frame: FrameT_contra, name: str) -> Self: """Returns self.""" return self._with_evaluated(self._evaluated, name) - def _cast_float(self, node: ExprIR, frame: FrameT_contra, name: str) -> Self: + def _cast_float(self, node: ir.ExprIR, frame: FrameT_contra, name: str) -> Self: """`polars` interpolates a single scalar as a float.""" dtype = self.version.dtypes.Float64() return self.cast(node.cast(dtype), frame, name) @@ -366,10 +374,10 @@ def count(self, node: agg.Count, frame: FrameT_contra, name: str) -> Self: """Returns 0 if null, else 1.""" ... - def sort(self, node: expr.Sort, frame: FrameT_contra, name: str) -> Self: + def sort(self, node: ir.Sort, frame: FrameT_contra, name: str) -> Self: return self._with_evaluated(self._evaluated, name) - def sort_by(self, node: expr.SortBy, frame: FrameT_contra, name: str) -> Self: + def sort_by(self, node: ir.SortBy, frame: FrameT_contra, name: str) -> Self: return self._with_evaluated(self._evaluated, name) # NOTE: `Filter` behaves the same, (maybe) no need to override @@ -439,11 +447,11 @@ def _frame(self) -> type[FrameT]: ... def _expr(self) -> type[ExprT_co]: ... @property def _scalar(self) -> type[ScalarT_co]: ... - def col(self, node: expr.Column, frame: FrameT, name: str) -> ExprT_co: ... + def col(self, node: ir.Column, frame: FrameT, name: str) -> ExprT_co: ... def lit( - self, node: expr.Literal[Any], frame: FrameT, name: str + self, node: ir.Literal[Any], frame: FrameT, name: str ) -> ExprT_co | ScalarT_co: ... - def len(self, node: expr.Len, frame: FrameT, name: str) -> ScalarT_co: ... + def len(self, node: ir.Len, frame: FrameT, name: str) -> ScalarT_co: ... def any_horizontal( self, node: FunctionExpr[boolean.AnyHorizontal], frame: FrameT, name: str ) -> ExprT_co | ScalarT_co: ... @@ -466,7 +474,7 @@ def concat_str( self, node: FunctionExpr[ConcatStr], frame: FrameT, name: str ) -> ExprT_co | ScalarT_co: ... def int_range( - self, node: RangeExpr[IntRange], frame: FrameT, name: str + self, node: ir.RangeExpr[IntRange], frame: FrameT, name: str ) -> ExprT_co: ... @@ -491,16 +499,16 @@ def _is_dataframe(self, obj: Any) -> TypeIs[EagerDataFrameT]: @overload def lit( - self, node: expr.Literal[NonNestedLiteral], frame: EagerDataFrameT, name: str + self, node: ir.Literal[NonNestedLiteral], frame: EagerDataFrameT, name: str ) -> EagerScalarT_co: ... @overload def lit( - self, node: expr.Literal[Series[Any]], frame: EagerDataFrameT, name: str + self, node: ir.Literal[Series[Any]], frame: EagerDataFrameT, name: str ) -> EagerExprT_co: ... def lit( - self, node: expr.Literal[Any], frame: EagerDataFrameT, name: str + self, node: ir.Literal[Any], frame: EagerDataFrameT, name: str ) -> EagerExprT_co | EagerScalarT_co: ... - def len(self, node: expr.Len, frame: EagerDataFrameT, name: str) -> EagerScalarT_co: + def len(self, node: ir.Len, frame: EagerDataFrameT, name: str) -> EagerScalarT_co: return self._scalar.from_python( len(frame), name or node.name, dtype=None, version=frame.version ) @@ -542,7 +550,7 @@ def _with_native(self, native: NativeFrameT) -> Self: @property def schema(self) -> Mapping[str, DType]: ... def _evaluate_irs( - self, nodes: Iterable[NamedIR[ExprIR]], / + self, nodes: Iterable[NamedIR[ir.ExprIR]], / ) -> Iterator[ColumnT_co]: ... def select(self, irs: Seq[NamedIR]) -> Self: ... def with_columns(self, irs: Seq[NamedIR]) -> Self: ... @@ -603,7 +611,7 @@ def name(self) -> str: return self._name def to_narwhals(self) -> Series[NativeSeriesT]: - from narwhals._plan.dummy import Series + from narwhals._plan.series import Series return Series[NativeSeriesT]._from_compliant(self) diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index 69c1b5a2b3..4dbf5e6ef3 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -7,8 +7,8 @@ from types import MappingProxyType from typing import TYPE_CHECKING, Any, TypeVar, overload +from narwhals._plan._expr_ir import NamedIR from narwhals._plan._immutable import _IMMUTABLE_HASH_NAME, Immutable -from narwhals._plan.common import NamedIR from narwhals.dtypes import Unknown if TYPE_CHECKING: diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py new file mode 100644 index 0000000000..1ab9366ea3 --- /dev/null +++ b/narwhals/_plan/series.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar, Generic + +from narwhals._plan.typing import NativeSeriesT +from narwhals._utils import Version +from narwhals.dependencies import is_pyarrow_chunked_array + +if TYPE_CHECKING: + from collections.abc import Iterator + + import pyarrow as pa + from typing_extensions import Self + + from narwhals._plan.protocols import CompliantSeries + from narwhals.dtypes import DType + from narwhals.typing import NativeSeries + + +class Series(Generic[NativeSeriesT]): + _compliant: CompliantSeries[NativeSeriesT] + _version: ClassVar[Version] = Version.MAIN + + @property + def version(self) -> Version: + return self._version + + @property + def dtype(self) -> DType: + return self._compliant.dtype + + @property + def name(self) -> str: + return self._compliant.name + + # NOTE: Gave up on trying to get typing working for now + @classmethod + def from_native( + cls, native: NativeSeries, name: str = "", / + ) -> Series[pa.ChunkedArray[Any]]: + if is_pyarrow_chunked_array(native): + from narwhals._plan.arrow.series import ArrowSeries + + return ArrowSeries.from_native( + native, name, version=cls._version + ).to_narwhals() + + raise NotImplementedError(type(native)) + + @classmethod + def _from_compliant(cls, compliant: CompliantSeries[NativeSeriesT], /) -> Self: + obj = cls.__new__(cls) + obj._compliant = compliant + return obj + + def to_native(self) -> NativeSeriesT: + return self._compliant.native + + def to_list(self) -> list[Any]: + return self._compliant.to_list() + + def __iter__(self) -> Iterator[Any]: + yield from self.to_native() + + +class SeriesV1(Series[NativeSeriesT]): + _version: ClassVar[Version] = Version.V1 diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index b7e0736e15..0efb81ea81 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -8,11 +8,14 @@ from typing_extensions import TypeAlias from narwhals import dtypes - from narwhals._plan import operators as ops - from narwhals._plan.common import ExprIR, Function, IRNamespace, NamedIR, SelectorIR - from narwhals._plan.dummy import Expr, Series - from narwhals._plan.functions import RollingWindow - from narwhals._plan.ranges import RangeFunction + from narwhals._plan._expr_ir import ExprIR, NamedIR, SelectorIR + from narwhals._plan._function import Function + from narwhals._plan.expr import Expr + from narwhals._plan.expressions import operators as ops + from narwhals._plan.expressions.functions import RollingWindow + from narwhals._plan.expressions.namespace import IRNamespace + from narwhals._plan.expressions.ranges import RangeFunction + from narwhals._plan.series import Series from narwhals.typing import ( NativeDataFrame, NativeFrame, diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index 62e0da3d2a..18ae514f0d 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -4,15 +4,14 @@ from narwhals._plan._guards import is_expr from narwhals._plan._immutable import Immutable -from narwhals._plan.dummy import Expr -from narwhals._plan.expr_parsing import ( +from narwhals._plan._parse import ( parse_into_expr_ir, parse_predicates_constraints_into_expr_ir, ) +from narwhals._plan.expr import Expr if TYPE_CHECKING: - from narwhals._plan.common import ExprIR - from narwhals._plan.expr import TernaryExpr + from narwhals._plan.expressions import ExprIR, TernaryExpr from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq @@ -28,8 +27,8 @@ def _from_expr(expr: Expr, /) -> When: return When(condition=expr._ir) @staticmethod - def _from_ir(ir: ExprIR, /) -> When: - return When(condition=ir) + def _from_ir(expr_ir: ExprIR, /) -> When: + return When(condition=expr_ir) class Then(Immutable, Expr): @@ -56,8 +55,8 @@ def _ir(self) -> ExprIR: # type: ignore[override] return self._otherwise() @classmethod - def _from_ir(cls, ir: ExprIR, /) -> Expr: # type: ignore[override] - return Expr._from_ir(ir) + def _from_ir(cls, expr_ir: ExprIR, /) -> Expr: # type: ignore[override] + return Expr._from_ir(expr_ir) def __eq__(self, value: object) -> Expr | bool: # type: ignore[override] if is_expr(value): @@ -104,8 +103,8 @@ def _ir(self) -> ExprIR: # type: ignore[override] return self._otherwise() @classmethod - def _from_ir(cls, ir: ExprIR, /) -> Expr: # type: ignore[override] - return Expr._from_ir(ir) + def _from_ir(cls, expr_ir: ExprIR, /) -> Expr: # type: ignore[override] + return Expr._from_ir(expr_ir) def __eq__(self, value: object) -> Expr | bool: # type: ignore[override] if is_expr(value): @@ -114,6 +113,6 @@ def __eq__(self, value: object) -> Expr | bool: # type: ignore[override] def ternary_expr(predicate: ExprIR, truthy: ExprIR, falsy: ExprIR, /) -> TernaryExpr: - from narwhals._plan.expr import TernaryExpr + from narwhals._plan.expressions.expr import TernaryExpr return TernaryExpr(predicate=predicate, truthy=truthy, falsy=falsy) diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py deleted file mode 100644 index fd27743948..0000000000 --- a/narwhals/_plan/window.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from narwhals._plan._guards import is_function_expr, is_window_expr -from narwhals._plan._immutable import Immutable -from narwhals._plan.exceptions import ( - over_elementwise_error, - over_nested_error, - over_row_separable_error, -) - -if TYPE_CHECKING: - from narwhals._plan.common import ExprIR - from narwhals._plan.expr import OrderedWindowExpr, WindowExpr - from narwhals._plan.options import SortOptions - from narwhals._plan.typing import Seq - from narwhals.exceptions import InvalidOperationError - - -class Window(Immutable): - """Renamed from `WindowType` https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/options/mod.rs#L139.""" - - -class Over(Window): - @staticmethod - def _validate_over( - expr: ExprIR, - partition_by: Seq[ExprIR], - order_by: Seq[ExprIR] = (), - sort_options: SortOptions | None = None, - /, - ) -> InvalidOperationError | None: - if is_window_expr(expr): - return over_nested_error(expr, partition_by, order_by, sort_options) - if is_function_expr(expr): - if expr.options.is_elementwise(): - return over_elementwise_error(expr, partition_by, order_by, sort_options) - if expr.options.is_row_separable(): - return over_row_separable_error( - expr, partition_by, order_by, sort_options - ) - return None - - def to_window_expr(self, expr: ExprIR, partition_by: Seq[ExprIR], /) -> WindowExpr: - from narwhals._plan.expr import WindowExpr - - if err := self._validate_over(expr, partition_by): - raise err - return WindowExpr(expr=expr, partition_by=partition_by, options=self) - - def to_ordered_window_expr( - self, - expr: ExprIR, - partition_by: Seq[ExprIR], - order_by: Seq[ExprIR], - sort_options: SortOptions, - /, - ) -> OrderedWindowExpr: - from narwhals._plan.expr import OrderedWindowExpr - - if err := self._validate_over(expr, partition_by, order_by, sort_options): - raise err - return OrderedWindowExpr( - expr=expr, - partition_by=partition_by, - order_by=order_by, - sort_options=sort_options, - options=self, - ) diff --git a/pyproject.toml b/pyproject.toml index 3b5db19672..fadf125896 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -174,7 +174,6 @@ ignore = [ "FBT003", # boolean-positional-value-in-call (We enforce at definition site when it is a flag, not a value (e.g. `lit(False)`)) "FIX", # flake8-fixme "PD010", # pandas-use-of-dot-pivot-or-unstack - "PD901", # pandas-df-variable-name (This is a auxiliary library so dataframe variables have no concrete business meaning) "PLC0415", # `import` should be at the top-level of a file "PLR0913", # too-many-arguments "PLR2004", # magic-value-comparison @@ -209,6 +208,7 @@ extend-ignore-names = [ "C901", # complex-structure "PLR0912", # too-many-branches "PLR0916", # too-many-boolean-expressions + "RUF043", # temp ignore until sync ] "tpch/tests/*" = ["S101"] "utils/*" = ["S311"] diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index dc548968a4..ffada70747 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -4,15 +4,15 @@ import pytest +from narwhals._plan import selectors as ndcs + pytest.importorskip("pyarrow") pytest.importorskip("numpy") import numpy as np import pyarrow as pa import narwhals as nw -from narwhals._plan import demo as nwd, selectors as ndcs -from narwhals._plan._guards import is_expr -from narwhals._plan.dummy import DataFrame +from narwhals import _plan as nwp from narwhals._utils import Version from narwhals.exceptions import ComputeError from tests.utils import assert_equal_data @@ -20,7 +20,6 @@ if TYPE_CHECKING: from collections.abc import Sequence - from narwhals._plan.dummy import Expr from narwhals.typing import PythonLiteral @@ -64,8 +63,8 @@ def data_indexed() -> dict[str, Any]: } -def _ids_ir(expr: Expr | Any) -> str: - if is_expr(expr): +def _ids_ir(expr: nwp.Expr | Any) -> str: + if isinstance(expr, nwp.Expr): return repr(expr._ir) return repr(expr) @@ -87,60 +86,60 @@ def _ids_ir(expr: Expr | Any) -> str: @pytest.mark.parametrize( ("expr", "expected"), [ - (nwd.col("a"), {"a": ["A", "B", "A"]}), - (nwd.col("a", "b"), {"a": ["A", "B", "A"], "b": [1, 2, 3]}), - (nwd.lit(1), {"literal": [1]}), - (nwd.lit(2.0), {"literal": [2.0]}), - (nwd.lit(None, nw.String), {"literal": [None]}), - (nwd.col("a", "b").first(), {"a": ["A"], "b": [1]}), - (nwd.col("d").max(), {"d": [8]}), - ([nwd.len(), nwd.nth(3).last()], {"len": [3], "d": [8]}), + (nwp.col("a"), {"a": ["A", "B", "A"]}), + (nwp.col("a", "b"), {"a": ["A", "B", "A"], "b": [1, 2, 3]}), + (nwp.lit(1), {"literal": [1]}), + (nwp.lit(2.0), {"literal": [2.0]}), + (nwp.lit(None, nw.String), {"literal": [None]}), + (nwp.col("a", "b").first(), {"a": ["A"], "b": [1]}), + (nwp.col("d").max(), {"d": [8]}), + ([nwp.len(), nwp.nth(3).last()], {"len": [3], "d": [8]}), ( - [nwd.len().alias("e"), nwd.nth(3).last(), nwd.nth(2)], + [nwp.len().alias("e"), nwp.nth(3).last(), nwp.nth(2)], {"e": [3, 3, 3], "d": [8, 8, 8], "c": [9, 2, 4]}, ), - (nwd.col("b").sort(descending=True).alias("b_desc"), {"b_desc": [3, 2, 1]}), - (nwd.col("c").filter(a="B"), {"c": [2]}), + (nwp.col("b").sort(descending=True).alias("b_desc"), {"b_desc": [3, 2, 1]}), + (nwp.col("c").filter(a="B"), {"c": [2]}), ( - [nwd.nth(0, 1).filter(nwd.col("c") >= 4), nwd.col("d").last() - 4], + [nwp.nth(0, 1).filter(nwp.col("c") >= 4), nwp.col("d").last() - 4], {"a": ["A", "A"], "b": [1, 3], "d": [4, 4]}, ), - (nwd.col("b").cast(nw.Float64()), {"b": [1.0, 2.0, 3.0]}), - (nwd.lit(1).cast(nw.Float64).alias("literal_cast"), {"literal_cast": [1.0]}), + (nwp.col("b").cast(nw.Float64()), {"b": [1.0, 2.0, 3.0]}), + (nwp.lit(1).cast(nw.Float64).alias("literal_cast"), {"literal_cast": [1.0]}), pytest.param( - nwd.lit(1).cast(nw.Float64()).name.suffix("_cast"), + nwp.lit(1).cast(nw.Float64()).name.suffix("_cast"), {"literal_cast": [1.0]}, marks=XFAIL_REWRITE_SPECIAL_ALIASES, ), - ([ndcs.string().first(), nwd.col("b")], {"a": ["A", "A", "A"], "b": [1, 2, 3]}), + ([ndcs.string().first(), nwp.col("b")], {"a": ["A", "A", "A"], "b": [1, 2, 3]}), ( - nwd.col("c", "d") + nwp.col("c", "d") .sort_by("a", "b", descending=[True, False]) .cast(nw.Float32()) .name.to_uppercase(), {"C": [2.0, 9.0, 4.0], "D": [7.0, 8.0, 8.0]}, ), - ([nwd.int_range(5)], {"literal": [0, 1, 2, 3, 4]}), - ([nwd.int_range(nwd.len())], {"literal": [0, 1, 2]}), - (nwd.int_range(nwd.len() * 5, 20).alias("lol"), {"lol": [15, 16, 17, 18, 19]}), - (nwd.int_range(nwd.col("b").min() + 4, nwd.col("d").last()), {"b": [5, 6, 7]}), - (nwd.col("b") ** 2, {"b": [1, 4, 9]}), + ([nwp.int_range(5)], {"literal": [0, 1, 2, 3, 4]}), + ([nwp.int_range(nwp.len())], {"literal": [0, 1, 2]}), + (nwp.int_range(nwp.len() * 5, 20).alias("lol"), {"lol": [15, 16, 17, 18, 19]}), + (nwp.int_range(nwp.col("b").min() + 4, nwp.col("d").last()), {"b": [5, 6, 7]}), + (nwp.col("b") ** 2, {"b": [1, 4, 9]}), ( - [2 ** nwd.col("b"), (nwd.lit(2.0) ** nwd.nth(1)).alias("lit")], + [2 ** nwp.col("b"), (nwp.lit(2.0) ** nwp.nth(1)).alias("lit")], {"literal": [2, 4, 8], "lit": [2, 4, 8]}, ), pytest.param( [ - nwd.col("b").is_between(2, 3, "left").alias("left"), - nwd.col("b").is_between(2, 3, "right").alias("right"), - nwd.col("b").is_between(2, 3, "none").alias("none"), - nwd.col("b").is_between(2, 3, "both").alias("both"), - nwd.col("c").is_between( - nwd.col("c").mean() - 1, 7 - nwd.col("b"), "both" + nwp.col("b").is_between(2, 3, "left").alias("left"), + nwp.col("b").is_between(2, 3, "right").alias("right"), + nwp.col("b").is_between(2, 3, "none").alias("none"), + nwp.col("b").is_between(2, 3, "both").alias("both"), + nwp.col("c").is_between( + nwp.col("c").mean() - 1, 7 - nwp.col("b"), "both" ), - nwd.col("c") + nwp.col("c") .alias("c_right") - .is_between(nwd.col("c").mean() - 1, 7 - nwd.col("b"), "right"), + .is_between(nwp.col("c").mean() - 1, 7 - nwp.col("b"), "right"), ], { "left": [False, True, False], @@ -154,12 +153,12 @@ def _ids_ir(expr: Expr | Any) -> str: ), pytest.param( [ - nwd.col("e").fill_null(0).alias("e_0"), - nwd.col("e").fill_null(nwd.col("b")).alias("e_b"), - nwd.col("e").fill_null(nwd.col("b").last()).alias("e_b_last"), - nwd.col("e") + nwp.col("e").fill_null(0).alias("e_0"), + nwp.col("e").fill_null(nwp.col("b")).alias("e_b"), + nwp.col("e").fill_null(nwp.col("b").last()).alias("e_b_last"), + nwp.col("e") .sort(nulls_last=True) - .fill_null(nwd.col("d").last() - nwd.col("c")) + .fill_null(nwp.col("d").last() - nwp.col("c")) .alias("e_sort_wild"), ], { @@ -170,88 +169,88 @@ def _ids_ir(expr: Expr | Any) -> str: }, id="sort", ), - (nwd.col("e", "d").is_null().any(), {"e": [True], "d": [False]}), + (nwp.col("e", "d").is_null().any(), {"e": [True], "d": [False]}), ( - [(~nwd.col("e", "d").is_null()).all(), "b"], + [(~nwp.col("e", "d").is_null()).all(), "b"], {"e": [False, False, False], "d": [True, True, True], "b": [1, 2, 3]}, ), pytest.param( - nwd.when(d=8).then("c"), {"c": [9, None, 4]}, id="When-otherwise-none" + nwp.when(d=8).then("c"), {"c": [9, None, 4]}, id="When-otherwise-none" ), pytest.param( - nwd.when(nwd.col("e").is_null()) - .then(nwd.col("b") + nwd.col("c")) + nwp.when(nwp.col("e").is_null()) + .then(nwp.col("b") + nwp.col("c")) .otherwise(50), {"b": [10, 50, 50]}, id="When-otherwise-native-broadcast", ), pytest.param( - nwd.when(nwd.col("a") == nwd.lit("C")) - .then(nwd.lit("c")) - .when(nwd.col("a") == nwd.lit("D")) - .then(nwd.lit("d")) - .when(nwd.col("a") == nwd.lit("B")) - .then(nwd.lit("b")) - .when(nwd.col("a") == nwd.lit("A")) - .then(nwd.lit("a")) + nwp.when(nwp.col("a") == nwp.lit("C")) + .then(nwp.lit("c")) + .when(nwp.col("a") == nwp.lit("D")) + .then(nwp.lit("d")) + .when(nwp.col("a") == nwp.lit("B")) + .then(nwp.lit("b")) + .when(nwp.col("a") == nwp.lit("A")) + .then(nwp.lit("a")) .alias("A"), {"A": ["a", "b", "a"]}, id="When-then-x4", ), pytest.param( - nwd.when(nwd.col("c") > 5, b=1).then(999), + nwp.when(nwp.col("c") > 5, b=1).then(999), {"literal": [999, None, None]}, id="When-multiple-predicates", ), pytest.param( - nwd.when(nwd.col("b") == nwd.col("c"), nwd.col("d").mean() > nwd.col("d")) + nwp.when(nwp.col("b") == nwp.col("c"), nwp.col("d").mean() > nwp.col("d")) .then(123) - .when(nwd.lit(True), ~nwd.nth(4).is_null()) + .when(nwp.lit(True), ~nwp.nth(4).is_null()) .then(456) - .otherwise(nwd.col("c")), + .otherwise(nwp.col("c")), {"literal": [9, 123, 456]}, id="When-multiple-predicates-mixed-broadcast", ), pytest.param( - nwd.when(nwd.lit(True)).then("c"), + nwp.when(nwp.lit(True)).then("c"), {"c": [9, 2, 4]}, id="When-literal-then-column", ), pytest.param( - nwd.when(nwd.lit(True)).then(nwd.col("c").mean()), + nwp.when(nwp.lit(True)).then(nwp.col("c").mean()), {"c": [5.0]}, id="When-literal-then-agg", ), pytest.param( [ - nwd.when(nwd.lit(True)).then(nwd.col("e").last()), - nwd.col("b").sort(descending=True), + nwp.when(nwp.lit(True)).then(nwp.col("e").last()), + nwp.col("b").sort(descending=True), ], {"e": [7, 7, 7], "b": [3, 2, 1]}, id="When-literal-then-agg-broadcast", ), pytest.param( [ - nwd.all_horizontal( - nwd.col("b") < nwd.col("c"), - nwd.col("a") != nwd.lit("B"), - nwd.col("e").cast(nw.Boolean), - nwd.lit(True), + nwp.all_horizontal( + nwp.col("b") < nwp.col("c"), + nwp.col("a") != nwp.lit("B"), + nwp.col("e").cast(nw.Boolean), + nwp.lit(True), ), - nwd.nth(1).last().name.suffix("_last"), + nwp.nth(1).last().name.suffix("_last"), ], {"b": [None, False, True], "b_last": [3, 3, 3]}, id="all-horizontal-mixed-broadcast", ), pytest.param( [ - nwd.all_horizontal(nwd.lit(True), nwd.lit(True)).alias("a"), - nwd.all_horizontal(nwd.lit(False), nwd.lit(True)).alias("b"), - nwd.all_horizontal(nwd.lit(False), nwd.lit(False)).alias("c"), - nwd.all_horizontal(nwd.lit(None, nw.Boolean), nwd.lit(True)).alias("d"), - nwd.all_horizontal(nwd.lit(None, nw.Boolean), nwd.lit(False)).alias("e"), - nwd.all_horizontal( - nwd.lit(None, nw.Boolean), nwd.lit(None, nw.Boolean) + nwp.all_horizontal(nwp.lit(True), nwp.lit(True)).alias("a"), + nwp.all_horizontal(nwp.lit(False), nwp.lit(True)).alias("b"), + nwp.all_horizontal(nwp.lit(False), nwp.lit(False)).alias("c"), + nwp.all_horizontal(nwp.lit(None, nw.Boolean), nwp.lit(True)).alias("d"), + nwp.all_horizontal(nwp.lit(None, nw.Boolean), nwp.lit(False)).alias("e"), + nwp.all_horizontal( + nwp.lit(None, nw.Boolean), nwp.lit(None, nw.Boolean) ).alias("f"), ], { @@ -266,9 +265,9 @@ def _ids_ir(expr: Expr | Any) -> str: ), pytest.param( [ - nwd.any_horizontal("f", "g"), - nwd.any_horizontal("g", "h"), - nwd.any_horizontal(nwd.lit(False), nwd.col("g").last()).alias( + nwp.any_horizontal("f", "g"), + nwp.any_horizontal("g", "h"), + nwp.any_horizontal(nwp.lit(False), nwp.col("g").last()).alias( "False-False" ), ], @@ -281,9 +280,9 @@ def _ids_ir(expr: Expr | Any) -> str: ), pytest.param( [ - nwd.any_horizontal(nwd.lit(None, nw.Boolean), "i").alias("None-None"), - nwd.any_horizontal(nwd.lit(True), "i").alias("True-None"), - nwd.any_horizontal(nwd.lit(False), "i").alias("False-None"), + nwp.any_horizontal(nwp.lit(None, nw.Boolean), "i").alias("None-None"), + nwp.any_horizontal(nwp.lit(True), "i").alias("True-None"), + nwp.any_horizontal(nwp.lit(False), "i").alias("False-None"), ], { "None-None": [None, None, None], @@ -295,15 +294,15 @@ def _ids_ir(expr: Expr | Any) -> str: ), pytest.param( [ - nwd.col("b").alias("a"), - nwd.col("l").alias("b"), - nwd.col("m").alias("i"), - nwd.any_horizontal(nwd.sum("b", "l").cast(nw.Boolean)).alias("any"), - nwd.all_horizontal(nwd.sum("b", "l").cast(nw.Boolean)).alias("all"), - nwd.max_horizontal(nwd.sum("b"), nwd.sum("l")).alias("max"), - nwd.min_horizontal(nwd.sum("b"), nwd.sum("l")).alias("min"), - nwd.sum_horizontal(nwd.sum("b"), nwd.sum("l")).alias("sum"), - nwd.mean_horizontal(nwd.sum("b"), nwd.sum("l")).alias("mean"), + nwp.col("b").alias("a"), + nwp.col("l").alias("b"), + nwp.col("m").alias("i"), + nwp.any_horizontal(nwp.sum("b", "l").cast(nw.Boolean)).alias("any"), + nwp.all_horizontal(nwp.sum("b", "l").cast(nw.Boolean)).alias("all"), + nwp.max_horizontal(nwp.sum("b"), nwp.sum("l")).alias("max"), + nwp.min_horizontal(nwp.sum("b"), nwp.sum("l")).alias("min"), + nwp.sum_horizontal(nwp.sum("b"), nwp.sum("l")).alias("sum"), + nwp.mean_horizontal(nwp.sum("b"), nwp.sum("l")).alias("mean"), ], { "a": [1, 2, 3], @@ -319,39 +318,39 @@ def _ids_ir(expr: Expr | Any) -> str: id="sumh_broadcasting", ), pytest.param( - nwd.mean_horizontal("j", nwd.col("k"), "e"), + nwp.mean_horizontal("j", nwp.col("k"), "e"), {"j": [27.05, 9.5, 5.5]}, id="mean_horizontal-null", ), pytest.param( - nwd.sum_horizontal("j", nwd.col("k"), "e"), + nwp.sum_horizontal("j", nwp.col("k"), "e"), {"j": [54.1, 19.0, 11.0]}, id="sum_horizontal-null", ), pytest.param( - nwd.concat_str(nwd.col("b") * 2, "n", nwd.col("o"), separator=" "), + nwp.concat_str(nwp.col("b") * 2, "n", nwp.col("o"), separator=" "), {"b": ["2 dogs play", "4 cats swim", None]}, id="concat_str-preserve_nulls", ), pytest.param( - nwd.concat_str( - nwd.col("b") * 2, "n", nwd.col("o"), separator=" ", ignore_nulls=True + nwp.concat_str( + nwp.col("b") * 2, "n", nwp.col("o"), separator=" ", ignore_nulls=True ), {"b": ["2 dogs play", "4 cats swim", "6 walk"]}, id="concat_str-ignore_nulls", ), pytest.param( - nwd.concat_str("a", nwd.lit("a")), + nwp.concat_str("a", nwp.lit("a")), {"a": ["Aa", "Ba", "Aa"]}, id="concat_str-lit", ), pytest.param( - nwd.concat_str( - nwd.lit("a"), - nwd.lit("b"), - nwd.lit("c"), - nwd.lit("d"), - nwd.col("e").last() + 13, + nwp.concat_str( + nwp.lit("a"), + nwp.lit("b"), + nwp.lit("c"), + nwp.lit("d"), + nwp.col("e").last() + 13, separator="|", ), {"literal": ["a|b|c|d|20"]}, @@ -359,7 +358,7 @@ def _ids_ir(expr: Expr | Any) -> str: ), pytest.param( [ - nwd.col("a") + nwp.col("a") .alias("...") .map_batches( lambda s: s.from_iterable( @@ -369,13 +368,13 @@ def _ids_ir(expr: Expr | Any) -> str: ), is_elementwise=True, ), - nwd.col("a"), + nwp.col("a"), ], {"funky": ["string", "string", "last"], "a": ["A", "B", "A"]}, id="map_batches-series", ), pytest.param( - nwd.col("b") + nwp.col("b") .map_batches(lambda s: s.to_numpy() + 1, nw.Float64(), is_elementwise=True) .sum(), {"b": [9.0]}, @@ -389,7 +388,7 @@ def _ids_ir(expr: Expr | Any) -> str: id="map_batches-selector", ), pytest.param( - nwd.col("j", "k") + nwp.col("j", "k") .fill_null(15) .map_batches(lambda s: (s.to_numpy().max()), returns_scalar=True), {"j": [15], "k": [42]}, @@ -403,10 +402,12 @@ def _ids_ir(expr: Expr | Any) -> str: ids=_ids_ir, ) def test_select( - expr: Expr | Sequence[Expr], expected: dict[str, Any], data_small: dict[str, Any] + expr: nwp.Expr | Sequence[nwp.Expr], + expected: dict[str, Any], + data_small: dict[str, Any], ) -> None: frame = pa.table(data_small) - df = DataFrame.from_native(frame) + df = nwp.DataFrame.from_native(frame) result = df.select(expr).to_dict(as_series=False) assert_equal_data(result, expected) @@ -415,7 +416,7 @@ def test_select( ("expr", "expected"), [ ( - ["d", nwd.col("a"), "b", nwd.col("e")], + ["d", nwp.col("a"), "b", nwp.col("e")], { "a": ["A", "B", "A"], "b": [1, 2, 3], @@ -438,9 +439,9 @@ def test_select( ), ( [ - nwd.col("e").fill_null(nwd.col("e").last()), - nwd.col("f").sort(), - nwd.nth(1).max(), + nwp.col("e").fill_null(nwp.col("e").last()), + nwp.col("f").sort(), + nwp.nth(1).max(), ], { "a": ["A", "B", "A"], @@ -453,11 +454,11 @@ def test_select( ), pytest.param( [ - nwd.col("a").alias("a?"), + nwp.col("a").alias("a?"), ndcs.by_name("a"), - nwd.col("b").cast(nw.Float64).name.suffix("_float"), - nwd.col("c").max() + 1, - nwd.sum_horizontal(1, "d", nwd.col("b"), nwd.lit(3)), + nwp.col("b").cast(nw.Float64).name.suffix("_float"), + nwp.col("c").max() + 1, + nwp.sum_horizontal(1, "d", nwp.col("b"), nwp.lit(3)), ], { "a": ["A", "B", "A"], @@ -475,20 +476,22 @@ def test_select( ], ) def test_with_columns( - expr: Expr | Sequence[Expr], expected: dict[str, Any], data_smaller: dict[str, Any] + expr: nwp.Expr | Sequence[nwp.Expr], + expected: dict[str, Any], + data_smaller: dict[str, Any], ) -> None: frame = pa.table(data_smaller) - df = DataFrame.from_native(frame) + df = nwp.DataFrame.from_native(frame) result = df.with_columns(expr).to_dict(as_series=False) assert_equal_data(result, expected) -def first(*names: str) -> Expr: - return nwd.col(*names).first() +def first(*names: str) -> nwp.Expr: + return nwp.col(*names).first() -def last(*names: str) -> Expr: - return nwd.col(*names).last() +def last(*names: str) -> nwp.Expr: + return nwp.col(*names).last() @pytest.mark.parametrize( @@ -503,12 +506,12 @@ def last(*names: str) -> Expr: ], ) def test_first_last_expr_with_columns( - data_indexed: dict[str, Any], agg: Expr, expected: PythonLiteral + data_indexed: dict[str, Any], agg: nwp.Expr, expected: PythonLiteral ) -> None: """Related https://github.com/narwhals-dev/narwhals/pull/2528#discussion_r2225930065.""" height = len(next(iter(data_indexed.values()))) expected_broadcast = height * [expected] - frame = DataFrame.from_native(pa.table(data_indexed)) + frame = nwp.DataFrame.from_native(pa.table(data_indexed)) expr = agg.over(order_by="idx").alias("result") result = frame.with_columns(expr).select("result").to_dict(as_series=False) assert_equal_data(result, {"result": expected_broadcast}) diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 8396db6c74..a80724ff86 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -6,14 +6,14 @@ import pytest import narwhals as nw -from narwhals._plan import demo as nwd, selectors as ndcs -from narwhals._plan.expr import Alias, Columns -from narwhals._plan.expr_expansion import ( +from narwhals import _plan as nwp +from narwhals._plan import expressions as ir, selectors as ndcs +from narwhals._plan._expansion import ( prepare_projection, replace_selector, rewrite_special_aliases, ) -from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir +from narwhals._plan._parse import parse_into_seq_of_expr_ir from narwhals._plan.schema import freeze_schema from narwhals.exceptions import ColumnNotFoundError, ComputeError, DuplicateError from tests.plan.utils import assert_expr_ir_equal @@ -21,8 +21,6 @@ if TYPE_CHECKING: from collections.abc import Iterable, Sequence - from narwhals._plan.common import ExprIR - from narwhals._plan.dummy import Expr, Selector from narwhals._plan.typing import IntoExpr, MapIR from narwhals.dtypes import DType @@ -54,9 +52,9 @@ def schema_1() -> dict[str, DType]: MULTI_OUTPUT_EXPRS = ( - pytest.param(nwd.col("a", "b", "c")), + pytest.param(nwp.col("a", "b", "c")), pytest.param(ndcs.numeric() - ndcs.matches("[d-j]")), - pytest.param(nwd.nth(0, 1, 2)), + pytest.param(nwp.nth(0, 1, 2)), pytest.param(ndcs.by_dtype(nw.Int64, nw.Int32, nw.Int16)), pytest.param(ndcs.by_name("a", "b", "c")), ) @@ -76,19 +74,19 @@ def udf_name_map(name: str) -> str: @pytest.mark.parametrize( ("expr", "expected"), [ - (nwd.col("a").name.to_uppercase(), "A"), - (nwd.col("B").name.to_lowercase(), "b"), - (nwd.col("c").name.suffix("_after"), "c_after"), - (nwd.col("d").name.prefix("before_"), "before_d"), + (nwp.col("a").name.to_uppercase(), "A"), + (nwp.col("B").name.to_lowercase(), "b"), + (nwp.col("c").name.suffix("_after"), "c_after"), + (nwp.col("d").name.prefix("before_"), "before_d"), ( - nwd.col("aBcD EFg hi").name.map(udf_name_map), + nwp.col("aBcD EFg hi").name.map(udf_name_map), "original='aBcD EFg hi' | upper='ABCD EFG HI' | lower='abcd efg hi' | title='Abcd Efg Hi'", ), - (nwd.col("a").min().alias("b").over("c").alias("d").max().name.keep(), "a"), + (nwp.col("a").min().alias("b").over("c").alias("d").max().name.keep(), "a"), ( ( - nwd.col("hello") - .sort_by(nwd.col("ignore me")) + nwp.col("hello") + .sort_by(nwp.col("ignore me")) .max() .over("ignore me as well") .first() @@ -98,7 +96,7 @@ def udf_name_map(name: str) -> str: ), ( ( - nwd.col("start") + nwp.col("start") .alias("next") .sort() .round() @@ -110,7 +108,7 @@ def udf_name_map(name: str) -> str: ), ], ) -def test_rewrite_special_aliases_single(expr: Expr, expected: str) -> None: +def test_rewrite_special_aliases_single(expr: nwp.Expr, expected: str) -> None: # NOTE: We can't use `output_name()` without resolving these rewrites # Once they're done, `output_name()` just peeks into `Alias(name=...)` ir_input = expr._ir @@ -126,10 +124,10 @@ def test_rewrite_special_aliases_single(expr: Expr, expected: str) -> None: def alias_replace_guarded(name: str) -> MapIR: # pragma: no cover """Guards against repeatedly creating the same alias.""" - def fn(ir: ExprIR) -> ExprIR: - if isinstance(ir, Alias) and ir.name != name: - return Alias(expr=ir.expr, name=name) - return ir + def fn(e_ir: ir.ExprIR) -> ir.ExprIR: + if isinstance(e_ir, ir.Alias) and e_ir.name != name: + return ir.Alias(expr=e_ir.expr, name=name) + return e_ir return fn @@ -143,10 +141,10 @@ def alias_replace_unguarded(name: str) -> MapIR: # pragma: no cover - *Pragmatically*, it might require an extra iteration to detect a cycle """ - def fn(ir: ExprIR) -> ExprIR: - if isinstance(ir, Alias): - return Alias(expr=ir.expr, name=name) - return ir + def fn(e_ir: ir.ExprIR) -> ir.ExprIR: + if isinstance(e_ir, ir.Alias): + return ir.Alias(expr=e_ir.expr, name=name) + return e_ir return fn @@ -154,33 +152,33 @@ def fn(ir: ExprIR) -> ExprIR: @pytest.mark.parametrize( ("expr", "function", "expected"), [ - (nwd.col("a"), alias_replace_guarded("never"), nwd.col("a")), - (nwd.col("a"), alias_replace_unguarded("never"), nwd.col("a")), - (nwd.col("a").alias("b"), alias_replace_guarded("c"), nwd.col("a").alias("c")), - (nwd.col("a").alias("b"), alias_replace_unguarded("c"), nwd.col("a").alias("c")), + (nwp.col("a"), alias_replace_guarded("never"), nwp.col("a")), + (nwp.col("a"), alias_replace_unguarded("never"), nwp.col("a")), + (nwp.col("a").alias("b"), alias_replace_guarded("c"), nwp.col("a").alias("c")), + (nwp.col("a").alias("b"), alias_replace_unguarded("c"), nwp.col("a").alias("c")), ( - nwd.col("a").alias("d").first().over("b", order_by="c").alias("e"), + nwp.col("a").alias("d").first().over("b", order_by="c").alias("e"), alias_replace_guarded("d"), - nwd.col("a").alias("d").first().over("b", order_by="c").alias("d"), + nwp.col("a").alias("d").first().over("b", order_by="c").alias("d"), ), ( - nwd.col("a").alias("d").first().over("b", order_by="c").alias("e"), + nwp.col("a").alias("d").first().over("b", order_by="c").alias("e"), alias_replace_unguarded("d"), - nwd.col("a").alias("d").first().over("b", order_by="c").alias("d"), + nwp.col("a").alias("d").first().over("b", order_by="c").alias("d"), ), ( - nwd.col("a").alias("e").abs().alias("f").sort().alias("g"), + nwp.col("a").alias("e").abs().alias("f").sort().alias("g"), alias_replace_guarded("e"), - nwd.col("a").alias("e").abs().alias("e").sort().alias("e"), + nwp.col("a").alias("e").abs().alias("e").sort().alias("e"), ), ( - nwd.col("a").alias("e").abs().alias("f").sort().alias("g"), + nwp.col("a").alias("e").abs().alias("f").sort().alias("g"), alias_replace_unguarded("e"), - nwd.col("a").alias("e").abs().alias("e").sort().alias("e"), + nwp.col("a").alias("e").abs().alias("e").sort().alias("e"), ), ], ) -def test_map_ir_recursive(expr: Expr, function: MapIR, expected: Expr) -> None: +def test_map_ir_recursive(expr: nwp.Expr, function: MapIR, expected: nwp.Expr) -> None: actual = expr._ir.map_ir(function) assert_expr_ir_equal(actual, expected) @@ -188,17 +186,17 @@ def test_map_ir_recursive(expr: Expr, function: MapIR, expected: Expr) -> None: @pytest.mark.parametrize( ("expr", "expected"), [ - (nwd.col("a"), nwd.col("a")), - (nwd.col("a").max().alias("z"), nwd.col("a").max().alias("z")), - (ndcs.string(), Columns(names=("k",))), + (nwp.col("a"), nwp.col("a")), + (nwp.col("a").max().alias("z"), nwp.col("a").max().alias("z")), + (ndcs.string(), ir.Columns(names=("k",))), ( ndcs.by_dtype(nw.Datetime("ms"), nw.Date, nw.List(nw.String)), - nwd.col("n", "s"), + nwp.col("n", "s"), ), - (ndcs.string() | ndcs.boolean(), nwd.col("k", "m")), + (ndcs.string() | ndcs.boolean(), nwp.col("k", "m")), ( ~(ndcs.numeric() | ndcs.string()), - nwd.col("l", "m", "n", "o", "p", "q", "r", "s", "u"), + nwp.col("l", "m", "n", "o", "p", "q", "r", "s", "u"), ), ( ( @@ -206,14 +204,14 @@ def test_map_ir_recursive(expr: Expr, function: MapIR, expected: Expr) -> None: - (ndcs.categorical() | ndcs.by_name("a", "b") | ndcs.matches("[fqohim]")) ^ ndcs.by_name("u", "a", "b", "d", "e", "f", "g") ).name.suffix("_after"), - nwd.col("a", "b", "c", "f", "j", "k", "l", "n", "r", "s").name.suffix( + nwp.col("a", "b", "c", "f", "j", "k", "l", "n", "r", "s").name.suffix( "_after" ), ), ( (ndcs.matches("[a-m]") & ~ndcs.numeric()).sort(nulls_last=True).first() - != nwd.lit(None), - nwd.col("k", "l", "m").sort(nulls_last=True).first() != nwd.lit(None), + != nwp.lit(None), + nwp.col("k", "l", "m").sort(nulls_last=True).first() != nwp.lit(None), ), ( ( @@ -222,9 +220,9 @@ def test_map_ir_recursive(expr: Expr, function: MapIR, expected: Expr) -> None: .over("k", order_by=ndcs.by_dtype(nw.Date()) | ndcs.boolean()) ), ( - nwd.col("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") + nwp.col("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") .mean() - .over(nwd.col("k"), order_by=nwd.col("m", "n")) + .over(nwp.col("k"), order_by=nwp.col("m", "n")) ), ), ( @@ -237,10 +235,10 @@ def test_map_ir_recursive(expr: Expr, function: MapIR, expected: Expr) -> None: .name.to_uppercase() ), ( - nwd.col("l", "o") + nwp.col("l", "o") .dt.timestamp("us") .min() - .over(nwd.col("k", "m")) + .over(nwp.col("k", "m")) .last() .name.to_uppercase() ), @@ -248,7 +246,9 @@ def test_map_ir_recursive(expr: Expr, function: MapIR, expected: Expr) -> None: ], ) def test_replace_selector( - expr: Selector | Expr, expected: Expr | ExprIR, schema_1: dict[str, DType] + expr: nwp.Selector | nwp.Expr, + expected: nwp.Expr | ir.ExprIR, + schema_1: dict[str, DType], ) -> None: actual = replace_selector(expr._ir, schema=freeze_schema(**schema_1)) assert_expr_ir_equal(actual, expected) @@ -257,41 +257,41 @@ def test_replace_selector( @pytest.mark.parametrize( ("into_exprs", "expected"), [ - ("a", [nwd.col("a")]), - (nwd.col("b", "c", "d"), [nwd.col("b"), nwd.col("c"), nwd.col("d")]), - (nwd.nth(6), [nwd.col("g")]), - (nwd.nth(9, 8, -5), [nwd.col("j"), nwd.col("i"), nwd.col("p")]), + ("a", [nwp.col("a")]), + (nwp.col("b", "c", "d"), [nwp.col("b"), nwp.col("c"), nwp.col("d")]), + (nwp.nth(6), [nwp.col("g")]), + (nwp.nth(9, 8, -5), [nwp.col("j"), nwp.col("i"), nwp.col("p")]), ( - [nwd.nth(2).alias("c again"), nwd.nth(-1, -2).name.to_uppercase()], + [nwp.nth(2).alias("c again"), nwp.nth(-1, -2).name.to_uppercase()], [ - nwd.col("c").alias("c again"), - nwd.col("u").alias("U"), - nwd.col("s").alias("S"), + nwp.col("c").alias("c again"), + nwp.col("u").alias("U"), + nwp.col("s").alias("S"), ], ), ( - nwd.all(), + nwp.all(), [ - nwd.col("a"), - nwd.col("b"), - nwd.col("c"), - nwd.col("d"), - nwd.col("e"), - nwd.col("f"), - nwd.col("g"), - nwd.col("h"), - nwd.col("i"), - nwd.col("j"), - nwd.col("k"), - nwd.col("l"), - nwd.col("m"), - nwd.col("n"), - nwd.col("o"), - nwd.col("p"), - nwd.col("q"), - nwd.col("r"), - nwd.col("s"), - nwd.col("u"), + nwp.col("a"), + nwp.col("b"), + nwp.col("c"), + nwp.col("d"), + nwp.col("e"), + nwp.col("f"), + nwp.col("g"), + nwp.col("h"), + nwp.col("i"), + nwp.col("j"), + nwp.col("k"), + nwp.col("l"), + nwp.col("m"), + nwp.col("n"), + nwp.col("o"), + nwp.col("p"), + nwp.col("q"), + nwp.col("r"), + nwp.col("s"), + nwp.col("u"), ], ), ( @@ -300,21 +300,21 @@ def test_replace_selector( .mean() .name.suffix("_mean"), [ - nwd.col("a").cast(nw.Int64()).mean().alias("a_mean"), - nwd.col("b").cast(nw.Int64()).mean().alias("b_mean"), - nwd.col("c").cast(nw.Int64()).mean().alias("c_mean"), - nwd.col("d").cast(nw.Int64()).mean().alias("d_mean"), - nwd.col("e").cast(nw.Int64()).mean().alias("e_mean"), - nwd.col("f").cast(nw.Int64()).mean().alias("f_mean"), - nwd.col("g").cast(nw.Int64()).mean().alias("g_mean"), - nwd.col("h").cast(nw.Int64()).mean().alias("h_mean"), + nwp.col("a").cast(nw.Int64()).mean().alias("a_mean"), + nwp.col("b").cast(nw.Int64()).mean().alias("b_mean"), + nwp.col("c").cast(nw.Int64()).mean().alias("c_mean"), + nwp.col("d").cast(nw.Int64()).mean().alias("d_mean"), + nwp.col("e").cast(nw.Int64()).mean().alias("e_mean"), + nwp.col("f").cast(nw.Int64()).mean().alias("f_mean"), + nwp.col("g").cast(nw.Int64()).mean().alias("g_mean"), + nwp.col("h").cast(nw.Int64()).mean().alias("h_mean"), ], ), ( - nwd.col("u").alias("1").alias("2").alias("3").alias("4").name.keep(), + nwp.col("u").alias("1").alias("2").alias("3").alias("4").name.keep(), # NOTE: Would be nice to rewrite with less intermediate steps # but retrieving the root name is enough for now - [nwd.col("u").alias("1").alias("2").alias("3").alias("4").alias("u")], + [nwp.col("u").alias("1").alias("2").alias("3").alias("4").alias("u")], ), ( ( @@ -322,30 +322,30 @@ def test_replace_selector( * 100 ).name.suffix("_mult_100"), [ - (nwd.col("e") * nwd.lit(100)).alias("e_mult_100"), - (nwd.col("h") * nwd.lit(100)).alias("h_mult_100"), - (nwd.col("j") * nwd.lit(100)).alias("j_mult_100"), + (nwp.col("e") * nwp.lit(100)).alias("e_mult_100"), + (nwp.col("h") * nwp.lit(100)).alias("h_mult_100"), + (nwp.col("j") * nwp.lit(100)).alias("j_mult_100"), ], ), ( ndcs.by_dtype(nw.Duration()) .dt.total_minutes() .name.map(lambda nm: f"total_mins: {nm!r} ?"), - [nwd.col("q").dt.total_minutes().alias("total_mins: 'q' ?")], + [nwp.col("q").dt.total_minutes().alias("total_mins: 'q' ?")], ), ( - nwd.col("f", "g") + nwp.col("f", "g") .cast(nw.String) .str.starts_with("1") .all() .name.suffix("_all_starts_with_1"), [ - nwd.col("f") + nwp.col("f") .cast(nw.String) .str.starts_with("1") .all() .alias("f_all_starts_with_1"), - nwd.col("g") + nwp.col("g") .cast(nw.String) .str.starts_with("1") .all() @@ -353,66 +353,66 @@ def test_replace_selector( ], ), ( - nwd.col("a", "b") + nwp.col("a", "b") .first() .over("c", "e", order_by="d") .name.suffix("_first_over_part_order_1"), [ - nwd.col("a") + nwp.col("a") .first() - .over(nwd.col("c"), nwd.col("e"), order_by=[nwd.col("d")]) + .over(nwp.col("c"), nwp.col("e"), order_by=[nwp.col("d")]) .alias("a_first_over_part_order_1"), - nwd.col("b") + nwp.col("b") .first() - .over(nwd.col("c"), nwd.col("e"), order_by=[nwd.col("d")]) + .over(nwp.col("c"), nwp.col("e"), order_by=[nwp.col("d")]) .alias("b_first_over_part_order_1"), ], ), ( - nwd.exclude(BIG_EXCLUDE), + nwp.exclude(BIG_EXCLUDE), [ - nwd.col("c"), - nwd.col("d"), - nwd.col("f"), - nwd.col("g"), - nwd.col("h"), - nwd.col("i"), - nwd.col("j"), + nwp.col("c"), + nwp.col("d"), + nwp.col("f"), + nwp.col("g"), + nwp.col("h"), + nwp.col("i"), + nwp.col("j"), ], ), ( - nwd.exclude(BIG_EXCLUDE).name.suffix("_2"), + nwp.exclude(BIG_EXCLUDE).name.suffix("_2"), [ - nwd.col("c").alias("c_2"), - nwd.col("d").alias("d_2"), - nwd.col("f").alias("f_2"), - nwd.col("g").alias("g_2"), - nwd.col("h").alias("h_2"), - nwd.col("i").alias("i_2"), - nwd.col("j").alias("j_2"), + nwp.col("c").alias("c_2"), + nwp.col("d").alias("d_2"), + nwp.col("f").alias("f_2"), + nwp.col("g").alias("g_2"), + nwp.col("h").alias("h_2"), + nwp.col("i").alias("i_2"), + nwp.col("j").alias("j_2"), ], ), ( - nwd.col("c").alias("c_min_over_order_by").min().over(order_by=ndcs.string()), + nwp.col("c").alias("c_min_over_order_by").min().over(order_by=ndcs.string()), [ - nwd.col("c") + nwp.col("c") .alias("c_min_over_order_by") .min() - .over(order_by=[nwd.col("k")]) + .over(order_by=[nwp.col("k")]) ], ), pytest.param( - (ndcs.by_name("a", "b", "c") / nwd.col("e").first()) + (ndcs.by_name("a", "b", "c") / nwp.col("e").first()) .over("g", "f", order_by="f") .name.prefix("hi_"), [ - (nwd.col("a") / nwd.col("e").first()) + (nwp.col("a") / nwp.col("e").first()) .over("g", "f", order_by="f") .alias("hi_a"), - (nwd.col("b") / nwd.col("e").first()) + (nwp.col("b") / nwp.col("e").first()) .over("g", "f", order_by="f") .alias("hi_b"), - (nwd.col("c") / nwd.col("e").first()) + (nwp.col("c") / nwp.col("e").first()) .over("g", "f", order_by="f") .alias("hi_c"), ], @@ -422,7 +422,7 @@ def test_replace_selector( ) def test_prepare_projection( into_exprs: IntoExpr | Sequence[IntoExpr], - expected: Sequence[Expr], + expected: Sequence[nwp.Expr], schema_1: dict[str, DType], ) -> None: irs_in = parse_into_seq_of_expr_ir(into_exprs) @@ -435,19 +435,19 @@ def test_prepare_projection( @pytest.mark.parametrize( "expr", [ - nwd.all(), - nwd.nth(1, 2, 3), - nwd.col("a", "b", "c"), + nwp.all(), + nwp.nth(1, 2, 3), + nwp.col("a", "b", "c"), ndcs.boolean() | ndcs.categorical(), (ndcs.by_name("a", "b") | ndcs.string()), - (nwd.col("b", "c") & nwd.col("a")), - nwd.col("a", "b").min().over("c", order_by="e"), + (nwp.col("b", "c") & nwp.col("a")), + nwp.col("a", "b").min().over("c", order_by="e"), (~ndcs.by_dtype(nw.Int64()) - ndcs.datetime()), - nwd.nth(6, 2).abs().cast(nw.Int32()) + 10, + nwp.nth(6, 2).abs().cast(nw.Int32()) + 10, *MULTI_OUTPUT_EXPRS, ], ) -def test_prepare_projection_duplicate(expr: Expr, schema_1: dict[str, DType]) -> None: +def test_prepare_projection_duplicate(expr: nwp.Expr, schema_1: dict[str, DType]) -> None: irs = parse_into_seq_of_expr_ir(expr.alias("dupe")) pattern = re.compile(r"\.alias\(.dupe.\)") with pytest.raises(DuplicateError, match=pattern): @@ -457,12 +457,12 @@ def test_prepare_projection_duplicate(expr: Expr, schema_1: dict[str, DType]) -> @pytest.mark.parametrize( ("into_exprs", "missing"), [ - ([nwd.col("y", "z")], ["y", "z"]), - ([nwd.col("a", "b", "z")], ["z"]), - ([nwd.col("x", "b", "a")], ["x"]), + ([nwp.col("y", "z")], ["y", "z"]), + ([nwp.col("a", "b", "z")], ["z"]), + ([nwp.col("x", "b", "a")], ["x"]), ( [ - nwd.col( + nwp.col( [ "a", "b", @@ -491,18 +491,18 @@ def test_prepare_projection_duplicate(expr: Expr, schema_1: dict[str, DType]) -> ["FIVE"], ), ( - [nwd.col("a").min().over("c").alias("y"), nwd.col("one").alias("b").last()], + [nwp.col("a").min().over("c").alias("y"), nwp.col("one").alias("b").last()], ["one"], ), - ([nwd.col("a").sort_by("b", "who").alias("f")], ["who"]), + ([nwp.col("a").sort_by("b", "who").alias("f")], ["who"]), ( [ - nwd.nth(0, 5) + nwp.nth(0, 5) .cast(nw.Int64()) .abs() .cum_sum() .over("X", "O", "h", "m", "r", "zee"), - nwd.col("d", "j"), + nwp.col("d", "j"), "n", ], ["O", "X", "zee"], @@ -525,29 +525,29 @@ def test_prepare_projection_column_not_found( [ ("a", "b", "c"), (["a", "b", "c"]), - ("a", "b", nwd.col("c")), - (nwd.col("a"), "b", "c"), - (nwd.col("a", "b"), "c"), - ("a", nwd.col("b", "c")), - ((nwd.nth(0), nwd.nth(1, 2))), + ("a", "b", nwp.col("c")), + (nwp.col("a"), "b", "c"), + (nwp.col("a", "b"), "c"), + ("a", nwp.col("b", "c")), + ((nwp.nth(0), nwp.nth(1, 2))), *MULTI_OUTPUT_EXPRS, ], ) @pytest.mark.parametrize( "function", [ - nwd.all_horizontal, - nwd.any_horizontal, - nwd.sum_horizontal, - nwd.min_horizontal, - nwd.max_horizontal, - nwd.mean_horizontal, - nwd.concat_str, + nwp.all_horizontal, + nwp.any_horizontal, + nwp.sum_horizontal, + nwp.min_horizontal, + nwp.max_horizontal, + nwp.mean_horizontal, + nwp.concat_str, ], ) def test_prepare_projection_horizontal_alias( into_exprs: IntoExpr | Iterable[IntoExpr], - function: Callable[..., Expr], + function: Callable[..., nwp.Expr], schema_1: dict[str, DType], ) -> None: # NOTE: See https://github.com/narwhals-dev/narwhals/pull/2572#discussion_r2139965411 @@ -566,7 +566,7 @@ def test_prepare_projection_horizontal_alias( @pytest.mark.parametrize( - "into_exprs", [nwd.nth(-21), nwd.nth(-1, 2, 54, 0), nwd.nth(20), nwd.nth([-10, -100])] + "into_exprs", [nwp.nth(-21), nwp.nth(-1, 2, 54, 0), nwp.nth(20), nwp.nth([-10, -100])] ) def test_prepare_projection_index_error( into_exprs: IntoExpr | Iterable[IntoExpr], schema_1: dict[str, DType] diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 52068da571..5b525001a9 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -3,20 +3,18 @@ import operator import re from collections import deque -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from contextlib import nullcontext -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any import pytest import narwhals as nw -import narwhals._plan.demo as nwd -from narwhals._plan import boolean, expr, functions as F, operators as ops -from narwhals._plan.common import ExprIR, Function -from narwhals._plan.dummy import Expr, Series -from narwhals._plan.expr import BinaryExpr, FunctionExpr, RangeExpr -from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir -from narwhals._plan.literal import SeriesLiteral +from narwhals import _plan as nwp +from narwhals._plan import expressions as ir +from narwhals._plan._parse import parse_into_seq_of_expr_ir +from narwhals._plan.expressions import functions as F, operators as ops +from narwhals._plan.expressions.literal import SeriesLiteral from narwhals.exceptions import ( InvalidIntoExprError, InvalidOperationError, @@ -31,6 +29,7 @@ from typing_extensions import TypeAlias + from narwhals._plan._function import Function from narwhals._plan.typing import IntoExpr, IntoExprColumn, OperatorFn, Seq @@ -40,15 +39,15 @@ @pytest.mark.parametrize( ("exprs", "named_exprs"), [ - ([nwd.col("a")], {}), + ([nwp.col("a")], {}), (["a"], {}), ([], {"a": "b"}), - ([], {"a": nwd.col("b")}), - (["a", "b", nwd.col("c", "d", "e")], {"g": nwd.lit(1)}), - ([["a", "b", "c"]], {"q": nwd.lit(5, nw.Int8())}), + ([], {"a": nwp.col("b")}), + (["a", "b", nwp.col("c", "d", "e")], {"g": nwp.lit(1)}), + ([["a", "b", "c"]], {"q": nwp.lit(5, nw.Int8())}), ( - [[nwd.nth(1), nwd.nth(2, 3, 4)]], - {"n": nwd.col("p").count(), "other n": nwd.len()}, + [[nwp.nth(1), nwp.nth(2, 3, 4)]], + {"n": nwp.col("p").count(), "other n": nwp.len()}, ), ], ) @@ -56,7 +55,7 @@ def test_parsing( exprs: Seq[IntoExpr | Iterable[IntoExpr]], named_exprs: dict[str, IntoExpr] ) -> None: assert all( - isinstance(node, ExprIR) + isinstance(node, ir.ExprIR) for node in parse_into_seq_of_expr_ir(*exprs, **named_exprs) ) @@ -64,12 +63,12 @@ def test_parsing( @pytest.mark.parametrize( ("function", "ir_node"), [ - (nwd.all_horizontal, boolean.AllHorizontal), - (nwd.any_horizontal, boolean.AnyHorizontal), - (nwd.sum_horizontal, F.SumHorizontal), - (nwd.min_horizontal, F.MinHorizontal), - (nwd.max_horizontal, F.MaxHorizontal), - (nwd.mean_horizontal, F.MeanHorizontal), + (nwp.all_horizontal, ir.boolean.AllHorizontal), + (nwp.any_horizontal, ir.boolean.AnyHorizontal), + (nwp.sum_horizontal, F.SumHorizontal), + (nwp.min_horizontal, F.MinHorizontal), + (nwp.max_horizontal, F.MaxHorizontal), + (nwp.mean_horizontal, F.MeanHorizontal), ], ) @pytest.mark.parametrize( @@ -77,24 +76,24 @@ def test_parsing( [ ("a", "b", "c"), (["a", "b", "c"]), - (nwd.col("d", "e", "f"), nwd.col("g"), "q", nwd.nth(9)), - ((nwd.lit(1),)), - ([nwd.lit(1), nwd.lit(2, nw.Int64), nwd.lit(3, nw.Int64())]), + (nwp.col("d", "e", "f"), nwp.col("g"), "q", nwp.nth(9)), + ((nwp.lit(1),)), + ([nwp.lit(1), nwp.lit(2, nw.Int64), nwp.lit(3, nw.Int64())]), ], ) def test_function_expr_horizontal( - function: Callable[..., Expr], + function: Callable[..., nwp.Expr], ir_node: type[Function], args: Seq[IntoExpr | Iterable[IntoExpr]], ) -> None: variadic = function(*args) sequence = function(args) - assert isinstance(variadic, Expr) - assert isinstance(sequence, Expr) + assert isinstance(variadic, nwp.Expr) + assert isinstance(sequence, nwp.Expr) variadic_node = variadic._ir sequence_node = sequence._ir - unrelated_node = nwd.lit(1)._ir - assert isinstance(variadic_node, FunctionExpr) + unrelated_node = nwp.lit(1)._ir + assert isinstance(variadic_node, ir.FunctionExpr) assert isinstance(variadic_node.function, ir_node) assert variadic_node == sequence_node assert sequence_node != unrelated_node @@ -106,7 +105,7 @@ def test_valid_windows() -> None: https://github.com/narwhals-dev/narwhals/blob/63c8e4771a1df4e0bfeea5559c303a4a447d5cc2/tests/expression_parsing_test.py#L10-L45 """ ELEMENTWISE_ERR = re.compile(r"cannot use.+over.+elementwise", re.IGNORECASE) # noqa: N806 - a = nwd.col("a") + a = nwp.col("a") assert a.cum_sum() assert a.cum_sum().over(order_by="id") with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): @@ -115,32 +114,32 @@ def test_valid_windows() -> None: assert (a.cum_sum() + 1).over(order_by="id") assert a.cum_sum().cum_sum().over(order_by="id") assert a.cum_sum().cum_sum() - assert nwd.sum_horizontal(a, a.cum_sum()) + assert nwp.sum_horizontal(a, a.cum_sum()) with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): - assert nwd.sum_horizontal(a, a.cum_sum()).over(order_by="a") + assert nwp.sum_horizontal(a, a.cum_sum()).over(order_by="a") - assert nwd.sum_horizontal(a, a.cum_sum().over(order_by="i")) - assert nwd.sum_horizontal(a.diff(), a.cum_sum().over(order_by="i")) + assert nwp.sum_horizontal(a, a.cum_sum().over(order_by="i")) + assert nwp.sum_horizontal(a.diff(), a.cum_sum().over(order_by="i")) with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): - assert nwd.sum_horizontal(a.diff(), a.cum_sum()).over(order_by="i") + assert nwp.sum_horizontal(a.diff(), a.cum_sum()).over(order_by="i") with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): - assert nwd.sum_horizontal(a.diff().abs(), a.cum_sum()).over(order_by="i") + assert nwp.sum_horizontal(a.diff().abs(), a.cum_sum()).over(order_by="i") def test_invalid_repeat_agg() -> None: with pytest.raises(InvalidOperationError): - nwd.col("a").mean().mean() + nwp.col("a").mean().mean() with pytest.raises(InvalidOperationError): - nwd.col("a").first().max() + nwp.col("a").first().max() with pytest.raises(InvalidOperationError): - nwd.col("a").any().std() + nwp.col("a").any().std() with pytest.raises(InvalidOperationError): - nwd.col("a").all().quantile(0.5, "linear") + nwp.col("a").all().quantile(0.5, "linear") with pytest.raises(InvalidOperationError): - nwd.col("a").arg_max().min() + nwp.col("a").arg_max().min() with pytest.raises(InvalidOperationError): - nwd.col("a").arg_min().arg_max() + nwp.col("a").arg_min().arg_max() # NOTE: Previously multiple different errors, but they can be reduced to the same thing @@ -148,51 +147,51 @@ def test_invalid_repeat_agg() -> None: def test_invalid_agg_non_elementwise() -> None: pattern = re.compile(r"cannot use.+rank.+aggregated.+mean", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): - nwd.col("a").mean().rank() + nwp.col("a").mean().rank() pattern = re.compile(r"cannot use.+drop_nulls.+aggregated.+max", re.IGNORECASE) with pytest.raises(InvalidOperationError): - nwd.col("a").max().drop_nulls() + nwp.col("a").max().drop_nulls() pattern = re.compile(r"cannot use.+diff.+aggregated.+min", re.IGNORECASE) with pytest.raises(InvalidOperationError): - nwd.col("a").min().diff() + nwp.col("a").min().diff() def test_agg_non_elementwise_range_special() -> None: - e = nwd.int_range(0, 100) - assert isinstance(e._ir, RangeExpr) - e = nwd.int_range(nwd.len(), dtype=nw.UInt32).alias("index") - ir = e._ir - assert isinstance(ir, expr.Alias) - assert isinstance(ir.expr, RangeExpr) - assert isinstance(ir.expr.input[0], expr.Literal) - assert isinstance(ir.expr.input[1], expr.Len) + e = nwp.int_range(0, 100) + assert isinstance(e._ir, ir.RangeExpr) + e = nwp.int_range(nwp.len(), dtype=nw.UInt32).alias("index") + e_ir = e._ir + assert isinstance(e_ir, ir.Alias) + assert isinstance(e_ir.expr, ir.RangeExpr) + assert isinstance(e_ir.expr.input[0], ir.Literal) + assert isinstance(e_ir.expr.input[1], ir.Len) def test_invalid_int_range() -> None: pattern = re.compile(r"scalar.+agg", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): - nwd.int_range(nwd.col("a")) + nwp.int_range(nwp.col("a")) with pytest.raises(InvalidOperationError, match=pattern): - nwd.int_range(nwd.nth(1), 10) + nwp.int_range(nwp.nth(1), 10) with pytest.raises(InvalidOperationError, match=pattern): - nwd.int_range(0, nwd.col("a").abs()) + nwp.int_range(0, nwp.col("a").abs()) with pytest.raises(InvalidOperationError, match=pattern): - nwd.int_range(nwd.col("a") + 1) + nwp.int_range(nwp.col("a") + 1) # NOTE: Non-`polars`` rule def test_invalid_over() -> None: pattern = re.compile(r"cannot use.+over.+elementwise", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): - nwd.col("a").fill_null(3).over("b") + nwp.col("a").fill_null(3).over("b") def test_nested_over() -> None: pattern = re.compile(r"cannot nest.+over", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): - nwd.col("a").mean().over("b").over("c") + nwp.col("a").mean().over("b").over("c") with pytest.raises(InvalidOperationError, match=pattern): - nwd.col("a").mean().over("b").over("c", order_by="i") + nwp.col("a").mean().over("b").over("c", order_by="i") # NOTE: This *can* error in polars, but only if the length **actually changes** @@ -200,36 +199,36 @@ def test_nested_over() -> None: def test_filtration_over() -> None: pattern = re.compile(r"cannot use.+over.+change length", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): - nwd.col("a").drop_nulls().over("b") + nwp.col("a").drop_nulls().over("b") with pytest.raises(InvalidOperationError, match=pattern): - nwd.col("a").drop_nulls().over("b", order_by="i") + nwp.col("a").drop_nulls().over("b", order_by="i") with pytest.raises(InvalidOperationError, match=pattern): - nwd.col("a").diff().drop_nulls().over("b", order_by="i") + nwp.col("a").diff().drop_nulls().over("b", order_by="i") def test_invalid_binary_expr_multi() -> None: pattern = re.escape("all() + cols(['b', 'c'])\n ^^^^^^^^^^^^^^^^") with pytest.raises(MultiOutputExpressionError, match=pattern): - nwd.all() + nwd.col("b", "c") + nwp.all() + nwp.col("b", "c") pattern = re.escape( "index_columns((1, 2, 3)) * index_columns((4, 5, 6)).max()\n" " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^" ) with pytest.raises(MultiOutputExpressionError, match=pattern): - nwd.nth(1, 2, 3) * nwd.nth(4, 5, 6).max() + nwp.nth(1, 2, 3) * nwp.nth(4, 5, 6).max() pattern = re.escape( "cols(['a', 'b', 'c']).abs().fill_null([lit(int: 0)]).round() * index_columns((9, 10)).cast(Int64).sort(asc)\n" " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^" ) with pytest.raises(MultiOutputExpressionError, match=pattern): - nwd.col("a", "b", "c").abs().fill_null(0).round(2) * nwd.nth(9, 10).cast( + nwp.col("a", "b", "c").abs().fill_null(0).round(2) * nwp.nth(9, 10).cast( nw.Int64() ).sort() def test_invalid_binary_expr_length_changing() -> None: - a = nwd.col("a") - b = nwd.col("b") + a = nwp.col("a") + b = nwp.col("b") with pytest.raises(LengthChangingExprError): a.unique() + b.unique() @@ -247,13 +246,13 @@ def test_invalid_binary_expr_length_changing() -> None: a.map_batches(lambda x: x) / b.gather_every(1, 0) -def _is_expr_ir_binary_expr(expr: Expr) -> bool: - return isinstance(expr._ir, BinaryExpr) +def _is_expr_ir_binary_expr(expr: nwp.Expr) -> bool: + return isinstance(expr._ir, ir.BinaryExpr) def test_binary_expr_length_changing_agg() -> None: - a = nwd.col("a") - b = nwd.col("b") + a = nwp.col("a") + b = nwp.col("b") assert _is_expr_ir_binary_expr(a.unique().first() + b.unique()) assert _is_expr_ir_binary_expr(a.mode().last() * b.unique()) @@ -273,8 +272,8 @@ def test_invalid_binary_expr_shape() -> None: re.escape("Cannot combine length-changing expressions with length-preserving"), re.IGNORECASE, ) - a = nwd.col("a") - b = nwd.col("b") + a = nwp.col("a") + b = nwp.col("b") with pytest.raises(ShapeError, match=pattern): a.unique() + b @@ -288,11 +287,11 @@ def test_invalid_binary_expr_shape() -> None: def test_is_in_seq(into_iter: IntoIterable) -> None: expected = 1, 2, 3 other = into_iter(list(expected)) - expr = nwd.col("a").is_in(other) - ir = expr._ir - assert isinstance(ir, FunctionExpr) - assert isinstance(ir.function, boolean.IsInSeq) - assert ir.function.other == expected + expr = nwp.col("a").is_in(other) + e_ir = expr._ir + assert isinstance(e_ir, ir.FunctionExpr) + assert isinstance(e_ir.function, ir.boolean.IsInSeq) + assert e_ir.function.other == expected def test_is_in_series() -> None: @@ -300,12 +299,12 @@ def test_is_in_series() -> None: import pyarrow as pa native = pa.chunked_array([pa.array([1, 2, 3])]) - other = Series.from_native(native) - expr = nwd.col("a").is_in(other) - ir = expr._ir - assert isinstance(ir, FunctionExpr) - assert isinstance(ir.function, boolean.IsInSeries) - assert ir.function.other.unwrap().to_native() is native + other = nwp.Series.from_native(native) + expr = nwp.col("a").is_in(other) + e_ir = expr._ir + assert isinstance(e_ir, ir.FunctionExpr) + assert isinstance(e_ir.function, ir.boolean.IsInSeries) + assert e_ir.function.other.unwrap().to_native() is native @pytest.mark.parametrize( @@ -314,7 +313,7 @@ def test_is_in_series() -> None: ("words", pytest.raises(TypeError, match=r"str \| bytes.+str")), (b"words", pytest.raises(TypeError, match=r"str \| bytes.+bytes")), ( - nwd.col("b"), + nwp.col("b"), pytest.raises( NotImplementedError, match=re.compile(r"iterable instead", re.IGNORECASE) ), @@ -329,19 +328,19 @@ def test_is_in_series() -> None: ) def test_invalid_is_in(other: Any, context: AbstractContextManager[Any]) -> None: with context: - nwd.col("a").is_in(other) + nwp.col("a").is_in(other) def test_filter_full_spellings() -> None: - a = nwd.col("a") - b = nwd.col("b") - c = nwd.col("c") - d = nwd.col("d") - expected = a.filter(b != b.max(), c < nwd.lit(2), d == nwd.lit(5)) - expr_1 = a.filter([b != b.max(), c < nwd.lit(2), d == nwd.lit(5)]) - expr_2 = a.filter([b != b.max(), c < nwd.lit(2)], d=nwd.lit(5)) - expr_3 = a.filter([b != b.max(), c < nwd.lit(2)], d=5) - expr_4 = a.filter(b != b.max(), c < nwd.lit(2), d=5) + a = nwp.col("a") + b = nwp.col("b") + c = nwp.col("c") + d = nwp.col("d") + expected = a.filter(b != b.max(), c < nwp.lit(2), d == nwp.lit(5)) + expr_1 = a.filter([b != b.max(), c < nwp.lit(2), d == nwp.lit(5)]) + expr_2 = a.filter([b != b.max(), c < nwp.lit(2)], d=nwp.lit(5)) + expr_3 = a.filter([b != b.max(), c < nwp.lit(2)], d=5) + expr_4 = a.filter(b != b.max(), c < nwp.lit(2), d=5) expr_5 = a.filter(b != b.max(), c < 2, d=5) expr_6 = a.filter((b != b.max(), c < 2), d=5) assert_expr_ir_equal(expected, expr_1) @@ -355,9 +354,9 @@ def test_filter_full_spellings() -> None: @pytest.mark.parametrize( ("predicates", "constraints", "context"), [ - ([nwd.col("b").is_last_distinct()], {}, nullcontext()), + ([nwp.col("b").is_last_distinct()], {}, nullcontext()), ((), {"b": 10}, nullcontext()), - ((), {"b": nwd.lit(10)}, nullcontext()), + ((), {"b": nwp.lit(10)}, nullcontext()), ( (), {}, @@ -365,9 +364,9 @@ def test_filter_full_spellings() -> None: TypeError, match=re.compile(r"at least one predicate", re.IGNORECASE) ), ), - ((nwd.col("b") > 1, nwd.col("c").is_null()), {}, nullcontext()), + ((nwp.col("b") > 1, nwp.col("c").is_null()), {}, nullcontext()), ( - ([nwd.col("b") > 1], nwd.col("c").is_null()), + ([nwp.col("b") > 1], nwp.col("c").is_null()), {}, pytest.raises( InvalidIntoExprError, @@ -384,7 +383,7 @@ def test_filter_partial_spellings( context: AbstractContextManager[Any], ) -> None: with context: - assert nwd.col("a").filter(*predicates, **constraints) + assert nwp.col("a").filter(*predicates, **constraints) def test_lit_series_roundtrip() -> None: @@ -393,15 +392,15 @@ def test_lit_series_roundtrip() -> None: data = ["a", "b", "c"] native = pa.chunked_array([pa.array(data)]) - series = Series.from_native(native) - lit_series = nwd.lit(series) + series = nwp.Series.from_native(native) + lit_series = nwp.lit(series) assert lit_series.meta.is_literal() - ir = lit_series._ir - assert isinstance(ir, expr.Literal) - assert isinstance(ir.dtype, nw.String) - assert isinstance(ir.value, SeriesLiteral) - unwrapped = ir.unwrap() - assert isinstance(unwrapped, Series) + e_ir = lit_series._ir + assert isinstance(e_ir, ir.Literal) + assert isinstance(e_ir.dtype, nw.String) + assert isinstance(e_ir.value, SeriesLiteral) + unwrapped = e_ir.unwrap() + assert isinstance(unwrapped, nwp.Series) assert isinstance(unwrapped.to_native(), pa.ChunkedArray) assert unwrapped.to_list() == data @@ -409,24 +408,24 @@ def test_lit_series_roundtrip() -> None: @pytest.mark.parametrize( ("arg_1", "arg_2", "function", "op"), [ - (nwd.col("a"), 1, operator.eq, ops.Eq), - (nwd.col("a"), "b", operator.eq, ops.Eq), - (nwd.col("a"), 1, operator.ne, ops.NotEq), - (nwd.col("a"), "b", operator.ne, ops.NotEq), - (nwd.col("a"), "b", operator.ge, ops.GtEq), - (nwd.col("a"), "b", operator.gt, ops.Gt), - (nwd.col("a"), "b", operator.le, ops.LtEq), - (nwd.col("a"), "b", operator.lt, ops.Lt), - ((nwd.col("a") != 1), False, operator.and_, ops.And), - ((nwd.col("a") != 1), False, operator.or_, ops.Or), - ((nwd.col("a")), True, operator.xor, ops.ExclusiveOr), - (nwd.col("a"), 6, operator.add, ops.Add), - (nwd.col("a"), 2.1, operator.mul, ops.Multiply), - (nwd.col("a"), nwd.col("b"), operator.sub, ops.Sub), - (nwd.col("a"), 2, operator.pow, F.Pow), - (nwd.col("a"), 2, operator.mod, ops.Modulus), - (nwd.col("a"), 2, operator.floordiv, ops.FloorDivide), - (nwd.col("a"), 4, operator.truediv, ops.TrueDivide), + (nwp.col("a"), 1, operator.eq, ops.Eq), + (nwp.col("a"), "b", operator.eq, ops.Eq), + (nwp.col("a"), 1, operator.ne, ops.NotEq), + (nwp.col("a"), "b", operator.ne, ops.NotEq), + (nwp.col("a"), "b", operator.ge, ops.GtEq), + (nwp.col("a"), "b", operator.gt, ops.Gt), + (nwp.col("a"), "b", operator.le, ops.LtEq), + (nwp.col("a"), "b", operator.lt, ops.Lt), + ((nwp.col("a") != 1), False, operator.and_, ops.And), + ((nwp.col("a") != 1), False, operator.or_, ops.Or), + ((nwp.col("a")), True, operator.xor, ops.ExclusiveOr), + (nwp.col("a"), 6, operator.add, ops.Add), + (nwp.col("a"), 2.1, operator.mul, ops.Multiply), + (nwp.col("a"), nwp.col("b"), operator.sub, ops.Sub), + (nwp.col("a"), 2, operator.pow, F.Pow), + (nwp.col("a"), 2, operator.mod, ops.Modulus), + (nwp.col("a"), 2, operator.floordiv, ops.FloorDivide), + (nwp.col("a"), 4, operator.truediv, ops.TrueDivide), ], ) def test_operators_left_right( @@ -443,8 +442,8 @@ def test_operators_left_right( } result_1 = function(arg_1, arg_2) result_2 = function(arg_2, arg_1) - assert isinstance(result_1, Expr) - assert isinstance(result_2, Expr) + assert isinstance(result_1, nwp.Expr) + assert isinstance(result_2, nwp.Expr) ir_1 = result_1._ir ir_2 = result_2._ir if op in {ops.Eq, ops.NotEq}: @@ -452,9 +451,9 @@ def test_operators_left_right( else: assert ir_1 != ir_2 if issubclass(op, ops.Operator): - assert isinstance(ir_1, BinaryExpr) + assert isinstance(ir_1, ir.BinaryExpr) assert isinstance(ir_1.op, op) - assert isinstance(ir_2, BinaryExpr) + assert isinstance(ir_2, ir.BinaryExpr) op_inverse = inverse.get(op, op) assert isinstance(ir_2.op, op_inverse) if op in {ops.Eq, ops.NotEq, *inverse}: @@ -464,8 +463,8 @@ def test_operators_left_right( assert ir_1.left == ir_2.right assert ir_1.right == ir_2.left else: - assert isinstance(ir_1, FunctionExpr) + assert isinstance(ir_1, ir.FunctionExpr) assert isinstance(ir_1.function, op) - assert isinstance(ir_2, FunctionExpr) + assert isinstance(ir_2, ir.FunctionExpr) assert isinstance(ir_2.function, op) assert tuple(reversed(ir_2.input)) == ir_1.input diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index 740d966818..bf810aa176 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -5,21 +5,18 @@ import pytest import narwhals as nw -from narwhals._plan import demo as nwd, expr_parsing as parse, selectors as ndcs -from narwhals._plan._guards import is_expr -from narwhals._plan.common import ExprIR, NamedIR -from narwhals._plan.expr import WindowExpr -from narwhals._plan.expr_rewrites import ( +from narwhals import _plan as nwp +from narwhals._plan import _parse, expressions as ir, selectors as ndcs +from narwhals._plan._rewrites import ( rewrite_all, rewrite_binary_agg_over, rewrite_elementwise_over, ) -from narwhals._plan.window import Over +from narwhals._plan.expressions.window import Over from narwhals.exceptions import InvalidOperationError from tests.plan.utils import assert_expr_ir_equal if TYPE_CHECKING: - from narwhals._plan.dummy import Expr from narwhals._plan.typing import IntoExpr from narwhals.dtypes import DType @@ -41,24 +38,24 @@ def schema_2() -> dict[str, DType]: } -def _to_window_expr(into_expr: IntoExpr, *partition_by: IntoExpr) -> WindowExpr: - return WindowExpr( - expr=parse.parse_into_expr_ir(into_expr), - partition_by=parse.parse_into_seq_of_expr_ir(*partition_by), +def _to_window_expr(into_expr: IntoExpr, *partition_by: IntoExpr) -> ir.WindowExpr: + return ir.WindowExpr( + expr=_parse.parse_into_expr_ir(into_expr), + partition_by=_parse.parse_into_seq_of_expr_ir(*partition_by), options=Over(), ) def test_rewrite_elementwise_over_simple(schema_2: dict[str, DType]) -> None: with pytest.raises(InvalidOperationError, match=r"over.+elementwise"): - nwd.col("a").sum().abs().over("b") + nwp.col("a").sum().abs().over("b") # NOTE: Since the requested "before" expression is currently an error (at definition time), # we need to manually build the IR - to sidestep the validation in `Over.to_window_expr`. # Later, that error might not be needed if we can do this rewrite. # If you're here because of a "Did not raise" - just replace everything with the (previously) erroring expr. - expected = nwd.col("a").sum().over("b").abs() - before = _to_window_expr(nwd.col("a").sum().abs(), "b").to_narwhals() + expected = nwp.col("a").sum().over("b").abs() + before = _to_window_expr(nwp.col("a").sum().abs(), "b").to_narwhals() assert_expr_ir_equal(before, "col('a').sum().abs().over([col('b')])") actual = rewrite_all(before, schema=schema_2, rewrites=[rewrite_elementwise_over]) assert len(actual) == 1 @@ -67,11 +64,11 @@ def test_rewrite_elementwise_over_simple(schema_2: dict[str, DType]) -> None: def test_rewrite_elementwise_over_multiple(schema_2: dict[str, DType]) -> None: expected = ( - nwd.col("b").last().over("d").replace_strict({1: 2}), - nwd.col("c").last().over("d").replace_strict({1: 2}), + nwp.col("b").last().over("d").replace_strict({1: 2}), + nwp.col("c").last().over("d").replace_strict({1: 2}), ) before = _to_window_expr( - nwd.col("b", "c").last().replace_strict({1: 2}), "d" + nwp.col("b", "c").last().replace_strict({1: 2}), "d" ).to_narwhals() assert_expr_ir_equal( before, "cols(['b', 'c']).last().replace_strict().over([col('d')])" @@ -82,37 +79,36 @@ def test_rewrite_elementwise_over_multiple(schema_2: dict[str, DType]) -> None: assert_expr_ir_equal(lhs, rhs) -def named_ir(name: str, expr: Expr | ExprIR, /) -> NamedIR[ExprIR]: +def named_ir(name: str, expr: nwp.Expr | ir.ExprIR, /) -> ir.NamedIR[ir.ExprIR]: """Helper constructor for test compare.""" - ir = expr._ir if is_expr(expr) else expr - return NamedIR(expr=ir, name=name) + return ir.NamedIR(expr=expr._ir if isinstance(expr, nwp.Expr) else expr, name=name) def test_rewrite_elementwise_over_complex(schema_2: dict[str, DType]) -> None: expected = ( - named_ir("a", nwd.col("a")), - named_ir("b", nwd.col("b").cast(nw.String)), - named_ir("x2", nwd.col("c").max().over("a").fill_null(50)), - named_ir("d**", ~nwd.col("d").is_duplicated().over("b")), - named_ir("f_some", nwd.col("f").str.contains("some")), - named_ir("g_some", nwd.col("g").str.contains("some")), - named_ir("h_some", nwd.col("h").str.contains("some")), - named_ir("D", nwd.col("d").null_count().over("f", "g", "j").sqrt()), - named_ir("E", nwd.col("e").null_count().over("f", "g", "j").sqrt()), - named_ir("B", nwd.col("b").null_count().over("f", "g", "j").sqrt()), + named_ir("a", nwp.col("a")), + named_ir("b", nwp.col("b").cast(nw.String)), + named_ir("x2", nwp.col("c").max().over("a").fill_null(50)), + named_ir("d**", ~nwp.col("d").is_duplicated().over("b")), + named_ir("f_some", nwp.col("f").str.contains("some")), + named_ir("g_some", nwp.col("g").str.contains("some")), + named_ir("h_some", nwp.col("h").str.contains("some")), + named_ir("D", nwp.col("d").null_count().over("f", "g", "j").sqrt()), + named_ir("E", nwp.col("e").null_count().over("f", "g", "j").sqrt()), + named_ir("B", nwp.col("b").null_count().over("f", "g", "j").sqrt()), ) before = ( - nwd.col("a"), - nwd.col("b").cast(nw.String), + nwp.col("a"), + nwp.col("b").cast(nw.String), ( - _to_window_expr(nwd.col("c").max().alias("x").fill_null(50), "a") + _to_window_expr(nwp.col("c").max().alias("x").fill_null(50), "a") .to_narwhals() .alias("x2") ), - ~(nwd.col("d").is_duplicated().alias("d*")).alias("d**").over("b"), + ~(nwp.col("d").is_duplicated().alias("d*")).alias("d**").over("b"), ndcs.string().str.contains("some").name.suffix("_some"), ( - _to_window_expr(nwd.nth(3, 4, 1).null_count().sqrt(), "f", "g", "j") + _to_window_expr(nwp.nth(3, 4, 1).null_count().sqrt(), "f", "g", "j") .to_narwhals() .name.to_uppercase() ), @@ -125,12 +121,12 @@ def test_rewrite_elementwise_over_complex(schema_2: dict[str, DType]) -> None: def test_rewrite_binary_agg_over_simple(schema_2: dict[str, DType]) -> None: expected = ( - nwd.col("a") - nwd.col("a").mean().over("b"), - nwd.col("c") * nwd.col("c").abs().null_count().over("d"), + nwp.col("a") - nwp.col("a").mean().over("b"), + nwp.col("c") * nwp.col("c").abs().null_count().over("d"), ) before = ( - (nwd.col("a") - nwd.col("a").mean()).over("b"), - (nwd.col("c") * nwd.col("c").abs().null_count()).over("d"), + (nwp.col("a") - nwp.col("a").mean()).over("b"), + (nwp.col("c") * nwp.col("c").abs().null_count()).over("d"), ) actual = rewrite_all(*before, schema=schema_2, rewrites=[rewrite_binary_agg_over]) assert len(actual) == 2 @@ -140,13 +136,13 @@ def test_rewrite_binary_agg_over_simple(schema_2: dict[str, DType]) -> None: def test_rewrite_binary_agg_over_multiple(schema_2: dict[str, DType]) -> None: expected = ( - named_ir("hi_a", nwd.col("a") / nwd.col("e").drop_nulls().first().over("g")), - named_ir("hi_b", nwd.col("b") / nwd.col("e").drop_nulls().first().over("g")), - named_ir("hi_c", nwd.col("c") / nwd.col("e").drop_nulls().first().over("g")), - named_ir("hi_d", nwd.col("d") / nwd.col("e").drop_nulls().first().over("g")), + named_ir("hi_a", nwp.col("a") / nwp.col("e").drop_nulls().first().over("g")), + named_ir("hi_b", nwp.col("b") / nwp.col("e").drop_nulls().first().over("g")), + named_ir("hi_c", nwp.col("c") / nwp.col("e").drop_nulls().first().over("g")), + named_ir("hi_d", nwp.col("d") / nwp.col("e").drop_nulls().first().over("g")), ) before = ( - (nwd.col("a", "b", "c", "d") / nwd.col("e").drop_nulls().first()).over("g") + (nwp.col("a", "b", "c", "d") / nwp.col("e").drop_nulls().first()).over("g") ).name.prefix("hi_") actual = rewrite_all(before, schema=schema_2, rewrites=[rewrite_binary_agg_over]) assert len(actual) == 4 diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index e783e55c31..2b5ca80c35 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -1,57 +1,54 @@ from __future__ import annotations -from typing import TYPE_CHECKING - import pytest -import narwhals._plan.demo as nwd +from narwhals import _plan as nwp from tests.utils import POLARS_VERSION -if TYPE_CHECKING: - from narwhals._plan.dummy import Expr - pytest.importorskip("polars") import polars as pl if POLARS_VERSION >= (1, 0): # https://github.com/pola-rs/polars/pull/16743 OVER_CASE = ( - nwd.col("a").last().over("b", order_by="c"), + nwp.col("a").last().over("b", order_by="c"), pl.col("a").last().over("b", order_by="c"), ["a", "b"], ) else: # pragma: no cover - OVER_CASE = (nwd.col("a").last().over("b"), pl.col("a").last().over("b"), ["a", "b"]) + OVER_CASE = (nwp.col("a").last().over("b"), pl.col("a").last().over("b"), ["a", "b"]) if POLARS_VERSION >= (0, 20, 5): - LEN_CASE = (nwd.len(), pl.len(), "len") + LEN_CASE = (nwp.len(), pl.len(), "len") else: # pragma: no cover - LEN_CASE = (nwd.len().alias("count"), pl.count(), "count") + LEN_CASE = (nwp.len().alias("count"), pl.count(), "count") @pytest.mark.parametrize( ("nw_expr", "pl_expr", "expected"), [ ( - nwd.col("a").alias("b").min().alias("c").alias("d"), + nwp.col("a").alias("b").min().alias("c").alias("d"), pl.col("a").alias("b").min().alias("c").alias("d"), ["a"], ), ( - (nwd.col("a") + (nwd.col("a") - nwd.col("b"))).alias("c"), + (nwp.col("a") + (nwp.col("a") - nwp.col("b"))).alias("c"), (pl.col("a") + (pl.col("a") - pl.col("b"))).alias("c"), ["a", "a", "b"], ), OVER_CASE, ( - (nwd.col("a", "b", "c").sort().abs() * 20).max(), + (nwp.col("a", "b", "c").sort().abs() * 20).max(), (pl.col("a", "b", "c").sort().abs() * 20).max(), [], ), - (nwd.all().mean(), pl.all().mean(), []), - (nwd.all().mean().sort_by("d"), pl.all().mean().sort_by("d"), ["d"]), + (nwp.all().mean(), pl.all().mean(), []), + (nwp.all().mean().sort_by("d"), pl.all().mean().sort_by("d"), ["d"]), ], ) -def test_meta_root_names(nw_expr: Expr, pl_expr: pl.Expr, expected: list[str]) -> None: +def test_meta_root_names( + nw_expr: nwp.Expr, pl_expr: pl.Expr, expected: list[str] +) -> None: pl_result = pl_expr.meta.root_names() nw_result = nw_expr.meta.root_names() assert nw_result == expected @@ -61,17 +58,17 @@ def test_meta_root_names(nw_expr: Expr, pl_expr: pl.Expr, expected: list[str]) - @pytest.mark.parametrize( ("nw_expr", "pl_expr", "expected"), [ - (nwd.col("a"), pl.col("a"), "a"), - (nwd.lit(1), pl.lit(1), "literal"), + (nwp.col("a"), pl.col("a"), "a"), + (nwp.lit(1), pl.lit(1), "literal"), LEN_CASE, pytest.param( ( - nwd.col("a") + nwp.col("a") .alias("b") .min() .alias("c") .over("e", "f") - .sort_by(nwd.col("i"), nwd.col("g", "h")) + .sort_by(nwp.col("i"), nwp.col("g", "h")) ), ( pl.col("a") @@ -85,17 +82,17 @@ def test_meta_root_names(nw_expr: Expr, pl_expr: pl.Expr, expected: list[str]) - id="Kitchen-Sink", ), pytest.param( - nwd.col("c").alias("x").fill_null(50), + nwp.col("c").alias("x").fill_null(50), pl.col("c").alias("x").fill_null(50), "x", id="FunctionExpr-Literal", ), pytest.param( ( - nwd.col("ROOT") + nwp.col("ROOT") .alias("ROOT-ALIAS") - .filter(nwd.col("b") >= 30, nwd.col("c").alias("d") == 7) - + nwd.col("RHS").alias("RHS-ALIAS") + .filter(nwp.col("b") >= 30, nwp.col("c").alias("d") == 7) + + nwp.col("RHS").alias("RHS-ALIAS") ), ( pl.col("ROOT") @@ -107,35 +104,35 @@ def test_meta_root_names(nw_expr: Expr, pl_expr: pl.Expr, expected: list[str]) - id="BinaryExpr-Multiple", ), pytest.param( - nwd.col("ROOT").alias("ROOT-ALIAS").mean().over(nwd.col("a").alias("b")), + nwp.col("ROOT").alias("ROOT-ALIAS").mean().over(nwp.col("a").alias("b")), pl.col("ROOT").alias("ROOT-ALIAS").mean().over(pl.col("a").alias("b")), "ROOT-ALIAS", id="WindowExpr", ), pytest.param( - nwd.when(nwd.col("a").alias("a?")).then(10), + nwp.when(nwp.col("a").alias("a?")).then(10), pl.when(pl.col("a").alias("a?")).then(10), "literal", id="When-Literal", ), pytest.param( - nwd.when(nwd.col("a").alias("a?")).then(nwd.col("b")).otherwise(20), + nwp.when(nwp.col("a").alias("a?")).then(nwp.col("b")).otherwise(20), pl.when(pl.col("a").alias("a?")).then(pl.col("b")).otherwise(20), "b", id="When-Column-Literal", ), pytest.param( - nwd.when(a=1).then(10).otherwise(nwd.col("c").alias("c?")), + nwp.when(a=1).then(10).otherwise(nwp.col("c").alias("c?")), pl.when(a=1).then(10).otherwise(pl.col("c").alias("c?")), "literal", id="When-Literal-Alias", ), pytest.param( ( - nwd.when(nwd.col("a").alias("a?")) + nwp.when(nwp.col("a").alias("a?")) .then(1) - .when(nwd.col("b") == 1) - .then(nwd.col("c")) + .when(nwp.col("b") == 1) + .then(nwp.col("c")) ), ( pl.when(pl.col("a").alias("a?")) @@ -148,9 +145,9 @@ def test_meta_root_names(nw_expr: Expr, pl_expr: pl.Expr, expected: list[str]) - ), pytest.param( ( - nwd.when(nwd.col("foo") > 2, nwd.col("bar") < 3) - .then(nwd.lit("Yes")) - .otherwise(nwd.lit("No")) + nwp.when(nwp.col("foo") > 2, nwp.col("bar") < 3) + .then(nwp.lit("Yes")) + .otherwise(nwp.lit("No")) .alias("TARGET") ), ( @@ -163,23 +160,23 @@ def test_meta_root_names(nw_expr: Expr, pl_expr: pl.Expr, expected: list[str]) - id="When2-Literal-Literal-Alias", ), pytest.param( - (nwd.col("ROOT").alias("ROOT-ALIAS").filter(nwd.col("c") <= 1).mean()), + (nwp.col("ROOT").alias("ROOT-ALIAS").filter(nwp.col("c") <= 1).mean()), (pl.col("ROOT").alias("ROOT-ALIAS").filter(pl.col("c") <= 1).mean()), "ROOT-ALIAS", id="Filter", ), pytest.param( - nwd.int_range(0, 10), pl.int_range(0, 10), "literal", id="IntRange-Literal" + nwp.int_range(0, 10), pl.int_range(0, 10), "literal", id="IntRange-Literal" ), pytest.param( - nwd.int_range(nwd.col("b").first(), 10), + nwp.int_range(nwp.col("b").first(), 10), pl.int_range(pl.col("b").first(), 10), "b", id="IntRange-Column", ), ], ) -def test_meta_output_name(nw_expr: Expr, pl_expr: pl.Expr, expected: str) -> None: +def test_meta_output_name(nw_expr: nwp.Expr, pl_expr: pl.Expr, expected: str) -> None: pl_result = pl_expr.meta.output_name() nw_result = nw_expr.meta.output_name() assert nw_result == expected diff --git a/tests/plan/utils.py b/tests/plan/utils.py index 4eaf98db9f..bf6135ee2f 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -2,27 +2,27 @@ from typing import TYPE_CHECKING -from narwhals._plan._guards import is_expr -from narwhals._plan.common import ExprIR, NamedIR +from narwhals import _plan as nwp +from narwhals._plan import expressions as ir if TYPE_CHECKING: from typing_extensions import LiteralString - from narwhals._plan.dummy import Expr - -def _unwrap_ir(obj: Expr | ExprIR | NamedIR) -> ExprIR: - if is_expr(obj): +def _unwrap_ir(obj: nwp.Expr | ir.ExprIR | ir.NamedIR) -> ir.ExprIR: + if isinstance(obj, nwp.Expr): return obj._ir - if isinstance(obj, ExprIR): + if isinstance(obj, ir.ExprIR): return obj - if isinstance(obj, NamedIR): + if isinstance(obj, ir.NamedIR): return obj.expr raise NotImplementedError(type(obj)) def assert_expr_ir_equal( - actual: Expr | ExprIR | NamedIR, expected: Expr | ExprIR | NamedIR | LiteralString, / + actual: nwp.Expr | ir.ExprIR | ir.NamedIR, + expected: nwp.Expr | ir.ExprIR | ir.NamedIR | LiteralString, + /, ) -> None: """Assert that `actual` is equivalent to `expected`. @@ -37,8 +37,8 @@ def assert_expr_ir_equal( lhs = _unwrap_ir(actual) if isinstance(expected, str): assert repr(lhs) == expected - elif isinstance(actual, NamedIR) and isinstance(expected, NamedIR): + elif isinstance(actual, ir.NamedIR) and isinstance(expected, ir.NamedIR): assert actual == expected else: - rhs = expected._ir if is_expr(expected) else expected + rhs = expected._ir if isinstance(expected, nwp.Expr) else expected assert lhs == rhs From 7599fc41478cdc4a3754921dcdc715cfca2d80ff Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 15 Sep 2025 08:49:24 +0000 Subject: [PATCH 360/368] revert(ruff): ignore (`RUF043`) --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c1c5ebc4b6..a25510ec2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -208,7 +208,6 @@ extend-ignore-names = [ "C901", # complex-structure "PLR0912", # too-many-branches "PLR0916", # too-many-boolean-expressions - "RUF043", # temp ignore until sync ] "tpch/tests/*" = ["S101"] "utils/*" = ["S311"] From 8208d32803e79b07f02b60e59451abdadf1a67b2 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 1 Oct 2025 10:46:45 +0000 Subject: [PATCH 361/368] feat(expr-ir): Support `group_by`, utilize `pyarrow.acero` (#3143) --- narwhals/_plan/_expansion.py | 37 +- narwhals/_plan/_expr_ir.py | 14 + narwhals/_plan/_rewrites.py | 5 +- narwhals/_plan/arrow/acero.py | 253 ++++++++ narwhals/_plan/arrow/dataframe.py | 44 +- narwhals/_plan/arrow/expr.py | 13 +- narwhals/_plan/arrow/functions.py | 71 ++- narwhals/_plan/arrow/group_by.py | 183 ++++++ narwhals/_plan/arrow/options.py | 105 ++++ narwhals/_plan/arrow/typing.py | 10 +- narwhals/_plan/common.py | 167 ++++- narwhals/_plan/dataframe.py | 74 ++- narwhals/_plan/expr.py | 3 + narwhals/_plan/expressions/__init__.py | 2 + narwhals/_plan/expressions/aggregation.py | 5 +- narwhals/_plan/group_by.py | 69 ++ narwhals/_plan/options.py | 31 +- narwhals/_plan/protocols.py | 243 +++++++- narwhals/_plan/schema.py | 75 ++- narwhals/_plan/typing.py | 7 +- tests/plan/compliant_test.py | 22 + tests/plan/expr_expansion_test.py | 182 +++--- tests/plan/expr_rewrites_test.py | 7 +- tests/plan/group_by_test.py | 726 ++++++++++++++++++++++ tests/plan/temp_test.py | 126 ++++ tests/plan/utils.py | 15 +- 26 files changed, 2283 insertions(+), 206 deletions(-) create mode 100644 narwhals/_plan/arrow/acero.py create mode 100644 narwhals/_plan/arrow/group_by.py create mode 100644 narwhals/_plan/arrow/options.py create mode 100644 narwhals/_plan/group_by.py create mode 100644 tests/plan/group_by_test.py create mode 100644 tests/plan/temp_test.py diff --git a/narwhals/_plan/_expansion.py b/narwhals/_plan/_expansion.py index fb2dd390a8..6cbf061a98 100644 --- a/narwhals/_plan/_expansion.py +++ b/narwhals/_plan/_expansion.py @@ -87,11 +87,10 @@ Excluded: TypeAlias = "frozenset[str]" """Internally use a `set`, then freeze before returning.""" -GroupByKeys: TypeAlias = "Seq[ExprIR]" -"""Represents group_by keys. +GroupByKeys: TypeAlias = "Seq[str]" +"""Represents `group_by` keys. -- Originates from `polars_plan::plans::conversion::dsl_to_ir::resolve_group_by` -- Not fully utilized in `narwhals` version yet +They need to be excluded from expansion. """ OutputNames: TypeAlias = "Seq[str]" @@ -154,24 +153,23 @@ def with_multiple_columns(self) -> ExpansionFlags: def prepare_projection( - exprs: Sequence[ExprIR], schema: IntoFrozenSchema -) -> tuple[Seq[ExprIR], FrozenSchema, OutputNames]: + exprs: Sequence[ExprIR], /, keys: GroupByKeys = (), *, schema: IntoFrozenSchema +) -> tuple[Seq[NamedIR], FrozenSchema]: """Expand IRs into named column selections. - **Primary entry-point**, will be used by `select`, `with_columns`, + **Primary entry-point**, for `select`, `with_columns`, and any other context that requires resolving expression names. Arguments: exprs: IRs that *may* contain things like `Columns`, `SelectorIR`, `Exclude`, etc. + keys: Names of `group_by` columns. schema: Scope to expand multi-column selectors in. - - Returns: - `exprs`, rewritten using `Column(name)` only. """ frozen_schema = freeze_schema(schema) - rewritten = rewrite_projections(tuple(exprs), keys=(), schema=frozen_schema) + rewritten = rewrite_projections(tuple(exprs), keys=keys, schema=frozen_schema) output_names = ensure_valid_exprs(rewritten, frozen_schema) - return rewritten, frozen_schema, output_names + named_irs = into_named_irs(rewritten, output_names) + return named_irs, frozen_schema def into_named_irs(exprs: Seq[ExprIR], names: OutputNames) -> Seq[NamedIR]: @@ -202,7 +200,7 @@ def _ensure_output_names_unique(exprs: Seq[ExprIR]) -> OutputNames: def expand_function_inputs(origin: ExprIR, /, *, schema: FrozenSchema) -> ExprIR: def fn(child: ExprIR, /) -> ExprIR: if is_horizontal_reduction(child): - rewrites = rewrite_projections(child.input, keys=(), schema=schema) + rewrites = rewrite_projections(child.input, schema=schema) return common.replace(child, input=rewrites) return child @@ -275,7 +273,7 @@ def expand_selector(selector: SelectorIR, schema: FrozenSchema) -> Columns: def rewrite_projections( input: Seq[ExprIR], # `FunctionExpr.input` /, - keys: GroupByKeys, + keys: GroupByKeys = (), *, schema: FrozenSchema, ) -> Seq[ExprIR]: @@ -323,13 +321,10 @@ def prepare_excluded( origin: ExprIR, keys: GroupByKeys, flags: ExpansionFlags, / ) -> Excluded: """Huge simplification of https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L484-L555.""" - exclude: set[str] = set() - if flags.has_exclude: - exclude.update(*(e.names for e in origin.iter_left() if isinstance(e, Exclude))) - for group_by_key in keys: - if name := group_by_key.meta.output_name(raise_if_undetermined=False): - exclude.add(name) - return frozenset(exclude) + gb_keys = frozenset(keys) + if not flags.has_exclude: + return gb_keys + return gb_keys.union(*(e.names for e in origin.iter_left() if isinstance(e, Exclude))) def _all_columns_match(origin: ExprIR, /, columns: Columns) -> bool: diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py index 0646520102..d163134c80 100644 --- a/narwhals/_plan/_expr_ir.py +++ b/narwhals/_plan/_expr_ir.py @@ -290,3 +290,17 @@ def is_elementwise_top_level(self) -> bool: if is_literal(ir): return ir.is_scalar return isinstance(ir, (expr.BinaryExpr, expr.Column, expr.TernaryExpr, expr.Cast)) + + def is_column(self, *, allow_aliasing: bool = False) -> bool: + """Return True if wrapping a single `Column` node. + + Note: + Multi-output (including selectors) expressions have been expanded at this stage. + + Arguments: + allow_aliasing: If False (default), any aliasing is not considered to be column selection. + """ + from narwhals._plan.expressions import Column + + ir = self.expr + return isinstance(ir, Column) and ((self.name == ir.name) or allow_aliasing) diff --git a/narwhals/_plan/_rewrites.py b/narwhals/_plan/_rewrites.py index ae23fa4b9b..fd26364e66 100644 --- a/narwhals/_plan/_rewrites.py +++ b/narwhals/_plan/_rewrites.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING -from narwhals._plan._expansion import into_named_irs, prepare_projection +from narwhals._plan._expansion import prepare_projection from narwhals._plan._guards import ( is_aggregation, is_binary_expr, @@ -31,8 +31,7 @@ def rewrite_all( - Currently we do a full traversal of each tree per-rewrite function - There's no caching *after* `prepare_projection` yet """ - out_irs, _, names = prepare_projection(parse_into_seq_of_expr_ir(*exprs), schema) - named_irs = into_named_irs(out_irs, names) + named_irs, _ = prepare_projection(parse_into_seq_of_expr_ir(*exprs), schema=schema) return tuple(map_ir(ir, *rewrites) for ir in named_irs) diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py new file mode 100644 index 0000000000..768248e312 --- /dev/null +++ b/narwhals/_plan/arrow/acero.py @@ -0,0 +1,253 @@ +"""Sugar for working with [Acero]. + +[`pyarrow.acero`] has some building blocks for constructing queries, but is +quite verbose when used directly. + +This module aligns some apis to look more like `polars`. + +Notes: + - Functions suffixed with `_table` all handle composition and collection internally + +[Acero]: https://arrow.apache.org/docs/cpp/acero/overview.html +[`pyarrow.acero`]: https://arrow.apache.org/docs/python/api/acero.html +""" + +from __future__ import annotations + +import functools +import operator +from functools import reduce +from itertools import chain +from typing import TYPE_CHECKING, Any, Final, Union, cast + +import pyarrow as pa # ignore-banned-import +import pyarrow.acero as pac +import pyarrow.compute as pc # ignore-banned-import +from pyarrow.acero import Declaration as Decl + +from narwhals._plan.typing import OneOrSeq +from narwhals.typing import SingleColSelector + +if TYPE_CHECKING: + from collections.abc import Callable, Collection, Iterable, Iterator + + from typing_extensions import TypeAlias + + from narwhals._arrow.typing import ( # type: ignore[attr-defined] + AggregateOptions as _AggregateOptions, + Aggregation as _Aggregation, + ) + from narwhals._plan.arrow.group_by import AggSpec + from narwhals._plan.arrow.typing import NullPlacement + from narwhals._plan.typing import OneOrIterable, Order, Seq + from narwhals.typing import NonNestedLiteral + +Incomplete: TypeAlias = Any +Expr: TypeAlias = pc.Expression +IntoExpr: TypeAlias = "Expr | NonNestedLiteral" +Field: TypeAlias = Union[Expr, SingleColSelector] +"""Anything that passes as a single item in [`_compute._ensure_field_ref`]. + +[`_compute._ensure_field_ref`]: https://github.com/apache/arrow/blob/9b96bdbc733d62f0375a2b1b9806132abc19cd3f/python/pyarrow/_compute.pyx#L1507-L1531 +""" + +Target: TypeAlias = OneOrSeq[Field] +Aggregation: TypeAlias = "_Aggregation" +AggregateOptions: TypeAlias = "_AggregateOptions" +Opts: TypeAlias = "AggregateOptions | None" +OutputName: TypeAlias = str + +_THREAD_UNSAFE: Final = frozenset[Aggregation]( + ("hash_first", "hash_last", "first", "last") +) +col = pc.field +lit = cast("Callable[[NonNestedLiteral], Expr]", pc.scalar) +"""Alias for `pyarrow.compute.scalar`.""" + + +# NOTE: ATOW there are 304 valid function names, 46 can be used for some kind of agg +# Due to expr expansion, it is very likely that we have repeat runs +@functools.lru_cache(maxsize=128) +def can_thread(function_name: str, /) -> bool: + return function_name not in _THREAD_UNSAFE + + +def _parse_into_expr(into: IntoExpr, /, *, str_as_lit: bool = False) -> Expr: + if isinstance(into, pc.Expression): + return into + if isinstance(into, str) and not str_as_lit: + return col(into) + return lit(into) + + +def _parse_into_iter_expr(inputs: Iterable[IntoExpr], /) -> Iterator[Expr]: + for into_expr in inputs: + yield _parse_into_expr(into_expr) + + +def _parse_into_seq_of_expr(inputs: Iterable[IntoExpr], /) -> Seq[Expr]: + return tuple(_parse_into_iter_expr(inputs)) + + +def _parse_all_horizontal(predicates: Seq[Expr], constraints: dict[str, Any], /) -> Expr: + if not constraints and len(predicates) == 1: + return predicates[0] + it = ( + col(name) == _parse_into_expr(v, str_as_lit=True) + for name, v in constraints.items() + ) + return reduce(operator.and_, chain(predicates, it)) + + +def table_source(native: pa.Table, /) -> Decl: + """Start building a logical plan, using `native` as the source table. + + All calls to `collect` must use this as the first `Declaration`. + """ + return Decl("table_source", options=pac.TableSourceNodeOptions(native)) + + +def _aggregate(aggs: Iterable[AggSpec], /, keys: Iterable[Field] | None = None) -> Decl: + # NOTE: See https://github.com/apache/arrow/blob/9b96bdbc733d62f0375a2b1b9806132abc19cd3f/python/pyarrow/_acero.pyx#L167-L192 + aggs_: Incomplete = aggs + keys_: Incomplete = keys + return Decl("aggregate", pac.AggregateNodeOptions(aggs_, keys=keys_)) + + +def aggregate(aggs: Iterable[AggSpec], /) -> Decl: + """May only use [Scalar aggregate] functions. + + [Scalar aggregate]: https://arrow.apache.org/docs/cpp/compute.html#aggregations + """ + return _aggregate(aggs) + + +def group_by(keys: Iterable[Field], aggs: Iterable[AggSpec], /) -> Decl: + """May only use [Hash aggregate] functions, requires grouping. + + [Hash aggregate]: https://arrow.apache.org/docs/cpp/compute.html#grouped-aggregations-group-by + """ + return _aggregate(aggs, keys=keys) + + +def filter(*predicates: Expr, **constraints: IntoExpr) -> Decl: + expr = _parse_all_horizontal(predicates, constraints) + return Decl("filter", options=pac.FilterNodeOptions(expr)) + + +def select_names(column_names: OneOrIterable[str], *more_names: str) -> Decl: + """`select` where all args are column names.""" + if not more_names: + if isinstance(column_names, str): + return _project((col(column_names),), (column_names,)) + more_names = tuple(column_names) + elif isinstance(column_names, str): + more_names = column_names, *more_names + else: + msg = f"Passing both iterable and positional inputs is not supported.\n{column_names=}\n{more_names=}" + raise NotImplementedError(msg) + return _project([col(name) for name in more_names], more_names) + + +def _project(exprs: Collection[Expr], names: Collection[str]) -> Decl: + # NOTE: Both just need to be `Sized` and `Iterable` + exprs_: Incomplete = exprs + names_: Incomplete = names + return Decl("project", options=pac.ProjectNodeOptions(exprs_, names_)) + + +def project(**named_exprs: IntoExpr) -> Decl: + """Similar to `select`, but more rigid. + + Arguments: + **named_exprs: Inputs composed of any combination of + + - Column names or `pc.field` + - Python literals or `pc.scalar` (for `str` literals) + - [Scalar functions] applied to the above + + Notes: + - [`Expression`]s have no concept of aliasing, therefore, all inputs must be `**named_exprs`. + - Always returns a table with the same length, scalar literals are broadcast unconditionally. + + [`Expression`]: https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html + [Scalar functions]: https://arrow.apache.org/docs/cpp/compute.html#element-wise-scalar-functions + """ + exprs = _parse_into_seq_of_expr(named_exprs.values()) + return _project(names=named_exprs.keys(), exprs=exprs) + + +def _order_by( + sort_keys: Iterable[tuple[str, Order]] = (), + *, + null_placement: NullPlacement = "at_end", +) -> Decl: + # NOTE: There's no runtime type checking of `sort_keys` wrt shape + # Just need to be `Iterable`and unpack like a 2-tuple + # https://github.com/apache/arrow/blob/9b96bdbc733d62f0375a2b1b9806132abc19cd3f/python/pyarrow/_compute.pyx#L77-L88 + keys: Incomplete = sort_keys + return Decl("order_by", pac.OrderByNodeOptions(keys, null_placement=null_placement)) + + +# TODO @dangotbanned: Utilize `SortMultipleOptions.to_arrow_acero` +def sort_by(*args: Any, **kwds: Any) -> Decl: + msg = "Should convert from polars args -> use `_order_by" + raise NotImplementedError(msg) + + +def collect(*declarations: Decl, use_threads: bool = True) -> pa.Table: + """Compose and evaluate a logical plan. + + Arguments: + *declarations: One or more `Declaration` nodes to execute as a pipeline. + **The first node must be a `table_source`**. + use_threads: Pass `False` if `declarations` contains any order-dependent aggregation(s). + """ + # NOTE: stubs + docs say `list`, but impl allows any iterable + decls: Incomplete = declarations + return Decl.from_sequence(decls).to_table(use_threads=use_threads) + + +def group_by_table( + native: pa.Table, keys: Iterable[Field], aggs: Iterable[AggSpec] +) -> pa.Table: + """Adapted from [`pa.TableGroupBy.aggregate`] and [`pa.acero._group_by`]. + + - Backport of [apache/arrow#36768]. + - `first` and `last` were [broken in `pyarrow==13`]. + - Also allows us to specify our own aliases for aggregate output columns. + - Fixes [narwhals-dev/narwhals#1612] + + [`pa.TableGroupBy.aggregate`]: https://github.com/apache/arrow/blob/0e7e70cfdef4efa287495272649c071a700c34fa/python/pyarrow/table.pxi#L6600-L6626 + [`pa.acero._group_by`]: https://github.com/apache/arrow/blob/0e7e70cfdef4efa287495272649c071a700c34fa/python/pyarrow/acero.py#L412-L418 + [apache/arrow#36768]: https://github.com/apache/arrow/pull/36768 + [broken in `pyarrow==13`]: https://github.com/apache/arrow/issues/36709 + [narwhals-dev/narwhals#1612]: https://github.com/narwhals-dev/narwhals/issues/1612 + """ + aggs = tuple(aggs) + use_threads = all(spec.use_threads for spec in aggs) + return collect(table_source(native), group_by(keys, aggs), use_threads=use_threads) + + +def filter_table(native: pa.Table, *predicates: Expr, **constraints: Any) -> pa.Table: + """Selects rows where all expressions evaluate to True. + + Arguments: + native: source table + predicates: [`Expression`]s which must all have a return type of boolean. + constraints: Column filters; use `name = value` to filter columns by the supplied value. + + Notes: + - Uses logic similar to [`polars`] for an AND-reduction + - Elements where the filter does not evaluate to True are discarded, **including nulls** + + [`Expression`]: https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html + [`polars`]: https://github.com/pola-rs/polars/blob/d0914d416ce4e1dfcb5f946875ffd1181e31c493/py-polars/polars/_utils/parse/expr.py#L199-L242 + """ + return collect(table_source(native), filter(*predicates, **constraints)) + + +def select_names_table( + native: pa.Table, column_names: OneOrIterable[str], *more_names: str +) -> pa.Table: + return collect(table_source(native), select_names(column_names, *more_names)) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 27a02bc2ed..b588b59180 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -1,15 +1,21 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, overload +import operator +from functools import reduce +from itertools import chain +from typing import TYPE_CHECKING, Any, Literal, cast, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._plan.arrow import functions as fn +from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy from narwhals._plan.arrow.series import ArrowSeries as Series +from narwhals._plan.expressions import NamedIR from narwhals._plan.protocols import EagerDataFrame, namespace -from narwhals._utils import Version +from narwhals._plan.typing import Seq +from narwhals._utils import Version, parse_columns_to_drop from narwhals.schema import Schema if TYPE_CHECKING: @@ -29,11 +35,18 @@ class ArrowDataFrame(EagerDataFrame[Series, "pa.Table", "ChunkedArrayAny"]): + _native: pa.Table + _version: Version + def __narwhals_namespace__(self) -> ArrowNamespace: from narwhals._plan.arrow.namespace import ArrowNamespace return ArrowNamespace(self._version) + @property + def _group_by(self) -> type[GroupBy]: + return GroupBy + @property def columns(self) -> list[str]: return self.native.column_names @@ -95,10 +108,26 @@ def get_column(self, name: str) -> Series: chunked = self.native.column(name) return Series.from_native(chunked, name, version=self.version) - def drop(self, columns: Sequence[str]) -> Self: - to_drop = list(columns) + def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: + to_drop = parse_columns_to_drop(self, columns, strict=strict) return self._with_native(self.native.drop(to_drop)) + def drop_nulls(self, subset: Sequence[str] | None) -> Self: + if subset is None: + native = self.native.drop_null() + else: + to_drop = reduce(operator.or_, (pc.field(name).is_null() for name in subset)) + native = self.native.filter(~to_drop) + return self._with_native(native) + + def rename(self, mapping: Mapping[str, str]) -> Self: + names: dict[str, str] | list[str] + if fn.BACKEND_VERSION >= (17,): + names = cast("dict[str, str]", mapping) + else: # pragma: no cover + names = [mapping.get(c, c) for c in self.columns] + return self._with_native(self.native.rename_columns(names)) + # NOTE: Use instead of `with_columns` for trivial cases def _with_columns(self, exprs: Iterable[Expr | Scalar], /) -> Self: native = self.native @@ -113,3 +142,10 @@ def _with_columns(self, exprs: Iterable[Expr | Scalar], /) -> Self: else: native = native.append_column(name, chunked) return self._with_native(native) + + def select_names(self, *column_names: str) -> Self: + return self._with_native(self.native.select(list(column_names))) + + def row(self, index: int) -> tuple[Any, ...]: + row = self.native.slice(index, 1) + return tuple(chain.from_iterable(row.to_pydict().values())) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 57ec5196d6..b547ed57fa 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -35,6 +35,7 @@ Count, First, Last, + Len, Max, Mean, Median, @@ -54,7 +55,7 @@ Not, ) from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr - from narwhals._plan.expressions.functions import FillNull, Pow + from narwhals._plan.expressions.functions import Abs, FillNull, Pow from narwhals.typing import Into1DArray, IntoDType, PythonLiteral Expr: TypeAlias = "ArrowExpr" @@ -111,6 +112,9 @@ def func(node: FunctionExpr[Any], frame: Frame, name: str) -> StoresNativeT_co: return func + def abs(self, node: FunctionExpr[Abs], frame: Frame, name: str) -> StoresNativeT_co: + return self._unary_function(pc.abs)(node, frame, name) + def not_(self, node: FunctionExpr[Not], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(pc.invert)(node, frame, name) @@ -296,6 +300,10 @@ def count(self, node: Count, frame: Frame, name: str) -> Scalar: result = fn.count(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) + def len(self, node: Len, frame: Frame, name: str) -> Scalar: + result = fn.count(self._dispatch_expr(node.expr, frame, name).native, mode="all") + return self._with_native(result, name) + def max(self, node: Max, frame: Frame, name: str) -> Scalar: result: NativeScalar = fn.max_(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) @@ -460,6 +468,9 @@ def count(self, node: Count, frame: Frame, name: str) -> Scalar: native = node.expr.dispatch(self, frame, name).native return self._with_native(pa.scalar(1 if native.is_valid else 0), name) + def len(self, node: Len, frame: Frame, name: str) -> Scalar: + return self._with_native(pa.scalar(1), name) + filter = not_implemented() over = not_implemented() over_ordered = not_implemented() diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 7a16404d3d..1fd1942b2c 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -13,11 +13,12 @@ chunked_array as _chunked_array, floordiv_compat as floordiv, ) +from narwhals._plan.arrow import options from narwhals._plan.expressions import operators as ops from narwhals._utils import Implementation if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Mapping, Sequence + from collections.abc import Iterable, Mapping from typing_extensions import TypeIs @@ -38,17 +39,20 @@ ChunkedOrScalar, ChunkedOrScalarAny, DataType, + DataTypeRemap, DataTypeT, IntegerScalar, IntegerType, + LargeStringType, NativeScalar, Scalar, ScalarAny, ScalarT, StringScalar, + StringType, UnaryFunction, ) - from narwhals.typing import ClosedInterval + from narwhals.typing import ClosedInterval, IntoArrowSchema BACKEND_VERSION = Implementation.PYARROW._backend_version() @@ -133,6 +137,41 @@ def cast( return pc.cast(native, target_type, safe=safe) +def cast_schema( + native: pa.Schema, target_types: DataType | Mapping[str, DataType] | DataTypeRemap +) -> pa.Schema: + if isinstance(target_types, pa.DataType): + return pa.schema((name, target_types) for name in native.names) + if _is_into_pyarrow_schema(target_types): + new_schema = native + for name, dtype in target_types.items(): + index = native.get_field_index(name) + new_schema.set(index, native.field(index).with_type(dtype)) + return new_schema + return pa.schema((fld.name, target_types.get(fld.type, fld.type)) for fld in native) + + +def cast_table( + native: pa.Table, target: DataType | IntoArrowSchema | DataTypeRemap +) -> pa.Table: + s = target if isinstance(target, pa.Schema) else cast_schema(native.schema, target) + return native.cast(s) + + +def has_large_string(data_types: Iterable[DataType], /) -> bool: + return any(pa.types.is_large_string(tp) for tp in data_types) + + +def string_type(data_types: Iterable[DataType] = (), /) -> StringType | LargeStringType: + """Return a native string type, compatible with `data_types`. + + Until [apache/arrow#45717] is resolved, we need to upcast `string` to `large_string` when joining. + + [apache/arrow#45717]: https://github.com/apache/arrow/issues/45717 + """ + return pa.large_string() if has_large_string(data_types) else pa.string() + + def any_(native: Any) -> pa.BooleanScalar: return pc.any(native, min_count=0) @@ -180,21 +219,11 @@ def binary( def concat_str( *arrays: ChunkedArrayAny, separator: str = "", ignore_nulls: bool = False ) -> ChunkedArray[StringScalar]: - fn: Incomplete = pc.binary_join_element_wise - it, sep = _cast_to_comparable_string_types(arrays, separator) - return fn(*it, sep, null_handling="skip" if ignore_nulls else "emit_null") # type: ignore[no-any-return] - - -def _cast_to_comparable_string_types( - arrays: Sequence[ChunkedArrayAny], /, separator: str -) -> tuple[Iterator[ChunkedArray[StringScalar]], StringScalar]: - # Ensure `chunked_arrays` are either all `string` or all `large_string`. - dtype = ( - pa.string() - if not any(pa.types.is_large_string(obj.type) for obj in arrays) - else pa.large_string() - ) - return (obj.cast(dtype) for obj in arrays), pa.scalar(separator, dtype) + dtype = string_type(obj.type for obj in arrays) + it = (obj.cast(dtype) for obj in arrays) + concat: Incomplete = pc.binary_join_element_wise + join = options.join(ignore_nulls=ignore_nulls) + return concat(*it, lit(separator, dtype), options=join) # type: ignore[no-any-return] def int_range( @@ -260,3 +289,11 @@ def is_series(obj: t.Any) -> TypeIs[ArrowSeries]: from narwhals._plan.arrow.series import ArrowSeries return isinstance(obj, ArrowSeries) + + +def _is_into_pyarrow_schema(obj: Mapping[Any, Any]) -> TypeIs[Mapping[str, DataType]]: + return ( + (first := next(iter(obj.items())), None) + and isinstance(first[0], str) + and isinstance(first[1], pa.DataType) + ) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py new file mode 100644 index 0000000000..c878f344ed --- /dev/null +++ b/narwhals/_plan/arrow/group_by.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal + +import pyarrow as pa # ignore-banned-import +import pyarrow.compute as pc # ignore-banned-import + +from narwhals._plan import expressions as ir +from narwhals._plan._guards import is_agg_expr, is_function_expr +from narwhals._plan.arrow import acero, functions as fn, options +from narwhals._plan.common import dispatch_method_name, temp +from narwhals._plan.expressions import aggregation as agg +from narwhals._plan.protocols import EagerDataFrameGroupBy +from narwhals._utils import Implementation +from narwhals.exceptions import InvalidOperationError + +if TYPE_CHECKING: + from collections.abc import Iterator, Mapping + + from typing_extensions import Self, TypeAlias + + from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame + from narwhals._plan.arrow.typing import ChunkedArray + from narwhals._plan.expressions import NamedIR + from narwhals._plan.typing import Seq + +Incomplete: TypeAlias = Any + +# NOTE: Unless stated otherwise, all aggregations have 2 variants: +# - `` (pc.Function.kind == "scalar_aggregate") +# - `hash_` (pc.Function.kind == "hash_aggregate") +SUPPORTED_AGG: Mapping[type[agg.AggExpr], acero.Aggregation] = { + agg.Sum: "hash_sum", + agg.Mean: "hash_mean", + agg.Median: "hash_approximate_median", + agg.Max: "hash_max", + agg.Min: "hash_min", + agg.Std: "hash_stddev", + agg.Var: "hash_variance", + agg.Count: "hash_count", + agg.Len: "hash_count", + agg.NUnique: "hash_count_distinct", + agg.First: "hash_first", + agg.Last: "hash_last", +} +SUPPORTED_IR: Mapping[type[ir.ExprIR], acero.Aggregation] = { + ir.Len: "hash_count_all", + ir.Column: "hash_list", # `hash_aggregate` only +} +SUPPORTED_FUNCTION: Mapping[type[ir.Function], acero.Aggregation] = { + ir.boolean.All: "hash_all", + ir.boolean.Any: "hash_any", + ir.functions.Unique: "hash_distinct", # `hash_aggregate` only +} + +REQUIRES_PYARROW_20: tuple[Literal["kurtosis"], Literal["skew"]] = ("kurtosis", "skew") +"""They don't show in [our version of the stubs], but are possible in [`pyarrow>=20`]. + +[our version of the stubs]: https://github.com/narwhals-dev/narwhals/issues/2124#issuecomment-3191374210 +[`pyarrow>=20`]: https://arrow.apache.org/docs/20.0/python/compute.html#grouped-aggregations +""" + + +class AggSpec: + __slots__ = ("agg", "name", "option", "target") + + def __init__( + self, + target: acero.Target, + agg: acero.Aggregation, + option: acero.Opts = None, + name: acero.OutputName = "", + ) -> None: + self.target = target + self.agg = agg + self.option = option + self.name = name or str(target) + + @property + def use_threads(self) -> bool: + """See https://github.com/apache/arrow/issues/36709.""" + return acero.can_thread(self.agg) + + def __iter__(self) -> Iterator[acero.Target | acero.Aggregation | acero.Opts]: + """Let's us duck-type as a 4-tuple.""" + yield from (self.target, self.agg, self.option, self.name) + + @classmethod + def from_named_ir(cls, named_ir: NamedIR) -> Self: + return cls.from_expr_ir(named_ir.expr, named_ir.name) + + @classmethod + def from_agg_expr(cls, expr: agg.AggExpr, name: acero.OutputName) -> Self: + tp = type(expr) + if not (agg_name := SUPPORTED_AGG.get(tp)): + raise group_by_error(name, expr) + if not isinstance(expr.expr, ir.Column): + raise group_by_error(name, expr, "too complex") + option = ( + options.variance(expr.ddof) + if isinstance(expr, (agg.Std, agg.Var)) + else options.AGG.get(tp) + ) + return cls(expr.expr.name, agg_name, option, name) + + @classmethod + def from_function_expr(cls, expr: ir.FunctionExpr, name: acero.OutputName) -> Self: + tp = type(expr.function) + if not (fn_name := SUPPORTED_FUNCTION.get(tp)): + raise group_by_error(name, expr) + args = expr.input + if not (len(args) == 1 and isinstance(args[0], ir.Column)): + raise group_by_error(name, expr, "too complex") + return cls(args[0].name, fn_name, options.FUNCTION.get(tp), name) + + @classmethod + def from_expr_ir(cls, expr: ir.ExprIR, name: acero.OutputName) -> Self: + if is_agg_expr(expr): + return cls.from_agg_expr(expr, name) + if is_function_expr(expr): + return cls.from_function_expr(expr, name) + if not isinstance(expr, (ir.Len, ir.Column)): + raise group_by_error(name, expr) + fn_name = SUPPORTED_IR[type(expr)] + return cls(expr.name if isinstance(expr, ir.Column) else (), fn_name, name=name) + + +def group_by_error( + column_name: str, expr: ir.ExprIR, reason: Literal["too complex"] | None = None +) -> InvalidOperationError: + backend = Implementation.PYARROW + if reason == "too complex": + msg = "Non-trivial complex aggregation found, which" + else: + if is_function_expr(expr): + func_name = repr(expr.function) + else: + func_name = dispatch_method_name(type(expr)) + msg = f"`{func_name}()`" + msg = f"{msg} is not supported in a `group_by` context for {backend!r}:\n{column_name}={expr!r}" + return InvalidOperationError(msg) + + +def concat_str(native: pa.Table, *, separator: str = "") -> ChunkedArray: + dtype = fn.string_type(native.schema.types) + it = fn.cast_table(native, dtype).itercolumns() + concat: Incomplete = pc.binary_join_element_wise + join = options.join_replace_nulls() + return concat(*it, fn.lit(separator, dtype), options=join) # type: ignore[no-any-return] + + +class ArrowGroupBy(EagerDataFrameGroupBy["Frame"]): + _df: Frame + _keys: Seq[NamedIR] + _key_names: Seq[str] + _key_names_original: Seq[str] + + @property + def compliant(self) -> Frame: + return self._df + + def __iter__(self) -> Iterator[tuple[Any, Frame]]: + temp_name = temp.column_name(self.compliant) + native = self.compliant.native + composite_values = concat_str(acero.select_names_table(native, self.key_names)) + re_keyed = native.add_column(0, temp_name, composite_values) + from_native = self.compliant._with_native + for v in composite_values.unique(): + t = from_native(acero.filter_table(re_keyed, pc.field(temp_name) == v)) + yield ( + t.select_names(*self.key_names).row(0), + t.select_names(*self._column_names_original), + ) + + def agg(self, irs: Seq[NamedIR]) -> Frame: + compliant = self.compliant + native = compliant.native + key_names = self.key_names + specs = (AggSpec.from_named_ir(e) for e in irs) + result = compliant._with_native(acero.group_by_table(native, key_names, specs)) + if original := self._key_names_original: + return result.rename(dict(zip(key_names, original))) + return result diff --git a/narwhals/_plan/arrow/options.py b/narwhals/_plan/arrow/options.py new file mode 100644 index 0000000000..8998b288a2 --- /dev/null +++ b/narwhals/_plan/arrow/options.py @@ -0,0 +1,105 @@ +"""Cached `pyarrow.compute` options classes, using `polars` defaults. + +Important: + `AGG` and `FUNCTION` mappings are constructed on first `__getattr__` access. +""" + +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING, Any, Literal + +import pyarrow.compute as pc # ignore-banned-import + +if TYPE_CHECKING: + from collections.abc import Mapping + + from narwhals._plan import expressions as ir + from narwhals._plan.arrow import acero + from narwhals._plan.expressions import aggregation as agg + + +__all__ = [ + "AGG", + "FUNCTION", + "count", + "join", + "join_replace_nulls", + "scalar_aggregate", + "variance", +] + + +AGG: Mapping[type[agg.AggExpr], acero.AggregateOptions] +FUNCTION: Mapping[type[ir.Function], acero.AggregateOptions] + + +@functools.cache +def count( + mode: Literal["only_valid", "only_null", "all"] = "only_valid", +) -> pc.CountOptions: + return pc.CountOptions(mode) + + +# pyarrow defaults to ignore_nulls +# polars doesn't mention +@functools.cache +def variance( + ddof: int = 1, *, ignore_nulls: bool = True, min_count: int = 0 +) -> pc.VarianceOptions: + return pc.VarianceOptions(ddof=ddof, skip_nulls=ignore_nulls, min_count=min_count) + + +@functools.cache +def scalar_aggregate( + *, ignore_nulls: bool = False, min_count: int = 0 +) -> pc.ScalarAggregateOptions: + return pc.ScalarAggregateOptions(skip_nulls=ignore_nulls, min_count=min_count) + + +@functools.cache +def join(*, ignore_nulls: bool = False) -> pc.JoinOptions: + return pc.JoinOptions(null_handling="skip" if ignore_nulls else "emit_null") + + +@functools.cache +def join_replace_nulls(*, replacement: str = "__nw_null_value__") -> pc.JoinOptions: + return pc.JoinOptions(null_handling="replace", null_replacement=replacement) + + +def _generate_agg() -> Mapping[type[agg.AggExpr], acero.AggregateOptions]: + from narwhals._plan.expressions import aggregation as agg + + return { + agg.NUnique: count("all"), + agg.Len: count("all"), + agg.Count: count("only_valid"), + agg.First: scalar_aggregate(), + agg.Last: scalar_aggregate(), + } + + +def _generate_function() -> Mapping[type[ir.Function], acero.AggregateOptions]: + from narwhals._plan.expressions import boolean + + return { + boolean.All: scalar_aggregate(ignore_nulls=True), + boolean.Any: scalar_aggregate(ignore_nulls=True), + } + + +# ruff: noqa: PLW0603 +# NOTE: Using globals for lazy-loading cache +if not TYPE_CHECKING: + + def __getattr__(name: str) -> Any: + if name == "AGG": + global AGG + AGG = _generate_agg() + return AGG + if name == "FUNCTION": + global FUNCTION + FUNCTION = _generate_function() + return FUNCTION + msg = f"module {__name__!r} has no attribute {name!r}" + raise AttributeError(msg) diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index e633e6560e..e11e9d45c1 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -1,7 +1,7 @@ from __future__ import annotations -from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Protocol, overload +from collections.abc import Callable, Mapping +from typing import TYPE_CHECKING, Any, Literal, Protocol, overload from narwhals._typing_compat import TypeVar from narwhals._utils import _StoresNative as StoresNative @@ -14,8 +14,8 @@ Int16Type, Int32Type, Int64Type, - LargeStringType, - StringType, + LargeStringType as LargeStringType, # noqa: PLC0414 + StringType as StringType, # noqa: PLC0414 Uint8Type, Uint16Type, Uint32Type, @@ -117,3 +117,5 @@ class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Prot NativeScalar: TypeAlias = ScalarAny BinOp: TypeAlias = Callable[..., ChunkedOrScalarAny] StoresNativeT_co = TypeVar("StoresNativeT_co", bound=StoresNative[Any], covariant=True) +DataTypeRemap: TypeAlias = Mapping[DataType, DataType] +NullPlacement: TypeAlias = Literal["at_start", "at_end"] diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 0b4267f214..defe398f95 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -6,15 +6,21 @@ from collections.abc import Iterable from decimal import Decimal from operator import attrgetter +from secrets import token_hex from typing import TYPE_CHECKING, cast, overload from narwhals._plan._guards import is_iterable_reject +from narwhals._utils import _hasattr_static from narwhals.dtypes import DType +from narwhals.exceptions import NarwhalsError from narwhals.utils import Version if TYPE_CHECKING: + import reprlib from collections.abc import Iterator - from typing import Any, Callable, TypeVar + from typing import Any, Callable, ClassVar, TypeVar + + from typing_extensions import TypeIs from narwhals._plan.typing import ( DTypeT, @@ -23,6 +29,7 @@ NonNestedDTypeT, OneOrIterable, ) + from narwhals._utils import _StoresColumns from narwhals.typing import NonNestedDType, NonNestedLiteral T = TypeVar("T") @@ -115,3 +122,161 @@ def flatten_hash_safe(iterable: Iterable[OneOrIterable[T]], /) -> Iterator[T]: yield from flatten_hash_safe(element) else: yield element # type: ignore[misc] + + +def _has_columns(obj: Any) -> TypeIs[_StoresColumns]: + return _hasattr_static(obj, "columns") + + +def _reprlib_repr_backport() -> reprlib.Repr: + # 3.12 added `indent` https://github.com/python/cpython/issues/92734 + # but also a useful constructor https://github.com/python/cpython/issues/94343 + import reprlib + + if sys.version_info >= (3, 12): + return reprlib.Repr(indent=4, maxlist=10) + else: # pragma: no cover # noqa: RET505 + obj = reprlib.Repr() + obj.maxlist = 10 + return obj + + +class temp: # noqa: N801 + """Temporary mini namespace for temporary utils.""" + + _MAX_ITERATIONS: ClassVar[int] = 100 + _MIN_RANDOM_CHARS: ClassVar[int] = 4 + + @classmethod + def column_name( + cls, + source: _StoresColumns | Iterable[str], + /, + *, + prefix: str = "nw", + n_chars: int = 16, + ) -> str: + """Generate a single, unique column name that is not present in `source`. + + Arguments: + source: Source of columns to check for uniqueness. + prefix: Prepend the name with this string. + n_chars: Total number of characters used by the name (including `prefix`). + + Examples: + >>> import narwhals as nw + >>> from narwhals._plan.common import temp + >>> columns = "abc", "xyz" + >>> temp.column_name(columns) # doctest: +SKIP + 'nwf65daf7ceb3c2f' + + Limit the number of characters that the name uses + + >>> temp.column_name(columns, n_chars=8) # doctest: +SKIP + 'nw388b5d' + + Make the name easier to trace back + + >>> temp.column_name(columns, prefix="_its_a_me_") # doctest: +SKIP + '_its_a_me_0ea2b0' + + Pass in a `DataFrame` directly, and let us get the columns for you + + >>> df = nw.from_dict({"foo": [1, 2], "bar": [6.0, 7.0]}, backend="polars") + >>> df.with_row_index(temp.column_name(df, prefix="idx_")) # doctest: +SKIP + ┌────────────────────────────────┐ + | Narwhals DataFrame | + |--------------------------------| + |shape: (2, 3) | + |┌──────────────────┬─────┬─────┐| + |│ idx_bae5e1b22963 ┆ foo ┆ bar │| + |│ --- ┆ --- ┆ --- │| + |│ u32 ┆ i64 ┆ f64 │| + |╞══════════════════╪═════╪═════╡| + |│ 0 ┆ 1 ┆ 6.0 │| + |│ 1 ┆ 2 ┆ 7.0 │| + |└──────────────────┴─────┴─────┘| + └────────────────────────────────┘ + """ + columns = cls._into_columns(source) + prefix, n_bytes = cls._parse_prefix_n_bytes(prefix, n_chars) + for _ in range(cls._MAX_ITERATIONS): + token = f"{prefix}{token_hex(n_bytes)}" + if token not in columns: + return token + raise cls._failed_generation_error(columns, n_chars) + + # TODO @dangotbanned: Write examples + @classmethod + def column_names( + cls, + source: _StoresColumns | Iterable[str], + /, + *, + prefix: str = "nw", + n_chars: int = 16, + ) -> Iterator[str]: + """Yields unique column names that are not present in `source`. + + Any column name returned will be unique among those that preceded it. + + Arguments: + source: Source of columns to check for uniqueness. + prefix: Prepend the name with this string. + n_chars: Total number of characters used by the name (including `prefix`). + """ + columns = cls._into_columns(source) + prefix, n_bytes = cls._parse_prefix_n_bytes(prefix, n_chars) + n_failed: int = 0 + while n_failed <= cls._MAX_ITERATIONS: + token = f"{prefix}{token_hex(n_bytes)}" + if token not in columns: + columns.add(token) + n_failed = 0 + yield token + else: + n_failed += 1 + raise cls._failed_generation_error(columns, n_chars) + + @staticmethod + def _into_columns(source: _StoresColumns | Iterable[str], /) -> set[str]: + return set(source.columns if _has_columns(source) else source) + + @classmethod + def _parse_prefix_n_bytes(cls, prefix: str, n_chars: int, /) -> tuple[str, int]: + prefix = prefix or "nw" + if not (available := n_chars - len(prefix)) or available < cls._MIN_RANDOM_CHARS: + raise cls._not_enough_room_error(prefix, n_chars) + return prefix, available // 2 + + @classmethod + def _not_enough_room_error(cls, prefix: str, n_chars: int, /) -> NarwhalsError: + len_prefix = len(prefix) + available_chars = n_chars - len_prefix + if available_chars < 0: + visualize = "" + else: + okay = "✔" * available_chars + bad = "✖" * (cls._MIN_RANDOM_CHARS - available_chars) + visualize = f"\n Preview: '{prefix}{okay}{bad}'" + msg = ( + f"Temporary column name generation requires {len_prefix} characters for the prefix " + f"and at least {cls._MIN_RANDOM_CHARS} more to store random bytes:{visualize}\n\n" + f"Hint: Maybe try\n" + f"- a shorter `prefix` than {prefix!r}?\n" + f"- a higher `n_chars` than {n_chars!r}?" + ) + return NarwhalsError(msg) + + @classmethod + def _failed_generation_error( + cls, columns: Iterable[str], n_chars: int, / + ) -> NarwhalsError: + current = sorted(columns) + truncated = _reprlib_repr_backport().repr(current) + msg = ( + "Was unable to generate a column name with " + f"`{n_chars=}` within {cls._MAX_ITERATIONS} iterations, \n" + f"that was not present in existing ({len(current)}) columns:\n{truncated}" + ) + return NarwhalsError(msg) diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 8f06f1e5c9..8956c33457 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -2,9 +2,10 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload -from narwhals._plan import _expansion, _parse -from narwhals._plan.contexts import ExprContext +from narwhals._plan import _parse +from narwhals._plan._expansion import prepare_projection from narwhals._plan.expr import _parse_sort_by +from narwhals._plan.group_by import GroupBy, Grouped from narwhals._plan.series import Series from narwhals._plan.typing import ( IntoExpr, @@ -18,13 +19,12 @@ from narwhals.schema import Schema if TYPE_CHECKING: + from collections.abc import Sequence + import pyarrow as pa from typing_extensions import Self - from narwhals._plan.expressions import ExprIR, NamedIR from narwhals._plan.protocols import CompliantBaseFrame, CompliantDataFrame - from narwhals._plan.schema import FrozenSchema - from narwhals._plan.typing import Seq from narwhals.typing import NativeFrame @@ -60,27 +60,19 @@ def _from_compliant(cls, compliant: CompliantBaseFrame[Any, NativeFrameT], /) -> def to_native(self) -> NativeFrameT: return self._compliant.native - def _project( - self, - exprs: tuple[OneOrIterable[IntoExpr], ...], - named_exprs: dict[str, Any], - context: ExprContext, - /, - ) -> tuple[Seq[NamedIR[ExprIR]], FrozenSchema]: - """Temp, while these parts aren't connected, this is easier for testing.""" - irs, schema_frozen, output_names = _expansion.prepare_projection( - _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), self.schema - ) - named_irs = _expansion.into_named_irs(irs, output_names) - return schema_frozen.project(named_irs, context) - def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: - named_irs, _ = self._project(exprs, named_exprs, ExprContext.SELECT) - return self._from_compliant(self._compliant.select(named_irs)) + named_irs, schema = prepare_projection( + _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), schema=self + ) + return self._from_compliant(self._compliant.select(schema.select_irs(named_irs))) def with_columns(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: - named_irs, _ = self._project(exprs, named_exprs, ExprContext.WITH_COLUMNS) - return self._from_compliant(self._compliant.with_columns(named_irs)) + named_irs, schema = prepare_projection( + _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), schema=self + ) + return self._from_compliant( + self._compliant.with_columns(schema.with_columns_irs(named_irs)) + ) def sort( self, @@ -92,10 +84,16 @@ def sort( sort, opts = _parse_sort_by( by, *more_by, descending=descending, nulls_last=nulls_last ) - irs, _, output_names = _expansion.prepare_projection(sort, self.schema) - named_irs = _expansion.into_named_irs(irs, output_names) + named_irs, _ = prepare_projection(sort, schema=self) return self._from_compliant(self._compliant.sort(named_irs, opts)) + def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: + return self._from_compliant(self._compliant.drop(columns, strict=strict)) + + def drop_nulls(self, subset: str | Sequence[str] | None = None) -> Self: + subset = [subset] if isinstance(subset, str) else subset + return self._from_compliant(self._compliant.drop_nulls(subset)) + class DataFrame(BaseFrame[NativeDataFrameT], Generic[NativeDataFrameT, NativeSeriesT]): _compliant: CompliantDataFrame[Any, NativeDataFrameT, NativeSeriesT] @@ -138,3 +136,29 @@ def to_dict( def __len__(self) -> int: return len(self._compliant) + + @overload + def group_by( + self, + *by: OneOrIterable[IntoExpr], + drop_null_keys: Literal[False] = ..., + **named_by: IntoExpr, + ) -> GroupBy[Self]: ... + + @overload + def group_by( + self, *by: OneOrIterable[str], drop_null_keys: Literal[True] + ) -> GroupBy[Self]: ... + + def group_by( + self, + *by: OneOrIterable[IntoExpr], + drop_null_keys: bool = False, + **named_by: IntoExpr, + ) -> GroupBy[Self]: + return Grouped.by(*by, drop_null_keys=drop_null_keys, **named_by).to_group_by( + self + ) + + def row(self, index: int) -> tuple[Any, ...]: + return self._compliant.row(index) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index b0f369bd77..7695c1d92f 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -103,6 +103,9 @@ def exclude(self, *names: OneOrIterable[str]) -> Self: def count(self) -> Self: return self._from_ir(agg.Count(expr=self._ir)) + def len(self) -> Self: + return self._from_ir(agg.Len(expr=self._ir)) + def max(self) -> Self: return self._from_ir(agg.Max(expr=self._ir)) diff --git a/narwhals/_plan/expressions/__init__.py b/narwhals/_plan/expressions/__init__.py index 237ee36e81..4444bbd6be 100644 --- a/narwhals/_plan/expressions/__init__.py +++ b/narwhals/_plan/expressions/__init__.py @@ -5,6 +5,7 @@ NamedIR, SelectorIR, ) +from narwhals._plan._function import Function from narwhals._plan.expressions import ( aggregation, boolean, @@ -60,6 +61,7 @@ "Exclude", "ExprIR", "Filter", + "Function", "FunctionExpr", "IndexColumns", "InvertSelector", diff --git a/narwhals/_plan/expressions/aggregation.py b/narwhals/_plan/expressions/aggregation.py index 263ca300e5..0f26a82c10 100644 --- a/narwhals/_plan/expressions/aggregation.py +++ b/narwhals/_plan/expressions/aggregation.py @@ -33,7 +33,10 @@ def __init__(self, *, expr: ExprIR, **kwds: Any) -> None: # fmt: off -class Count(AggExpr): ... +class Count(AggExpr): + """Non-null count.""" +class Len(AggExpr): + """Null-inclusive count.""" class Max(AggExpr): ... class Mean(AggExpr): ... class Median(AggExpr): ... diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py new file mode 100644 index 0000000000..5e95bd484e --- /dev/null +++ b/narwhals/_plan/group_by.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Generic + +from narwhals._plan._parse import parse_into_seq_of_expr_ir +from narwhals._plan.protocols import GroupByResolver as Resolved, Grouper +from narwhals._plan.typing import DataFrameT + +if TYPE_CHECKING: + from collections.abc import Iterator + + from typing_extensions import Self + + from narwhals._plan.expressions import ExprIR + from narwhals._plan.typing import IntoExpr, OneOrIterable, Seq + + +class GroupBy(Generic[DataFrameT]): + _frame: DataFrameT + _grouper: Grouped + + def __init__(self, frame: DataFrameT, grouper: Grouped, /) -> None: + self._frame = frame + self._grouper = grouper + + def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFrameT: + frame = self._frame + return frame._from_compliant( + self._grouper.agg(*aggs, **named_aggs) + .resolve(frame) + .evaluate(frame._compliant) + ) + + def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: + frame = self._frame + resolver = self._grouper.agg().resolve(frame) + for key, df in frame._compliant.group_by_resolver(resolver): + yield key, frame._from_compliant(df) + + +class Grouped(Grouper["Resolved"]): + """Narwhals-level `GroupBy` builder.""" + + _keys: Seq[ExprIR] + _aggs: Seq[ExprIR] + _drop_null_keys: bool + + @classmethod + def by( + cls, + *by: OneOrIterable[IntoExpr], + drop_null_keys: bool = False, + **named_by: IntoExpr, + ) -> Self: + obj = cls.__new__(cls) + obj._keys = parse_into_seq_of_expr_ir(*by, **named_by) + obj._drop_null_keys = drop_null_keys + return obj + + def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> Self: + self._aggs = parse_into_seq_of_expr_ir(*aggs, **named_aggs) + return self + + @property + def _resolver(self) -> type[Resolved]: + return Resolved + + def to_group_by(self, frame: DataFrameT, /) -> GroupBy[DataFrameT]: + return GroupBy(frame, self) diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 6f77674dff..303d07a097 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -9,10 +9,12 @@ if TYPE_CHECKING: from collections.abc import Iterable, Sequence + import pyarrow.acero import pyarrow.compute as pc from typing_extensions import Self, TypeAlias - from narwhals._plan.typing import Accessor, OneOrIterable, Seq + from narwhals._plan.arrow.typing import NullPlacement + from narwhals._plan.typing import Accessor, OneOrIterable, Order, Seq from narwhals.typing import RankMethod DispatchOrigin: TypeAlias = Literal["expr", "__narwhals_namespace__"] @@ -170,7 +172,7 @@ def to_multiple(self, n_repeat: int = 1, /) -> SortMultipleOptions: nulls: Seq[bool] = (self.nulls_last,) else: desc = tuple(repeat(self.descending, n_repeat)) - nulls = tuple(repeat(self.nulls_last)) + nulls = tuple(repeat(self.nulls_last, n_repeat)) return SortMultipleOptions(descending=desc, nulls_last=nulls) @@ -193,9 +195,9 @@ def parse( nulls = (nulls_last,) if isinstance(nulls_last, bool) else tuple(nulls_last) return SortMultipleOptions(descending=desc, nulls_last=nulls) - def to_arrow(self, by: Sequence[str]) -> pc.SortOptions: - import pyarrow.compute as pc - + def _to_arrow_args( + self, by: Sequence[str] + ) -> tuple[Sequence[tuple[str, Order]], NullPlacement]: first = self.nulls_last[0] if len(self.nulls_last) != 1 and any(x != first for x in self.nulls_last[1:]): msg = f"pyarrow doesn't support multiple values for `nulls_last`, got: {self.nulls_last!r}" @@ -204,12 +206,23 @@ def to_arrow(self, by: Sequence[str]) -> pc.SortOptions: descending: Iterable[bool] = repeat(self.descending[0], len(by)) else: descending = self.descending - sorting: list[tuple[str, Literal["ascending", "descending"]]] = [ + sorting = tuple[tuple[str, "Order"]]( (key, "descending" if desc else "ascending") for key, desc in zip(by, descending) - ] - placement: Literal["at_start", "at_end"] = "at_end" if first else "at_start" - return pc.SortOptions(sort_keys=sorting, null_placement=placement) + ) + return sorting, "at_end" if first else "at_start" + + def to_arrow(self, by: Sequence[str]) -> pc.SortOptions: + import pyarrow.compute as pc + + sort_keys, placement = self._to_arrow_args(by) + return pc.SortOptions(sort_keys=sort_keys, null_placement=placement) + + def to_arrow_acero(self, by: Sequence[str]) -> pyarrow.acero.Declaration: + from narwhals._plan.arrow import acero + + sort_keys, placement = self._to_arrow_args(by) + return acero._order_by(sort_keys, null_placement=placement) class RankOptions(Immutable): diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 11a17eb081..cff5e790e8 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -1,12 +1,24 @@ +"""TODO: Split this module up into `narwhals._plan.compliant.*`.""" + from __future__ import annotations from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized +from itertools import chain from typing import TYPE_CHECKING, Any, Literal, Protocol, overload -from narwhals._plan.common import flatten_hash_safe -from narwhals._plan.typing import NativeDataFrameT, NativeFrameT, NativeSeriesT, Seq +from narwhals._plan._expansion import prepare_projection +from narwhals._plan._parse import parse_into_seq_of_expr_ir +from narwhals._plan.common import flatten_hash_safe, replace, temp +from narwhals._plan.typing import ( + IntoExpr, + NativeDataFrameT, + NativeFrameT, + NativeSeriesT, + Seq, +) from narwhals._typing_compat import TypeVar from narwhals._utils import Version +from narwhals.exceptions import ComputeError if TYPE_CHECKING: from typing_extensions import Self, TypeAlias, TypeIs @@ -15,6 +27,7 @@ from narwhals._plan.dataframe import BaseFrame, DataFrame from narwhals._plan.expressions import ( BinaryExpr, + ExprIR, FunctionExpr, NamedIR, aggregation as agg, @@ -25,6 +38,7 @@ from narwhals._plan.expressions.ranges import IntRange from narwhals._plan.expressions.strings import ConcatStr from narwhals._plan.options import SortMultipleOptions + from narwhals._plan.schema import FrozenSchema, IntoFrozenSchema from narwhals._plan.series import Series from narwhals._plan.typing import OneOrIterable from narwhals.dtypes import DType @@ -49,6 +63,8 @@ ColumnT = TypeVar("ColumnT") ColumnT_co = TypeVar("ColumnT_co", covariant=True) +ResolverT_co = TypeVar("ResolverT_co", bound="GroupByResolver", covariant=True) + ExprAny: TypeAlias = "CompliantExpr[Any, Any]" ScalarAny: TypeAlias = "CompliantScalar[Any, Any]" SeriesAny: TypeAlias = "CompliantSeries[Any]" @@ -69,7 +85,9 @@ SeriesT = TypeVar("SeriesT", bound=SeriesAny) SeriesT_co = TypeVar("SeriesT_co", bound=SeriesAny, covariant=True) FrameT = TypeVar("FrameT", bound=FrameAny) +FrameT_co = TypeVar("FrameT_co", bound=FrameAny, covariant=True) FrameT_contra = TypeVar("FrameT_contra", bound=FrameAny, contravariant=True) +DataFrameT = TypeVar("DataFrameT", bound=DataFrameAny) NamespaceT_co = TypeVar("NamespaceT_co", bound="NamespaceAny", covariant=True) EagerExprT_co = TypeVar("EagerExprT_co", bound=EagerExprAny, covariant=True) @@ -199,6 +217,7 @@ def _with_native(self, native: Any, name: str, /) -> Self: return self.from_native(native, name or self.name, self.version) # series & scalar + def abs(self, node: FunctionExpr[F.Abs], frame: FrameT_contra, name: str) -> Self: ... def cast(self, node: ir.Cast, frame: FrameT_contra, name: str) -> Self: ... def pow(self, node: FunctionExpr[F.Pow], frame: FrameT_contra, name: str) -> Self: ... def not_(self, node: FunctionExpr[Not], frame: FrameT_contra, name: str) -> Self: ... @@ -268,6 +287,9 @@ def quantile( def count( self, node: agg.Count, frame: FrameT_contra, name: str ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def len( + self, node: agg.Len, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... def max( self, node: agg.Max, frame: FrameT_contra, name: str ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... @@ -374,6 +396,10 @@ def count(self, node: agg.Count, frame: FrameT_contra, name: str) -> Self: """Returns 0 if null, else 1.""" ... + def len(self, node: agg.Len, frame: FrameT_contra, name: str) -> Self: + """Returns 1.""" + ... + def sort(self, node: ir.Sort, frame: FrameT_contra, name: str) -> Self: return self._with_evaluated(self._evaluated, name) @@ -531,6 +557,8 @@ class CompliantBaseFrame(StoresVersion, Protocol[ColumnT_co, NativeFrameT]): def __narwhals_namespace__(self) -> Any: ... @property + def _group_by(self) -> type[CompliantGroupBy[Self]]: ... + @property def native(self) -> NativeFrameT: return self._native @@ -553,18 +581,44 @@ def _evaluate_irs( self, nodes: Iterable[NamedIR[ir.ExprIR]], / ) -> Iterator[ColumnT_co]: ... def select(self, irs: Seq[NamedIR]) -> Self: ... + def select_names(self, *column_names: str) -> Self: ... def with_columns(self, irs: Seq[NamedIR]) -> Self: ... def sort(self, by: Seq[NamedIR], options: SortMultipleOptions) -> Self: ... + def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: ... + def drop_nulls(self, subset: Sequence[str] | None) -> Self: ... class CompliantDataFrame( CompliantBaseFrame[SeriesT, NativeDataFrameT], Protocol[SeriesT, NativeDataFrameT, NativeSeriesT], ): + @property + def _group_by(self) -> type[DataFrameGroupBy[Self]]: ... + @property + def _grouper(self) -> type[Grouped]: + return Grouped + @classmethod def from_dict( cls, data: Mapping[str, Any], /, *, schema: IntoSchema | None = None ) -> Self: ... + def group_by_agg( + self, by: OneOrIterable[IntoExpr], aggs: OneOrIterable[IntoExpr], / + ) -> Self: + """Compliant-level `group_by(by).agg(agg)`, allows `Expr`.""" + return self._grouper.by(by).agg(aggs).resolve(self).evaluate(self) + + def group_by_names(self, names: Seq[str], /) -> DataFrameGroupBy[Self]: + """Compliant-level `group_by`, allowing only `str` keys.""" + return self._group_by.by_names(self, names) + + def group_by_resolver(self, resolver: GroupByResolver, /) -> DataFrameGroupBy[Self]: + """Narwhals-level resolved `group_by`. + + `keys`, `aggs` are already parsed and projections planned. + """ + return self._group_by.from_resolver(self, resolver) + def to_narwhals(self) -> DataFrame[NativeDataFrameT, NativeSeriesT]: ... @overload def to_dict(self, *, as_series: Literal[True]) -> dict[str, SeriesT]: ... @@ -579,12 +633,15 @@ def to_dict( ) -> dict[str, SeriesT] | dict[str, list[Any]]: ... def __len__(self) -> int: ... def with_row_index(self, name: str) -> Self: ... + def row(self, index: int) -> tuple[Any, ...]: ... class EagerDataFrame( CompliantDataFrame[SeriesT, NativeDataFrameT, NativeSeriesT], Protocol[SeriesT, NativeDataFrameT, NativeSeriesT], ): + @property + def _group_by(self) -> type[EagerDataFrameGroupBy[Self]]: ... def __narwhals_namespace__(self) -> EagerNamespace[Self, SeriesT, Any, Any]: ... def select(self, irs: Seq[NamedIR]) -> Self: return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) @@ -650,3 +707,185 @@ def __len__(self) -> int: def to_list(self) -> list[Any]: ... def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: ... + + +class CompliantGroupBy(Protocol[FrameT_co]): + @property + def compliant(self) -> FrameT_co: ... + def agg(self, irs: Seq[NamedIR]) -> FrameT_co: ... + + +class DataFrameGroupBy(CompliantGroupBy[DataFrameT], Protocol[DataFrameT]): + _keys: Seq[NamedIR] + _key_names: Seq[str] + + @classmethod + def from_resolver( + cls, df: DataFrameT, resolver: GroupByResolver, / + ) -> DataFrameGroupBy[DataFrameT]: ... + @classmethod + def by_names( + cls, df: DataFrameT, names: Seq[str], / + ) -> DataFrameGroupBy[DataFrameT]: ... + def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: ... + @property + def keys(self) -> Seq[NamedIR]: + return self._keys + + @property + def key_names(self) -> Seq[str]: + if names := self._key_names: + return names + msg = "at least one key is required in a group_by operation" + raise ComputeError(msg) + + +class EagerDataFrameGroupBy(DataFrameGroupBy[EagerDataFrameT], Protocol[EagerDataFrameT]): + _df: EagerDataFrameT + _key_names: Seq[str] + _key_names_original: Seq[str] + _column_names_original: Seq[str] + + @classmethod + def by_names(cls, df: EagerDataFrameT, names: Seq[str], /) -> Self: + obj = cls.__new__(cls) + obj._df = df + obj._keys = () + obj._key_names = names + obj._key_names_original = () + obj._column_names_original = tuple(df.columns) + return obj + + @classmethod + def from_resolver( + cls, df: EagerDataFrameT, resolver: GroupByResolver, / + ) -> EagerDataFrameGroupBy[EagerDataFrameT]: + key_names = resolver.key_names + if not resolver.requires_projection(): + df = df.drop_nulls(key_names) if resolver._drop_null_keys else df + return cls.by_names(df, key_names) + obj = cls.__new__(cls) + unique_names = temp.column_names(chain(key_names, df.columns)) + safe_keys = tuple( + replace(key, name=name) for key, name in zip(resolver.keys, unique_names) + ) + obj._df = df.with_columns(resolver._schema_in.with_columns_irs(safe_keys)) + obj._keys = safe_keys + obj._key_names = tuple(e.name for e in safe_keys) + obj._key_names_original = key_names + obj._column_names_original = resolver._schema_in.names + return obj + + +class Grouper(Protocol[ResolverT_co]): + """`GroupBy` helper for collecting and forwarding `Expr`s for projection. + + - Uses `Expr` everywhere (no need to duplicate layers) + - Resolver only needs schema (neither needs a frame, but can use one to get `schema`) + """ + + _keys: Seq[ExprIR] + _aggs: Seq[ExprIR] + _drop_null_keys: bool + + @classmethod + def by(cls, *by: OneOrIterable[IntoExpr]) -> Self: + obj = cls.__new__(cls) + obj._keys = parse_into_seq_of_expr_ir(*by) + return obj + + def agg(self, *aggs: OneOrIterable[IntoExpr]) -> Self: + self._aggs = parse_into_seq_of_expr_ir(*aggs) + return self + + @property + def _resolver(self) -> type[ResolverT_co]: ... + + def resolve(self, context: IntoFrozenSchema, /) -> ResolverT_co: + """Project keys and aggs in `context`, expanding all `Expr` -> `NamedIR`.""" + return self._resolver.from_grouper(self, context) + + +class GroupByResolver: + """Narwhals-level `GroupBy` resolver.""" + + _schema_in: FrozenSchema + _keys: Seq[NamedIR] + _aggs: Seq[NamedIR] + _key_names: Seq[str] + _schema: FrozenSchema + _drop_null_keys: bool + + @property + def keys(self) -> Seq[NamedIR]: + return self._keys + + @property + def aggs(self) -> Seq[NamedIR]: + return self._aggs + + @property + def key_names(self) -> Seq[str]: + if names := self._key_names: + return names + if keys := self.keys: + return tuple(e.name for e in keys) + msg = "at least one key is required in a group_by operation" + raise ComputeError(msg) + + @property + def schema(self) -> FrozenSchema: + return self._schema + + def evaluate(self, frame: DataFrameT) -> DataFrameT: + """Perform the `group_by` on `frame`.""" + return frame.group_by_resolver(self).agg(self.aggs) + + @classmethod + def from_grouper(cls, grouper: Grouper[Self], context: IntoFrozenSchema, /) -> Self: + """Loosely based on [`resolve_group_by`]. + + [`resolve_group_by`]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs#L1125-L1227 + """ + obj = cls.__new__(cls) + keys, schema_in = prepare_projection(grouper._keys, schema=context) + obj._keys, obj._schema_in = keys, schema_in + obj._key_names = tuple(e.name for e in keys) + obj._aggs, _ = prepare_projection(grouper._aggs, obj.key_names, schema=schema_in) + obj._schema = schema_in.select(keys).merge(schema_in.select(obj._aggs)) + obj._drop_null_keys = grouper._drop_null_keys + return obj + + def requires_projection(self, *, allow_aliasing: bool = False) -> bool: + """Return True is group keys contain anything that is not a column selection. + + Notes: + If False is returned, we can just use the resolved key names as a fast-path to group. + + Arguments: + allow_aliasing: If False (default), any aliasing is not considered to be column selection. + """ + if not all(key.is_column(allow_aliasing=allow_aliasing) for key in self.keys): + if self._drop_null_keys: + msg = "drop_null_keys cannot be True when keys contains Expr or Series" + raise NotImplementedError(msg) + return True + return False + + +class Resolved(GroupByResolver): + """Compliant-level `GroupBy` resolver.""" + + _drop_null_keys: bool = False + + +class Grouped(Grouper[Resolved]): + """Compliant-level `GroupBy` builder.""" + + _keys: Seq[ExprIR] + _aggs: Seq[ExprIR] + _drop_null_keys: bool = False + + @property + def _resolver(self) -> type[Resolved]: + return Resolved diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index 4dbf5e6ef3..67433db06b 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -1,28 +1,28 @@ from __future__ import annotations -from collections import deque from collections.abc import Mapping from functools import lru_cache -from itertools import chain, repeat +from itertools import chain from types import MappingProxyType -from typing import TYPE_CHECKING, Any, TypeVar, overload +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload from narwhals._plan._expr_ir import NamedIR from narwhals._plan._immutable import _IMMUTABLE_HASH_NAME, Immutable +from narwhals._utils import _hasattr_static from narwhals.dtypes import Unknown if TYPE_CHECKING: from collections.abc import ItemsView, Iterator, KeysView, ValuesView - from typing_extensions import TypeAlias + from typing_extensions import Never, TypeAlias, TypeIs - from narwhals._plan.contexts import ExprContext from narwhals._plan.typing import Seq from narwhals.dtypes import DType + from narwhals.typing import IntoSchema IntoFrozenSchema: TypeAlias = ( - "Mapping[str, DType] | Iterator[tuple[str, DType]] | FrozenSchema" + "IntoSchema | Iterator[tuple[str, DType]] | FrozenSchema | HasSchema" ) """A schema to freeze, or an already frozen one. @@ -41,16 +41,18 @@ class FrozenSchema(Immutable): __slots__ = ("_mapping",) _mapping: MappingProxyType[str, DType] - def project( - self, exprs: Seq[NamedIR], context: ExprContext - ) -> tuple[Seq[NamedIR], FrozenSchema]: - if context.is_select(): - return exprs, self._select(exprs) - if context.is_with_columns(): - return self._with_columns(exprs) - raise TypeError(context) + def __init_subclass__(cls, *_: Never, **__: Never) -> Never: + msg = f"Cannot subclass {cls.__name__!r}" + raise TypeError(msg) - def _select(self, exprs: Seq[NamedIR]) -> FrozenSchema: + def merge(self, other: FrozenSchema, /) -> FrozenSchema: + """Return a new schema, merging `other` with `self` (see [upstream]). + + [upstream]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-schema/src/schema.rs#L265-L274. + """ + return freeze_schema(self._mapping | other._mapping) + + def select(self, exprs: Seq[NamedIR]) -> FrozenSchema: """Return a new schema, equivalent to performing `df.select(*exprs)`. Arguments: @@ -64,18 +66,24 @@ def _select(self, exprs: Seq[NamedIR]) -> FrozenSchema: default = Unknown() return freeze_schema((name, self.get(name, default)) for name in names) - def _with_columns(self, exprs: Seq[NamedIR]) -> tuple[Seq[NamedIR], FrozenSchema]: - exprs_out = deque[NamedIR]() + def select_irs(self, exprs: Seq[NamedIR]) -> Seq[NamedIR]: + return exprs + + def with_columns(self, exprs: Seq[NamedIR]) -> FrozenSchema: + # similar to `merge`, but preserving known `DType`s + names = (e.name for e in exprs) + default = Unknown() + miss = {name: default for name in names if name not in self} + return freeze_schema(self._mapping | miss) + + def with_columns_irs(self, exprs: Seq[NamedIR]) -> Seq[NamedIR]: + """Required for `_concat_horizontal`-based `with_columns`. + + Fills in any unreferenced columns present in `self`, but not in `exprs` as selections. + """ named: dict[str, NamedIR[Any]] = {e.name: e for e in exprs} - items: IntoFrozenSchema - for name in self: - exprs_out.append(named.pop(name, NamedIR.from_name(name))) - if named: - items = chain(self.items(), zip(named, repeat(Unknown(), len(named)))) - exprs_out.extend(named.values()) - else: - items = self - return tuple(exprs_out), freeze_schema(items) + it = (named.pop(name, NamedIR.from_name(name)) for name in self) + return tuple(chain(it, named.values())) @property def __immutable_hash__(self) -> int: @@ -92,7 +100,9 @@ def names(self) -> FrozenColumns: @staticmethod def _from_mapping(mapping: MappingProxyType[str, DType], /) -> FrozenSchema: - return FrozenSchema(_mapping=mapping) + obj = FrozenSchema.__new__(FrozenSchema) + object.__setattr__(obj, "_mapping", mapping) + return obj @staticmethod def _from_hash_safe(items: _FrozenSchemaHash, /) -> FrozenSchema: @@ -134,6 +144,15 @@ def __repr__(self) -> str: return f"{type(self).__name__}([{nl}{indent}{items}{sep}{nl}])" +class HasSchema(Protocol): + @property + def schema(self) -> IntoSchema: ... + + +def has_schema(obj: Any) -> TypeIs[HasSchema]: + return _hasattr_static(obj, "schema") + + @overload def freeze_schema(mapping: IntoFrozenSchema, /) -> FrozenSchema: ... @overload @@ -143,7 +162,7 @@ def freeze_schema( ) -> FrozenSchema: if isinstance(iterable, FrozenSchema): return iterable - into = iterable or schema + into = iterable.schema if has_schema(iterable) else (iterable or schema) hashable = tuple(into.items() if isinstance(into, Mapping) else into) return _freeze_schema_cache(hashable) diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 0efb81ea81..2a734488a6 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -10,6 +10,7 @@ from narwhals import dtypes from narwhals._plan._expr_ir import ExprIR, NamedIR, SelectorIR from narwhals._plan._function import Function + from narwhals._plan.dataframe import DataFrame from narwhals._plan.expr import Expr from narwhals._plan.expressions import operators as ops from narwhals._plan.expressions.functions import RollingWindow @@ -25,6 +26,7 @@ ) __all__ = [ + "DataFrameT", "FunctionT", "IntoExpr", "IntoExprColumn", @@ -95,7 +97,7 @@ T = TypeVar("T") -Seq: TypeAlias = "tuple[T,...]" +Seq: TypeAlias = tuple[T, ...] """Immutable Sequence. Using instead of `Sequence`, as a `list` can be passed there (can't break immutability promise). @@ -107,3 +109,6 @@ IntoExprColumn: TypeAlias = "Expr | Series[t.Any] | str" IntoExpr: TypeAlias = "NonNestedLiteral | IntoExprColumn" OneOrIterable: TypeAlias = "T | t.Iterable[T]" +OneOrSeq: TypeAlias = t.Union[T, Seq[T]] +DataFrameT = TypeVar("DataFrameT", bound="DataFrame[t.Any, t.Any]") +Order: TypeAlias = t.Literal["ascending", "descending"] diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index ffada70747..7b7113e450 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -398,6 +398,11 @@ def _ids_ir(expr: nwp.Expr | Any) -> str: raises=NotImplementedError, ), ), + pytest.param( + [nwp.col("g").len(), nwp.col("m").last(), nwp.col("h").count()], + {"g": [3], "m": [2], "h": [1]}, + id="len-count-with-nulls", + ), ], ids=_ids_ir, ) @@ -517,6 +522,23 @@ def test_first_last_expr_with_columns( assert_equal_data(result, {"result": expected_broadcast}) +@pytest.mark.parametrize( + ("index", "expected"), [(3, (None, 12, 0.9, 3, 3)), (1, (2, 5, 1.0, 1, 1))] +) +def test_row_is_py_literal( + data_indexed: dict[str, Any], index: int, expected: tuple[PythonLiteral, ...] +) -> None: + frame = nwp.DataFrame.from_native(pa.table(data_indexed)) + result = frame.row(index) + assert all(v is None or isinstance(v, (int, float)) for v in result) + assert result == expected + pytest.importorskip("polars") + import polars as pl + + polars_result = pl.DataFrame(data_indexed).row(index) + assert result == polars_result + + if TYPE_CHECKING: def test_protocol_expr() -> None: diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index a80724ff86..203c39911b 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -16,7 +16,7 @@ from narwhals._plan._parse import parse_into_seq_of_expr_ir from narwhals._plan.schema import freeze_schema from narwhals.exceptions import ColumnNotFoundError, ComputeError, DuplicateError -from tests.plan.utils import assert_expr_ir_equal +from tests.plan.utils import assert_expr_ir_equal, named_ir if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -257,19 +257,28 @@ def test_replace_selector( @pytest.mark.parametrize( ("into_exprs", "expected"), [ - ("a", [nwp.col("a")]), - (nwp.col("b", "c", "d"), [nwp.col("b"), nwp.col("c"), nwp.col("d")]), - (nwp.nth(6), [nwp.col("g")]), - (nwp.nth(9, 8, -5), [nwp.col("j"), nwp.col("i"), nwp.col("p")]), - ( + pytest.param("a", [nwp.col("a")], id="Col"), + pytest.param( + nwp.col("b", "c", "d"), + [nwp.col("b"), nwp.col("c"), nwp.col("d")], + id="Columns", + ), + pytest.param(nwp.nth(6), [nwp.col("g")], id="Nth"), + pytest.param( + nwp.nth(9, 8, -5), + [nwp.col("j"), nwp.col("i"), nwp.col("p")], + id="IndexColumns", + ), + pytest.param( [nwp.nth(2).alias("c again"), nwp.nth(-1, -2).name.to_uppercase()], [ - nwp.col("c").alias("c again"), - nwp.col("u").alias("U"), - nwp.col("s").alias("S"), + named_ir("c again", nwp.col("c")), + named_ir("U", nwp.col("u")), + named_ir("S", nwp.col("s")), ], + id="Nth-Alias-IndexColumns-Uppercase", ), - ( + pytest.param( nwp.all(), [ nwp.col("a"), @@ -293,82 +302,89 @@ def test_replace_selector( nwp.col("s"), nwp.col("u"), ], + id="All", ), - ( + pytest.param( (ndcs.numeric() - ndcs.by_dtype(nw.Float32(), nw.Float64())) .cast(nw.Int64) .mean() .name.suffix("_mean"), [ - nwp.col("a").cast(nw.Int64()).mean().alias("a_mean"), - nwp.col("b").cast(nw.Int64()).mean().alias("b_mean"), - nwp.col("c").cast(nw.Int64()).mean().alias("c_mean"), - nwp.col("d").cast(nw.Int64()).mean().alias("d_mean"), - nwp.col("e").cast(nw.Int64()).mean().alias("e_mean"), - nwp.col("f").cast(nw.Int64()).mean().alias("f_mean"), - nwp.col("g").cast(nw.Int64()).mean().alias("g_mean"), - nwp.col("h").cast(nw.Int64()).mean().alias("h_mean"), + named_ir("a_mean", nwp.col("a").cast(nw.Int64()).mean()), + named_ir("b_mean", nwp.col("b").cast(nw.Int64()).mean()), + named_ir("c_mean", nwp.col("c").cast(nw.Int64()).mean()), + named_ir("d_mean", nwp.col("d").cast(nw.Int64()).mean()), + named_ir("e_mean", nwp.col("e").cast(nw.Int64()).mean()), + named_ir("f_mean", nwp.col("f").cast(nw.Int64()).mean()), + named_ir("g_mean", nwp.col("g").cast(nw.Int64()).mean()), + named_ir("h_mean", nwp.col("h").cast(nw.Int64()).mean()), ], + id="Selector-SUB-Cast-Mean-Suffix", ), - ( + pytest.param( nwp.col("u").alias("1").alias("2").alias("3").alias("4").name.keep(), - # NOTE: Would be nice to rewrite with less intermediate steps - # but retrieving the root name is enough for now - [nwp.col("u").alias("1").alias("2").alias("3").alias("4").alias("u")], + [named_ir("u", nwp.col("u"))], + id="Alias-Etc-Keep", ), - ( + pytest.param( ( (ndcs.numeric() ^ (ndcs.matches(r"[abcdg]") | ndcs.by_name("i", "f"))) * 100 ).name.suffix("_mult_100"), [ - (nwp.col("e") * nwp.lit(100)).alias("e_mult_100"), - (nwp.col("h") * nwp.lit(100)).alias("h_mult_100"), - (nwp.col("j") * nwp.lit(100)).alias("j_mult_100"), + named_ir("e_mult_100", (nwp.col("e") * nwp.lit(100))), + named_ir("h_mult_100", (nwp.col("h") * nwp.lit(100))), + named_ir("j_mult_100", (nwp.col("j") * nwp.lit(100))), ], + id="Selector-XOR-OR-BinaryExpr-Suffix", ), - ( + pytest.param( ndcs.by_dtype(nw.Duration()) .dt.total_minutes() .name.map(lambda nm: f"total_mins: {nm!r} ?"), - [nwp.col("q").dt.total_minutes().alias("total_mins: 'q' ?")], + [named_ir("total_mins: 'q' ?", nwp.col("q").dt.total_minutes())], + id="ByDType-TotalMins-Name-Map", ), - ( + pytest.param( nwp.col("f", "g") .cast(nw.String) .str.starts_with("1") .all() .name.suffix("_all_starts_with_1"), [ - nwp.col("f") - .cast(nw.String) - .str.starts_with("1") - .all() - .alias("f_all_starts_with_1"), - nwp.col("g") - .cast(nw.String) - .str.starts_with("1") - .all() - .alias("g_all_starts_with_1"), + named_ir( + "f_all_starts_with_1", + nwp.col("f").cast(nw.String).str.starts_with("1").all(), + ), + named_ir( + "g_all_starts_with_1", + nwp.col("g").cast(nw.String).str.starts_with("1").all(), + ), ], + id="Cast-StartsWith-All-Suffix", ), - ( + pytest.param( nwp.col("a", "b") .first() .over("c", "e", order_by="d") .name.suffix("_first_over_part_order_1"), [ - nwp.col("a") - .first() - .over(nwp.col("c"), nwp.col("e"), order_by=[nwp.col("d")]) - .alias("a_first_over_part_order_1"), - nwp.col("b") - .first() - .over(nwp.col("c"), nwp.col("e"), order_by=[nwp.col("d")]) - .alias("b_first_over_part_order_1"), + named_ir( + "a_first_over_part_order_1", + nwp.col("a") + .first() + .over(nwp.col("c"), nwp.col("e"), order_by=[nwp.col("d")]), + ), + named_ir( + "b_first_over_part_order_1", + nwp.col("b") + .first() + .over(nwp.col("c"), nwp.col("e"), order_by=[nwp.col("d")]), + ), ], + id="First-Over-Partitioned-Ordered-Suffix", ), - ( + pytest.param( nwp.exclude(BIG_EXCLUDE), [ nwp.col("c"), @@ -379,42 +395,48 @@ def test_replace_selector( nwp.col("i"), nwp.col("j"), ], + id="Exclude", ), - ( + pytest.param( nwp.exclude(BIG_EXCLUDE).name.suffix("_2"), [ - nwp.col("c").alias("c_2"), - nwp.col("d").alias("d_2"), - nwp.col("f").alias("f_2"), - nwp.col("g").alias("g_2"), - nwp.col("h").alias("h_2"), - nwp.col("i").alias("i_2"), - nwp.col("j").alias("j_2"), + named_ir("c_2", nwp.col("c")), + named_ir("d_2", nwp.col("d")), + named_ir("f_2", nwp.col("f")), + named_ir("g_2", nwp.col("g")), + named_ir("h_2", nwp.col("h")), + named_ir("i_2", nwp.col("i")), + named_ir("j_2", nwp.col("j")), ], + id="Exclude-Suffix", ), - ( + pytest.param( nwp.col("c").alias("c_min_over_order_by").min().over(order_by=ndcs.string()), [ - nwp.col("c") - .alias("c_min_over_order_by") - .min() - .over(order_by=[nwp.col("k")]) + named_ir( + "c_min_over_order_by", + nwp.col("c").min().over(order_by=[nwp.col("k")]), + ) ], + id="Alias-Min-Over-Order-By-Selector", ), pytest.param( (ndcs.by_name("a", "b", "c") / nwp.col("e").first()) .over("g", "f", order_by="f") .name.prefix("hi_"), [ - (nwp.col("a") / nwp.col("e").first()) - .over("g", "f", order_by="f") - .alias("hi_a"), - (nwp.col("b") / nwp.col("e").first()) - .over("g", "f", order_by="f") - .alias("hi_b"), - (nwp.col("c") / nwp.col("e").first()) - .over("g", "f", order_by="f") - .alias("hi_c"), + named_ir( + "hi_a", + (nwp.col("a") / nwp.col("e").first()).over("g", "f", order_by="f"), + ), + named_ir( + "hi_b", + (nwp.col("b") / nwp.col("e").first()).over("g", "f", order_by="f"), + ), + named_ir( + "hi_c", + (nwp.col("c") / nwp.col("e").first()).over("g", "f", order_by="f"), + ), ], id="Selector-BinaryExpr-Over-Prefix", ), @@ -426,7 +448,7 @@ def test_prepare_projection( schema_1: dict[str, DType], ) -> None: irs_in = parse_into_seq_of_expr_ir(into_exprs) - actual, _, _ = prepare_projection(irs_in, schema_1) + actual, _ = prepare_projection(irs_in, schema=schema_1) assert len(actual) == len(expected) for lhs, rhs in zip(actual, expected): assert_expr_ir_equal(lhs, rhs) @@ -451,7 +473,7 @@ def test_prepare_projection_duplicate(expr: nwp.Expr, schema_1: dict[str, DType] irs = parse_into_seq_of_expr_ir(expr.alias("dupe")) pattern = re.compile(r"\.alias\(.dupe.\)") with pytest.raises(DuplicateError, match=pattern): - prepare_projection(irs, schema_1) + prepare_projection(irs, schema=schema_1) @pytest.mark.parametrize( @@ -517,7 +539,7 @@ def test_prepare_projection_column_not_found( pattern = re.compile(rf"not found: {re.escape(repr(missing))}") irs = parse_into_seq_of_expr_ir(into_exprs) with pytest.raises(ColumnNotFoundError, match=pattern): - prepare_projection(irs, schema_1) + prepare_projection(irs, schema=schema_1) @pytest.mark.parametrize( @@ -554,15 +576,15 @@ def test_prepare_projection_horizontal_alias( expr = function(into_exprs) alias_1 = expr.alias("alias(x1)") irs = parse_into_seq_of_expr_ir(alias_1) - out_irs, _, _ = prepare_projection(irs, schema_1) + out_irs, _ = prepare_projection(irs, schema=schema_1) assert len(out_irs) == 1 - assert out_irs[0] == function("a", "b", "c").alias("alias(x1)")._ir + assert out_irs[0] == named_ir("alias(x1)", function("a", "b", "c")) alias_2 = alias_1.alias("alias(x2)") irs = parse_into_seq_of_expr_ir(alias_2) - out_irs, _, _ = prepare_projection(irs, schema_1) + out_irs, _ = prepare_projection(irs, schema=schema_1) assert len(out_irs) == 1 - assert out_irs[0] == function("a", "b", "c").alias("alias(x1)").alias("alias(x2)")._ir + assert out_irs[0] == named_ir("alias(x2)", function("a", "b", "c")) @pytest.mark.parametrize( @@ -574,4 +596,4 @@ def test_prepare_projection_index_error( irs = parse_into_seq_of_expr_ir(into_exprs) pattern = re.compile(r"invalid.+index.+nth", re.DOTALL | re.IGNORECASE) with pytest.raises(ComputeError, match=pattern): - prepare_projection(irs, schema_1) + prepare_projection(irs, schema=schema_1) diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index bf810aa176..455fecd114 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -14,7 +14,7 @@ ) from narwhals._plan.expressions.window import Over from narwhals.exceptions import InvalidOperationError -from tests.plan.utils import assert_expr_ir_equal +from tests.plan.utils import assert_expr_ir_equal, named_ir if TYPE_CHECKING: from narwhals._plan.typing import IntoExpr @@ -79,11 +79,6 @@ def test_rewrite_elementwise_over_multiple(schema_2: dict[str, DType]) -> None: assert_expr_ir_equal(lhs, rhs) -def named_ir(name: str, expr: nwp.Expr | ir.ExprIR, /) -> ir.NamedIR[ir.ExprIR]: - """Helper constructor for test compare.""" - return ir.NamedIR(expr=expr._ir if isinstance(expr, nwp.Expr) else expr, name=name) - - def test_rewrite_elementwise_over_complex(schema_2: dict[str, DType]) -> None: expected = ( named_ir("a", nwp.col("a")), diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py new file mode 100644 index 0000000000..2b60c118db --- /dev/null +++ b/tests/plan/group_by_test.py @@ -0,0 +1,726 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Any + +import pytest + +import narwhals as nw +from narwhals import _plan as nwp +from narwhals._plan import selectors as npcs +from narwhals.exceptions import InvalidOperationError +from tests.utils import PYARROW_VERSION, assert_equal_data as _assert_equal_data + +pytest.importorskip("pyarrow") + + +import pyarrow as pa + +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + from narwhals._plan.typing import IntoExpr + + +def dataframe(data: dict[str, Any], /) -> nwp.DataFrame[Any, Any]: + return nwp.DataFrame.from_native(pa.table(data)) + + +def assert_equal_data(result: nwp.DataFrame, expected: Mapping[str, Any]) -> None: + _assert_equal_data(result.to_dict(as_series=False), expected) + + +def test_group_by_iter() -> None: + data = {"a": [1, 1, 3], "b": [4, 4, 6], "c": [7.0, 8.0, 9.0]} + df = dataframe(data) + expected_keys: list[tuple[int, ...]] = [(1,), (3,)] + keys = [] + for key, sub_df in df.group_by("a"): + if key == (1,): + expected = {"a": [1, 1], "b": [4, 4], "c": [7.0, 8.0]} + assert_equal_data(sub_df, expected) + assert isinstance(sub_df, nwp.DataFrame) + keys.append(key) + assert sorted(keys) == sorted(expected_keys) + expected_keys = [(1, 4), (3, 6)] + keys = [key for key, _ in df.group_by("a", "b")] + assert sorted(keys) == sorted(expected_keys) + keys = [key for key, _ in df.group_by("a", "b")] + assert sorted(keys) == sorted(expected_keys) + + +def test_group_by_nw_all() -> None: + df = dataframe({"a": [1, 1, 2], "b": [4, 5, 6], "c": [7, 8, 9]}) + result = df.group_by("a").agg(nwp.all().sum()).sort("a") + expected = {"a": [1, 2], "b": [9, 6], "c": [15, 9]} + assert_equal_data(result, expected) + result = df.group_by("a").agg(nwp.all().sum().name.suffix("_sum")).sort("a") + expected = {"a": [1, 2], "b_sum": [9, 6], "c_sum": [15, 9]} + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("attr", "expected"), + [ + ("sum", {"a": [1, 2], "b": [3, 3]}), + ("mean", {"a": [1, 2], "b": [1.5, 3]}), + ("max", {"a": [1, 2], "b": [2, 3]}), + ("min", {"a": [1, 2], "b": [1, 3]}), + ("std", {"a": [1, 2], "b": [0.707107, None]}), + ("var", {"a": [1, 2], "b": [0.5, None]}), + ("len", {"a": [1, 2], "b": [3, 1]}), + ("n_unique", {"a": [1, 2], "b": [3, 1]}), + ("count", {"a": [1, 2], "b": [2, 1]}), + ], +) +def test_group_by_depth_1_agg(attr: str, expected: dict[str, list[Any]]) -> None: + data = {"a": [1, 1, 1, 2], "b": [1, None, 2, 3]} + expr = getattr(nwp.col("b"), attr)() + result = dataframe(data).group_by("a").agg(expr).sort("a") + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("values", "expected"), + [ + ( + {"x": [True, True, True, False, False, False]}, + {"all": [True, False, False], "any": [True, True, False]}, + ), + ( + {"x": [True, None, False, None, None, None]}, + {"all": [True, False, True], "any": [True, False, False]}, + ), + ], + ids=["not-nullable", "nullable"], +) +def test_group_by_depth_1_agg_bool_ops( + values: dict[str, list[bool]], expected: dict[str, list[bool]] +) -> None: + data = {"a": [1, 1, 2, 2, 3, 3], **values} + result = ( + dataframe(data) + .group_by("a") + .agg(nwp.col("x").all().alias("all"), nwp.col("x").any().alias("any")) + .sort("a") + ) + assert_equal_data(result, {"a": [1, 2, 3], **expected}) + + +@pytest.mark.parametrize( + ("attr", "ddof"), [("std", 0), ("var", 0), ("std", 2), ("var", 2)] +) +def test_group_by_depth_1_std_var(attr: str, ddof: int) -> None: + data = {"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]} + _pow = 0.5 if attr == "std" else 1 + expected = { + "a": [1, 2], + "b": [ + (sum((v - 5) ** 2 for v in [4, 5, 6]) / (3 - ddof)) ** _pow, + (sum((v - 10 / 3) ** 2 for v in [0, 5, 5]) / (3 - ddof)) ** _pow, + ], + } + expr = getattr(nwp.col("b"), attr)(ddof=ddof) + result = dataframe(data).group_by("a").agg(expr).sort("a") + assert_equal_data(result, expected) + + +def test_group_by_median() -> None: + data = {"a": [1, 1, 1, 2, 2, 2], "b": [5, 4, 6, 7, 3, 2]} + result = dataframe(data).group_by("a").agg(nwp.col("b").median()).sort("a") + expected = {"a": [1, 2], "b": [5, 3]} + assert_equal_data(result, expected) + + +def test_group_by_n_unique_w_missing() -> None: + data = {"a": [1, 1, 2], "b": [4, None, 5], "c": [None, None, 7], "d": [1, 1, 3]} + result = ( + dataframe(data) + .group_by("a") + .agg( + nwp.col("b").n_unique(), + c_n_unique=nwp.col("c").n_unique(), + c_n_min=nwp.col("b").min(), + d_n_unique=nwp.col("d").n_unique(), + ) + .sort("a") + ) + expected = { + "a": [1, 2], + "b": [2, 1], + "c_n_unique": [1, 1], + "c_n_min": [4, 5], + "d_n_unique": [1, 1], + } + assert_equal_data(result, expected) + + +def test_group_by_simple_named() -> None: + data = {"a": [1, 1, 2], "b": [4, 5, 6], "c": [7, 2, 1]} + df = dataframe(data) + result = ( + df.group_by("a").agg(b_min=nwp.col("b").min(), b_max=nwp.col("b").max()).sort("a") + ) + expected = {"a": [1, 2], "b_min": [4, 6], "b_max": [5, 6]} + assert_equal_data(result, expected) + + +def test_group_by_simple_unnamed() -> None: + data = {"a": [1, 1, 2], "b": [4, 5, 6], "c": [7, 2, 1]} + df = dataframe(data) + result = df.group_by("a").agg(nwp.col("b").min(), nwp.col("c").max()).sort("a") + expected = {"a": [1, 2], "b": [4, 6], "c": [7, 1]} + assert_equal_data(result, expected) + + +def test_group_by_multiple_keys() -> None: + data = {"a": [1, 1, 2], "b": [4, 4, 6], "c": [7, 2, 1]} + df = dataframe(data) + result = ( + df.group_by("a", "b") + .agg(c_min=nwp.col("c").min(), c_max=nwp.col("c").max()) + .sort("a") + ) + expected = {"a": [1, 2], "b": [4, 6], "c_min": [2, 1], "c_max": [7, 1]} + assert_equal_data(result, expected) + + +def test_key_with_nulls() -> None: + data = {"b": [4, 5, None], "a": [1, 2, 3]} + result = ( + dataframe(data) + .group_by("b") + .agg(nwp.len(), nwp.col("a").min()) + .sort("a") + .with_columns(nwp.col("b").cast(nw.Float64)) + ) + expected = {"b": [4.0, 5, None], "len": [1, 1, 1], "a": [1, 2, 3]} + assert_equal_data(result, expected) + + +def test_key_with_nulls_ignored() -> None: + data = {"b": [4, 5, None], "a": [1, 2, 3]} + result = ( + dataframe(data) + .group_by("b", drop_null_keys=True) + .agg(nwp.len(), nwp.col("a").min()) + .sort("a") + .with_columns(nwp.col("b").cast(nw.Float64)) + ) + expected = {"b": [4.0, 5], "len": [1, 1], "a": [1, 2]} + assert_equal_data(result, expected) + + +def test_key_with_nulls_iter() -> None: + data = { + "b": [None, "4", "5", None, "7"], + "a": [None, 1, 2, 3, 4], + "c": [None, "4", "3", None, None], + } + result = dict(dataframe(data).group_by("b", "c", drop_null_keys=True).__iter__()) + + assert len(result) == 2 + assert_equal_data(result[("4", "4")], {"b": ["4"], "a": [1], "c": ["4"]}) + assert_equal_data(result[("5", "3")], {"b": ["5"], "a": [2], "c": ["3"]}) + + result = dict(dataframe(data).group_by("b", "c", drop_null_keys=False).__iter__()) + assert_equal_data(result[("4", "4")], {"b": ["4"], "a": [1], "c": ["4"]}) + assert_equal_data(result[("5", "3")], {"b": ["5"], "a": [2], "c": ["3"]}) + assert len(result) == 4 + + +def test_group_by_expr_iter() -> None: + data = { + "b": [None, "4", "5", None, "7"], + "a": [None, 1, 2, 3, 4], + "c": ["1", "4", "3", "1", "1"], + } + + expected = { + ("1",): {"b": [None, None, "7"], "a": [None, 3, 4], "c": ["1", "1", "1"]}, + ("3",): {"b": ["5"], "a": [2], "c": ["3"]}, + ("4",): {"b": ["4"], "a": [1], "c": ["4"]}, + } + grouped = dataframe(data).group_by(nwp.col("c").alias("d")) + result = dict(sorted((k, df.sort("c").to_dict(as_series=False)) for k, df in grouped)) + assert len(result) == len(expected) + assert result.keys() == expected.keys() + # NOTE: The bug this is trying to avoid regressing on would break zipping, as one side has more columns + result_p1 = next(iter(result.values())) + expected_p1 = next(iter(expected.values())) + assert result_p1 == expected_p1 + _assert_equal_data(result, expected) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + "keys", [[nwp.col("a").abs()], ["a", nwp.col("a").abs().alias("a_test")]] +) +def test_group_by_raise_drop_null_keys_with_exprs(keys: list[nwp.Expr | str]) -> None: + data = {"a": [1, 1, 2, 2, -1], "x": [0, 1, 2, 3, 4], "y": [0.5, -0.5, 1.0, -1.0, 1.5]} + df = dataframe(data) + with pytest.raises( + NotImplementedError, match="drop_null_keys cannot be True when keys contains Expr" + ): + df.group_by(*keys, drop_null_keys=True).agg(nwp.sum("y")) # type: ignore[call-overload] + + +def test_no_agg() -> None: + data = {"a": [1, 1, 3], "b": [4, 4, 6], "c": [7.0, 8.0, 9.0]} + result = dataframe(data).group_by(["a", "b"]).agg().sort("a", "b") + expected = {"a": [1, 3], "b": [4, 6]} + assert_equal_data(result, expected) + + +@pytest.mark.xfail( + PYARROW_VERSION < (15,), + reason=( + "The defaults for grouping by categories in pandas are different.\n\n" + "https://github.com/narwhals-dev/narwhals/issues/1078" + ), +) +def test_group_by_categorical() -> None: + data = {"g1": ["a", "a", "b", "b"], "g2": ["x", "y", "x", "z"], "x": [1, 2, 3, 4]} + df = dataframe(data) + result = ( + df.with_columns( + g1=nwp.col("g1").cast(nw.Categorical()), + g2=nwp.col("g2").cast(nw.Categorical()), + ) + .group_by(["g1", "g2"]) + .agg(nwp.col("x").sum()) + .sort("x") + ) + assert_equal_data(result, data) + + +@pytest.mark.parametrize( + ("agg", "message_body", "expected_repr"), + [ + (nwp.col("a").shift(1), r"shift.+not.+group_by.+pyarrow.+", "col('a').shift("), + ( + nwp.col("a").arg_max(), + r"arg_max.+not.+group_by.+pyarrow.+", + "col('a').arg_max(", + ), + ( + nwp.col("a").max().over("b"), + r"over.+not.+group_by.+pyarrow.+", + "col('a').max().over([col('b')])", + ), + ( + nwp.col("a").drop_nulls().abs().mean(), + r"complex aggregation found.+not.+group_by.+pyarrow.+", + "col('a').drop_nulls().abs().mean()", + ), + ], +) +def test_group_by_unsupported_raises( + agg: nwp.Expr, message_body: str, expected_repr: str +) -> None: + df = dataframe({"a": [1, 2, 3], "b": [1, 1, 2]}) + pat = re.compile(rf"{message_body}{re.escape(expected_repr)}", re.DOTALL) + with pytest.raises(InvalidOperationError, match=pat): + df.group_by("b").agg(agg) + + +def test_double_same_aggregation() -> None: + df = dataframe({"a": [1, 1, 2], "b": [4, 5, 6]}) + result = df.group_by("a").agg(c=nwp.col("b").mean(), d=nwp.col("b").mean()).sort("a") + expected = {"a": [1, 2], "c": [4.5, 6], "d": [4.5, 6]} + assert_equal_data(result, expected) + + +def test_all_kind_of_aggs() -> None: + df = dataframe({"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]}) + result = ( + df.group_by("a") + .agg( + c=nwp.col("b").mean(), + d=nwp.col("b").mean(), + e=nwp.col("b").std(ddof=1), + f=nwp.col("b").std(ddof=2), + g=nwp.col("b").var(ddof=2), + h=nwp.col("b").var(ddof=2), + i=nwp.col("b").n_unique(), + ) + .sort("a") + ) + + variance_num = sum((v - 10 / 3) ** 2 for v in [0, 5, 5]) + expected = { + "a": [1, 2], + "c": [5, 10 / 3], + "d": [5, 10 / 3], + "e": [1, (variance_num / (3 - 1)) ** 0.5], + "f": [2**0.5, (variance_num) ** 0.5], # denominator is 1 (=3-2) + "g": [2.0, variance_num], # denominator is 1 (=3-2) + "h": [2.0, variance_num], # denominator is 1 (=3-2) + "i": [3, 2], + } + assert_equal_data(result, expected) + + +def test_fancy_functions() -> None: + df = dataframe({"a": [1, 1, 2], "b": [4, 5, 6]}) + result = df.group_by("a").agg(nwp.all().std(ddof=0)).sort("a") + expected = {"a": [1, 2], "b": [0.5, 0.0]} + assert_equal_data(result, expected) + result = df.group_by("a").agg(npcs.numeric().std(ddof=0)).sort("a") + assert_equal_data(result, expected) + result = df.group_by("a").agg(npcs.matches("b").std(ddof=0)).sort("a") + assert_equal_data(result, expected) + result = df.group_by("a").agg(npcs.matches("b").std(ddof=0).alias("c")).sort("a") + expected = {"a": [1, 2], "c": [0.5, 0.0]} + assert_equal_data(result, expected) + result = ( + df.group_by("a") + .agg(npcs.matches("b").std(ddof=0).name.map(lambda _x: "c")) + .sort("a") + ) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("keys", "aggs", "expected", "sort_by"), + [ + ( + [nwp.col("a").abs(), nwp.col("a").abs().alias("a_with_alias")], + [nwp.col("x").sum()], + {"a": [1, 2], "a_with_alias": [1, 2], "x": [5, 5]}, + ["a"], + ), + ( + [nwp.col("a").alias("x")], + [nwp.col("x").mean().alias("y")], + {"x": [-1, 1, 2], "y": [4.0, 0.5, 2.5]}, + ["x"], + ), + ( + [nwp.col("a")], + [nwp.col("a").count().alias("foo-bar"), nwp.all().sum()], + {"a": [-1, 1, 2], "foo-bar": [1, 2, 2], "x": [4, 1, 5], "y": [1.5, 0, 0]}, + ["a"], + ), + ( + [nwp.col("a", "y").abs()], + [nwp.col("x").sum()], + {"a": [1, 1, 2], "y": [0.5, 1.5, 1], "x": [1, 4, 5]}, + ["a", "y"], + ), + ( + [nwp.col("a").abs().alias("y")], + [nwp.all().sum().name.suffix("c")], + {"y": [1, 2], "ac": [1, 4], "xc": [5, 5]}, + ["y"], + ), + ( + [npcs.by_dtype(nw.Float64()).abs()], + [npcs.numeric().sum()], + {"y": [0.5, 1.0, 1.5], "a": [2, 4, -1], "x": [1, 5, 4]}, + ["y"], + ), + ], +) +def test_group_by_expr( + keys: list[nwp.Expr], + aggs: list[nwp.Expr], + expected: dict[str, list[Any]], + sort_by: list[str], +) -> None: + data = {"a": [1, 1, 2, 2, -1], "x": [0, 1, 2, 3, 4], "y": [0.5, -0.5, 1.0, -1.0, 1.5]} + df = dataframe(data) + result = df.group_by(*keys).agg(*aggs).sort(*sort_by) + assert_equal_data(result, expected) + + +def test_group_by_expr_2757684799() -> None: + """From [narwhals-dev/narwhals#2325-2757684799]. + + The **incorrect** result is: + + {'b': [2, 1], 'a': [2, 1], 'c': [2.0, 1.0]} + + [narwhals-dev/narwhals#2325-2757684799]: https://github.com/narwhals-dev/narwhals/pull/2325#pullrequestreview-2757684799 + """ + data: dict[str, Any] = {"a": [1, 1, 2], "b": [4, 5, 6], "unrelated": [10, -1, -9]} + df = dataframe(data) + keys = nwp.col("a").alias("b"), "a" + aggs = nwp.col("b").mean().alias("c") + expected = {"b": [2, 1], "a": [2, 1], "c": [6.0, 4.5]} + + result = df.group_by(keys).agg(aggs).sort("b", descending=True) + assert_equal_data(result, expected) + + +def test_group_by_selector() -> None: + data = { + "a": [1, 1, 1], + "b": [4, 4, 6], + "c": ["foo", "foo", "bar"], + "x": [7.5, 8.5, 9.0], + } + result = ( + dataframe(data) + .group_by(npcs.by_dtype(nw.Int64), "c") + .agg(nwp.col("x").mean()) + .sort("a", "b") + ) + expected = {"a": [1, 1], "b": [4, 6], "c": ["foo", "bar"], "x": [8.0, 9.0]} + assert_equal_data(result, expected) + + +def test_renaming_edge_case() -> None: + data = {"a": [0, 0, 0], "_a_tmp": [1, 2, 3], "b": [4, 5, 6]} + result = dataframe(data).group_by(nwp.col("a")).agg(nwp.all().min()) + expected = {"a": [0], "_a_tmp": [1], "b": [4]} + assert_equal_data(result, expected) + + +def test_group_by_len_1_column() -> None: + """Based on a failure from marimo. + + - https://github.com/marimo-team/marimo/blob/036fd3ff89ef3a0e598bebb166637028024f98bc/tests/_plugins/ui/_impl/tables/test_narwhals.py#L1098-L1108 + - https://github.com/marimo-team/marimo/blob/036fd3ff89ef3a0e598bebb166637028024f98bc/marimo/_plugins/ui/_impl/tables/narwhals_table.py#L163-L188 + """ + data = {"a": [1, 2, 1, 2, 3, 4]} + expected = {"a": [1, 2, 3, 4], "len": [2, 2, 1, 1], "len_a": [2, 2, 1, 1]} + result = ( + dataframe(data).group_by("a").agg(nwp.len(), nwp.len().alias("len_a")).sort("a") + ) + assert_equal_data(result, expected) + + +def test_top_level_len() -> None: + # https://github.com/holoviz/holoviews/pull/6567#issuecomment-3178743331 + df = dataframe({"gender": ["m", "f", "f"], "weight": [4, 5, 6], "age": [None, 8, 9]}) + result = df.group_by(["gender"]).agg(nwp.all().len()).sort("gender") + expected = {"gender": ["f", "m"], "weight": [2, 1], "age": [2, 1]} + assert_equal_data(result, expected) + result = ( + df.group_by("gender") + .agg(nwp.col("weight").len(), nwp.col("age").len()) + .sort("gender") + ) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("keys", "aggs", "expected", "pre_sort"), + [ + (["a"], ["b"], {"a": [1, 2, 3, 4], "b": [1, 2, 4, 6]}, None), + (["a"], ["b"], {"a": [1, 2, 3, 4], "b": [1, 3, 5, 6]}, {"descending": True}), + (["a"], ["c"], {"a": [1, 2, 3, 4], "c": [None, "A", None, "B"]}, None), + ( + ["a"], + ["c"], + {"a": [1, 2, 3, 4], "c": [None, "A", "B", "B"]}, + {"nulls_last": True}, + ), + ], + ids=["no-sort", "sort-descending", "NA-order-nulls-first", "NA-order-nulls-last"], +) +def test_group_by_agg_first( + keys: Sequence[str], + aggs: Sequence[str], + expected: Mapping[str, Any], + pre_sort: Mapping[str, Any] | None, +) -> None: + data = { + "a": [1, 2, 2, 3, 3, 4], + "b": [1, 2, 3, 4, 5, 6], + "c": [None, "A", "A", None, "B", "B"], + } + df = dataframe(data) + if pre_sort: + df = df.sort(aggs, **pre_sort) + result = df.group_by(keys).agg(nwp.col(aggs).first()).sort(keys) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("keys", "aggs", "expected", "pre_sort"), + [ + (["a"], ["b"], {"a": [1, 2, 3, 4], "b": [1, 3, 5, 6]}, None), + (["a"], ["b"], {"a": [1, 2, 3, 4], "b": [1, 2, 4, 6]}, {"descending": True}), + (["a"], ["c"], {"a": [1, 2, 3, 4], "c": [None, "A", "B", "B"]}, None), + ( + ["a"], + ["c"], + {"a": [1, 2, 3, 4], "c": [None, "A", None, "B"]}, + {"nulls_last": True}, + ), + ], + ids=["no-sort", "sort-descending", "NA-order-nulls-first", "NA-order-nulls-last"], +) +def test_group_by_agg_last( + keys: Sequence[str], + aggs: Sequence[str], + expected: Mapping[str, Any], + pre_sort: Mapping[str, Any] | None, +) -> None: + data = { + "a": [1, 2, 2, 3, 3, 4], + "b": [1, 2, 3, 4, 5, 6], + "c": [None, "A", "A", None, "B", "B"], + } + df = dataframe(data) + if pre_sort: + df = df.sort(aggs, **pre_sort) + result = df.group_by(keys).agg(nwp.col(aggs).last()).sort(keys) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("keys", "aggs", "expected"), + [ + (["a"], [nwp.col("b").unique()], {"a": ["a", "b", "c"], "b": [[1], [2, 3], [3]]}), + ( + ["a"], + [nwp.col("b", "d").unique()], + { + "a": ["a", "b", "c"], + "b": [[1], [2, 3], [3]], + "d": [["three", "one"], ["three"], ["one"]], + }, + ), + ( + ["d", "c"], + [npcs.string().unique(), nwp.col("b").first().alias("b_first")], + { + "d": ["one", "one", "three", "three", "three"], + "c": [1, 3, 2, 4, 5], + "a": [["c"], ["a"], ["b"], ["b"], ["a"]], + "b_first": [3, 1, 3, 2, 1], + }, + ), + ], + ids=["Unique-Single", "Unique-Multi", "Unique-Selector-Fancy"], +) +def test_group_by_agg_unique( + keys: Sequence[str], aggs: Sequence[IntoExpr], expected: Mapping[str, Any] +) -> None: + data = { + "a": ["a", "b", "a", "b", "c"], + "b": [1, 2, 1, 3, 3], + "c": [5, 4, 3, 2, 1], + "d": ["three", "three", "one", "three", "one"], + } + df = dataframe(data) + result = df.group_by(keys).agg(aggs).sort(keys) + assert_equal_data(result, expected) + + +def test_group_by_args() -> None: + """Adapted from [upstream]. + + [upstream]: https://github.com/pola-rs/polars/blob/04dbc94c36f75ed05bb19587f2226e240ec1775f/py-polars/tests/unit/operations/test_group_by.py#L302-L325 + """ + data = { + "a": ["a", "b", "a", "b", "b", "c"], + "b": [1, 2, 3, 4, 5, 6], + "c": [6, 5, 4, 3, 2, 1], + } + df = dataframe(data) + + # Single column name + assert df.group_by("a").agg("b").columns == ["a", "b"] + # Column names as list + expected = ["a", "b", "c"] + assert df.group_by(["a", "b"]).agg("c").columns == expected + # Column names as positional arguments + assert df.group_by("a", "b").agg("c").columns == expected + # With keyword argument + assert df.group_by("a", "b", drop_null_keys=True).agg("c").columns == expected + # Multiple aggregations as list + assert df.group_by("a").agg(["b", "c"]).columns == expected + # Multiple aggregations as positional arguments + assert df.group_by("a").agg("b", "c").columns == expected + # Multiple aggregations as keyword arguments + assert df.group_by("a").agg(q="b", r="c").columns == ["a", "q", "r"] + + +def test_group_by_all() -> None: + """Adapted from [upstream]. + + [upstream]: https://github.com/pola-rs/polars/blob/04dbc94c36f75ed05bb19587f2226e240ec1775f/py-polars/tests/unit/operations/test_group_by.py#L568-L577 + """ + data = {"a": [1, 2], "b": [1, 2]} + df = dataframe(data) + expected = {"a": [1, 2], "b": [1, 2], "a_agg": [1, 2]} + result = df.group_by(nwp.all()).agg(nwp.col("a").max().name.suffix("_agg")).sort("a") + assert_equal_data(result, expected) + + +def test_group_by_input_independent_with_len_23868() -> None: + """Adapted from [upstream]. + + [upstream]: https://github.com/pola-rs/polars/blob/04dbc94c36f75ed05bb19587f2226e240ec1775f/py-polars/tests/unit/operations/test_group_by.py#L1476-L1484 + """ + data = {"a": ["A", "B", "C"]} + expected = {"literal": ["G"], "len": [3]} + result = dataframe(data).group_by(nwp.lit("G")).agg(nwp.len()) + assert_equal_data(result, expected) + + +def test_group_by_series_lit_22103() -> None: + """Adapted from [upstream], but rejecting for now. + + [upstream]: https://github.com/pola-rs/polars/blob/04dbc94c36f75ed05bb19587f2226e240ec1775f/py-polars/tests/unit/operations/test_group_by.py#L1406-L1424 + """ + data = {"g": [0, 1]} + series = nwp.Series.from_native(pa.chunked_array([[42, 2, 3]])) + df = dataframe(data) + with pytest.raises(InvalidOperationError, match=re.escape("foo=lit(Series)")): + df.group_by("g").agg(foo=series) + + +def test_group_by_named() -> None: + """Adapted from [upstream]. + + [upstream]: https://github.com/pola-rs/polars/blob/04dbc94c36f75ed05bb19587f2226e240ec1775f/py-polars/tests/unit/operations/test_group_by.py#L878-884 + """ + data = {"a": [1, 1, 2, 2, 3, 3], "b": range(6)} + df = dataframe(data) + result = df.group_by(z=nwp.col("a") * 2).agg(nwp.col("b").min()).sort("b") + expected = ( + df.group_by((nwp.col("a") * 2).alias("z")).agg(nwp.col("b").min()).sort("b") + ) + assert_equal_data(result, expected.to_dict(as_series=False)) + + +def test_group_by_exclude_keys() -> None: + # `group_by(keys)` and `exclude` share some logic + data = { + "a": ["A", "B", "A"], + "b": [1, 2, 3], + "c": [9, 2, 4], + "d": [8, 7, 8], + "e": [None, 9, 7], + "f": [True, False, None], + "g": [False, None, False], + "h": [None, None, True], + "j": [12.1, None, 4.0], + "k": [42, 10, None], + "l": [4, 5, 6], + "m": [0, 1, 2], + } + df = dataframe(data).with_columns( + npcs.boolean().fill_null(False), npcs.numeric().fill_null(0) + ) + exclude = "b", "c", "d", "e", "f", "g", "j", "k", "l", "m" + result = df.group_by(nwp.exclude(exclude)).agg(npcs.all().sum()).sort("a", "h") + expected = { + "a": ["A", "A", "B"], + "h": [False, True, False], + "b": [1, 3, 2], + "c": [9, 4, 2], + "d": [8, 8, 7], + "e": [0, 7, 9], + "f": [1, 0, 0], + "g": [0, 0, 0], + "j": [12.1, 4.0, 0.0], + "k": [42, 0, 10], + "l": [4, 6, 5], + "m": [0, 2, 1], + } + assert_equal_data(result, expected) diff --git a/tests/plan/temp_test.py b/tests/plan/temp_test.py new file mode 100644 index 0000000000..9dd7a0e42f --- /dev/null +++ b/tests/plan/temp_test.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import random +import re +import string + +# ruff: noqa: S311 +from collections import deque +from itertools import islice, product, repeat +from typing import TYPE_CHECKING, NamedTuple + +import hypothesis.strategies as st +import pytest +from hypothesis import given + +import narwhals as nw +from narwhals._plan.common import temp +from narwhals._utils import qualified_type_name +from narwhals.exceptions import NarwhalsError + +pytest.importorskip("pyarrow") +pytest.importorskip("polars") + + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from narwhals._utils import _StoresColumns + + +class MockStoresColumns(NamedTuple): + columns: Sequence[str] + + +_COLUMNS = ("abc", "XYZ", "nw2929023", "column", string.hexdigits) +_EMPTY_SCHEMA = nw.Schema((name, nw.Int64()) for name in _COLUMNS) + + +sources = pytest.mark.parametrize( + "source", + [ + _COLUMNS, + MockStoresColumns(columns=_COLUMNS), + deque(_COLUMNS), + nw.from_dict({}, _EMPTY_SCHEMA, backend="pyarrow"), + dict.fromkeys(_COLUMNS), + set(_COLUMNS), + nw.from_dict({}, _EMPTY_SCHEMA, backend="polars").to_native(), + ], + ids=qualified_type_name, +) + + +@sources +def test_temp_column_name_sources(source: _StoresColumns | Iterable[str]) -> None: + name = temp.column_name(source) + assert name not in _COLUMNS + + +@sources +def test_temp_column_names_sources(source: _StoresColumns | Iterable[str]) -> None: + it = temp.column_names(source) + name = next(it) + assert name not in _COLUMNS + + +@given(n_chars=st.integers(6, 106)) +@pytest.mark.slow +def test_temp_column_name_n_chars(n_chars: int) -> None: + name = temp.column_name(_COLUMNS, n_chars=n_chars) + assert name not in _COLUMNS + + +@given(n_new_names=st.integers(10_000, 100_000)) +@pytest.mark.slow +def test_temp_column_names_always_new_names(n_new_names: int) -> None: + it = temp.column_names(_COLUMNS) + new_names = set(islice(it, n_new_names)) + assert len(new_names) == n_new_names + assert new_names.isdisjoint(_COLUMNS) + + +@pytest.mark.parametrize( + ("prefix", "n_chars"), + [ + ("nw", random.randint(0, 5)), + ("col", random.randint(0, 4)), + ("NW_", random.randint(0, 3)), + ("join", random.randint(0, 2)), + ("__tmp", random.randint(0, 1)), + ("longer", random.randint(-5, 0)), + ("", random.randint(0, 5)), + ], +) +def test_temp_column_name_requires_more_characters(prefix: str, n_chars: int) -> None: + pattern = re.compile( + rf"temp.+column.+name.+requires.+try.+shorter.+{prefix}.+higher.+{n_chars}", + re.IGNORECASE | re.DOTALL, + ) + with pytest.raises(NarwhalsError, match=pattern): + temp.column_name(_COLUMNS, prefix=prefix, n_chars=n_chars) + + +def test_temp_column_name_failed_unique() -> None: + hex_lower = string.hexdigits.strip(string.ascii_uppercase) + every_possible_name_65k = [ + f"nw{e1}{e2}{e3}{e4}" for e1, e2, e3, e4 in product(*repeat(hex_lower, 4)) + ] + n_many_columns = len(every_possible_name_65k) + + pattern = re.compile( + rf"unable.+generate.+name.+n_chars=6.+within.+existing.+{n_many_columns}.+columns", + re.DOTALL, + ) + with pytest.raises(NarwhalsError, match=pattern): + temp.column_name(every_possible_name_65k, prefix="nw", n_chars=6) + + +def test_temp_column_names_failed_unique() -> None: + it = temp.column_names(["a", "b", "c"], prefix="long_prefix", n_chars=16) + pattern = re.compile( + r"unable.+generate.+name.+n_chars=16.+within.+existing.+.+columns.+\.\.\.", + re.DOTALL, + ) + with pytest.raises(NarwhalsError, match=pattern): + list(islice(it, 100_000)) diff --git a/tests/plan/utils.py b/tests/plan/utils.py index bf6135ee2f..d1ae2ce95e 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -36,9 +36,18 @@ def assert_expr_ir_equal( """ lhs = _unwrap_ir(actual) if isinstance(expected, str): - assert repr(lhs) == expected + assert repr(lhs) == expected, ( + f"\nlhs:\n {lhs!r}\n\nexpected:\n {expected!r}" + ) elif isinstance(actual, ir.NamedIR) and isinstance(expected, ir.NamedIR): - assert actual == expected + assert actual == expected, ( + f"\nactual:\n {actual!r}\n\nexpected:\n {expected!r}" + ) else: rhs = expected._ir if isinstance(expected, nwp.Expr) else expected - assert lhs == rhs + assert lhs == rhs, f"\nlhs:\n {lhs!r}\n\nrhs:\n {rhs!r}" + + +def named_ir(name: str, expr: nwp.Expr | ir.ExprIR, /) -> ir.NamedIR[ir.ExprIR]: + """Helper constructor for test compare.""" + return ir.NamedIR(expr=expr._ir if isinstance(expr, nwp.Expr) else expr, name=name) From 2b9dbf0b2c433a9728e203af7f10087446cd306c Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 1 Oct 2025 20:13:59 +0000 Subject: [PATCH 362/368] refactor(expr-ir): Split up and refine `protocols.py` (#3166) --- narwhals/_plan/_expr_ir.py | 2 +- narwhals/_plan/_guards.py | 2 +- narwhals/_plan/arrow/dataframe.py | 11 +- narwhals/_plan/arrow/expr.py | 23 +- narwhals/_plan/arrow/group_by.py | 2 +- narwhals/_plan/arrow/namespace.py | 2 +- narwhals/_plan/arrow/series.py | 2 +- narwhals/_plan/arrow/typing.py | 13 +- narwhals/_plan/compliant/__init__.py | 1 + narwhals/_plan/compliant/column.py | 99 +++ narwhals/_plan/compliant/dataframe.py | 143 +++++ narwhals/_plan/compliant/expr.py | 150 +++++ narwhals/_plan/compliant/group_by.py | 205 ++++++ narwhals/_plan/compliant/namespace.py | 142 ++++ narwhals/_plan/compliant/scalar.py | 120 ++++ narwhals/_plan/compliant/series.py | 75 +++ narwhals/_plan/compliant/typing.py | 92 +++ narwhals/_plan/dataframe.py | 66 +- narwhals/_plan/expressions/expr.py | 2 +- narwhals/_plan/group_by.py | 6 +- narwhals/_plan/protocols.py | 891 -------------------------- narwhals/_plan/series.py | 34 +- narwhals/_plan/typing.py | 12 + tests/plan/compliant_test.py | 12 + 24 files changed, 1129 insertions(+), 978 deletions(-) create mode 100644 narwhals/_plan/compliant/__init__.py create mode 100644 narwhals/_plan/compliant/column.py create mode 100644 narwhals/_plan/compliant/dataframe.py create mode 100644 narwhals/_plan/compliant/expr.py create mode 100644 narwhals/_plan/compliant/group_by.py create mode 100644 narwhals/_plan/compliant/namespace.py create mode 100644 narwhals/_plan/compliant/scalar.py create mode 100644 narwhals/_plan/compliant/series.py create mode 100644 narwhals/_plan/compliant/typing.py delete mode 100644 narwhals/_plan/protocols.py diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py index d163134c80..862fb0c44a 100644 --- a/narwhals/_plan/_expr_ir.py +++ b/narwhals/_plan/_expr_ir.py @@ -15,10 +15,10 @@ from typing_extensions import Self, TypeAlias + from narwhals._plan.compliant.typing import Ctx, FrameT_contra, R_co from narwhals._plan.expr import Expr, Selector from narwhals._plan.expressions.expr import Alias, Cast, Column from narwhals._plan.meta import MetaNamespace - from narwhals._plan.protocols import Ctx, FrameT_contra, R_co from narwhals._plan.typing import ExprIRT2, MapIR, Seq from narwhals.dtypes import DType diff --git a/narwhals/_plan/_guards.py b/narwhals/_plan/_guards.py index 0f62942ab8..ed2762d993 100644 --- a/narwhals/_plan/_guards.py +++ b/narwhals/_plan/_guards.py @@ -12,8 +12,8 @@ from typing_extensions import TypeIs from narwhals._plan import expressions as ir + from narwhals._plan.compliant.series import CompliantSeries from narwhals._plan.expr import Expr - from narwhals._plan.protocols import CompliantSeries from narwhals._plan.series import Series from narwhals._plan.typing import NativeSeriesT, Seq from narwhals.typing import NonNestedLiteral diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index b588b59180..668fd5330c 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -12,8 +12,9 @@ from narwhals._plan.arrow import functions as fn from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy from narwhals._plan.arrow.series import ArrowSeries as Series +from narwhals._plan.compliant.dataframe import EagerDataFrame +from narwhals._plan.compliant.typing import namespace from narwhals._plan.expressions import NamedIR -from narwhals._plan.protocols import EagerDataFrame, namespace from narwhals._plan.typing import Seq from narwhals._utils import Version, parse_columns_to_drop from narwhals.schema import Schema @@ -23,10 +24,9 @@ from typing_extensions import Self - from narwhals._arrow.typing import ChunkedArrayAny + from narwhals._arrow.typing import ChunkedArrayAny # noqa: F401 from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar from narwhals._plan.arrow.namespace import ArrowNamespace - from narwhals._plan.dataframe import DataFrame as NwDataFrame from narwhals._plan.expressions import ExprIR, NamedIR from narwhals._plan.options import SortMultipleOptions from narwhals._plan.typing import Seq @@ -62,11 +62,6 @@ def schema(self) -> dict[str, DType]: def __len__(self) -> int: return self.native.num_rows - def to_narwhals(self) -> NwDataFrame[pa.Table, ChunkedArrayAny]: - from narwhals._plan.dataframe import DataFrame - - return DataFrame[pa.Table, "ChunkedArrayAny"]._from_compliant(self) - @classmethod def from_dict( cls, data: Mapping[str, Any], /, *, schema: IntoSchema | None = None diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index b547ed57fa..9dc25f05cd 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -9,8 +9,11 @@ from narwhals._plan.arrow import functions as fn from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co +from narwhals._plan.compliant.column import ExprDispatch +from narwhals._plan.compliant.expr import EagerExpr +from narwhals._plan.compliant.scalar import EagerScalar +from narwhals._plan.compliant.typing import namespace from narwhals._plan.expressions import NamedIR -from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch, namespace from narwhals._utils import ( Implementation, Version, @@ -449,28 +452,10 @@ def broadcast(self, length: int) -> Series: chunked = fn.chunked_array(pa_repeat(scalar, length)) return Series.from_native(chunked, self.name, version=self.version) - def arg_min(self, node: ArgMin, frame: Frame, name: str) -> Scalar: - return self._with_native(pa.scalar(0), name) - - def arg_max(self, node: ArgMax, frame: Frame, name: str) -> Scalar: - return self._with_native(pa.scalar(0), name) - - def n_unique(self, node: NUnique, frame: Frame, name: str) -> Scalar: - return self._with_native(pa.scalar(1), name) - - def std(self, node: Std, frame: Frame, name: str) -> Scalar: - return self._with_native(pa.scalar(None, pa.null()), name) - - def var(self, node: Var, frame: Frame, name: str) -> Scalar: - return self._with_native(pa.scalar(None, pa.null()), name) - def count(self, node: Count, frame: Frame, name: str) -> Scalar: native = node.expr.dispatch(self, frame, name).native return self._with_native(pa.scalar(1 if native.is_valid else 0), name) - def len(self, node: Len, frame: Frame, name: str) -> Scalar: - return self._with_native(pa.scalar(1), name) - filter = not_implemented() over = not_implemented() over_ordered = not_implemented() diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index c878f344ed..b519261c53 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -9,8 +9,8 @@ from narwhals._plan._guards import is_agg_expr, is_function_expr from narwhals._plan.arrow import acero, functions as fn, options from narwhals._plan.common import dispatch_method_name, temp +from narwhals._plan.compliant.group_by import EagerDataFrameGroupBy from narwhals._plan.expressions import aggregation as agg -from narwhals._plan.protocols import EagerDataFrameGroupBy from narwhals._utils import Implementation from narwhals.exceptions import InvalidOperationError diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index e4f68f27db..b532f77883 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -9,8 +9,8 @@ from narwhals._arrow.utils import narwhals_to_native_dtype from narwhals._plan._guards import is_tuple_of from narwhals._plan.arrow import functions as fn +from narwhals._plan.compliant.namespace import EagerNamespace from narwhals._plan.expressions.literal import is_literal_scalar -from narwhals._plan.protocols import EagerNamespace from narwhals._utils import Version from narwhals.exceptions import InvalidOperationError diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index d7941fa681..c068cf43ed 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -4,7 +4,7 @@ from narwhals._arrow.utils import narwhals_to_native_dtype, native_to_narwhals_dtype from narwhals._plan.arrow import functions as fn -from narwhals._plan.protocols import CompliantSeries +from narwhals._plan.compliant.series import CompliantSeries from narwhals._utils import Version from narwhals.dependencies import is_numpy_array_1d diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index e11e9d45c1..dc0795c95b 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Callable, Mapping +from collections.abc import Callable, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal, Protocol, overload from narwhals._typing_compat import TypeVar @@ -23,10 +23,21 @@ ) from typing_extensions import TypeAlias + from narwhals.typing import NativeDataFrame, NativeSeries + StringScalar: TypeAlias = "Scalar[StringType | LargeStringType]" IntegerType: TypeAlias = "Int8Type | Int16Type | Int32Type | Int64Type | Uint8Type | Uint16Type | Uint32Type | Uint64Type" IntegerScalar: TypeAlias = "Scalar[IntegerType]" + class NativeArrowSeries(NativeSeries, Protocol): + @property + def chunks(self) -> list[Any]: ... + + class NativeArrowDataFrame(NativeDataFrame, Protocol): + def column(self, *args: Any, **kwds: Any) -> NativeArrowSeries: ... + @property + def columns(self) -> Sequence[NativeArrowSeries]: ... + ScalarT = TypeVar("ScalarT", bound="pa.Scalar[Any]", default="pa.Scalar[Any]") ScalarPT_contra = TypeVar( diff --git a/narwhals/_plan/compliant/__init__.py b/narwhals/_plan/compliant/__init__.py new file mode 100644 index 0000000000..9d48db4f9f --- /dev/null +++ b/narwhals/_plan/compliant/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/narwhals/_plan/compliant/column.py b/narwhals/_plan/compliant/column.py new file mode 100644 index 0000000000..f2bb799409 --- /dev/null +++ b/narwhals/_plan/compliant/column.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from collections.abc import Sized +from typing import TYPE_CHECKING, Protocol + +from narwhals._plan.common import flatten_hash_safe +from narwhals._plan.compliant.typing import ( + FrameT_contra, + HasVersion, + LengthT, + NamespaceT_co, + R_co, + SeriesT, +) + +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + + from typing_extensions import Self + + from narwhals._plan import expressions as ir + from narwhals._plan.typing import OneOrIterable + + +class SupportsBroadcast(Protocol[SeriesT, LengthT]): + """Minimal broadcasting for `Expr` results.""" + + def _length(self) -> LengthT: + """Return the length of the current expression.""" + ... + + @classmethod + def _length_all( + cls, exprs: Sequence[SupportsBroadcast[SeriesT, LengthT]], / + ) -> Sequence[LengthT]: + return [e._length() for e in exprs] + + @classmethod + def _length_max(cls, lengths: Sequence[LengthT], /) -> LengthT: + """Return the maximum length among `exprs`.""" + ... + + @classmethod + def _length_required( + cls, exprs: Sequence[SupportsBroadcast[SeriesT, LengthT]], / + ) -> LengthT | None: + """Return the broadcast length, if all lengths do not equal the maximum.""" + + @classmethod + def align( + cls, *exprs: OneOrIterable[SupportsBroadcast[SeriesT, LengthT]] + ) -> Iterator[SeriesT]: + exprs = tuple[SupportsBroadcast[SeriesT, LengthT], ...](flatten_hash_safe(exprs)) + length = cls._length_required(exprs) + if length is None: + for e in exprs: + yield e.to_series() + else: + for e in exprs: + yield e.broadcast(length) + + def broadcast(self, length: LengthT, /) -> SeriesT: ... + @classmethod + def from_series(cls, series: SeriesT, /) -> Self: ... + def to_series(self) -> SeriesT: ... + + +class EagerBroadcast(Sized, SupportsBroadcast[SeriesT, int], Protocol[SeriesT]): + """Determines expression length via the size of the container.""" + + def _length(self) -> int: + return len(self) + + @classmethod + def _length_max(cls, lengths: Sequence[int], /) -> int: + return max(lengths) + + @classmethod + def _length_required( + cls, exprs: Sequence[SupportsBroadcast[SeriesT, int]], / + ) -> int | None: + lengths = cls._length_all(exprs) + max_length = cls._length_max(lengths) + required = any(len_ != max_length for len_ in lengths) + return max_length if required else None + + +class ExprDispatch(HasVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]): + # NOTE: Needs to stay `covariant` and never be used as a parameter + def __narwhals_namespace__(self) -> NamespaceT_co: ... + @classmethod + def from_ir(cls, node: ir.ExprIR, frame: FrameT_contra, name: str) -> R_co: + obj = cls.__new__(cls) + obj._version = frame.version + return node.dispatch(obj, frame, name) + + @classmethod + def from_named_ir(cls, named_ir: ir.NamedIR[ir.ExprIR], frame: FrameT_contra) -> R_co: + return cls.from_ir(named_ir.expr, frame, named_ir.name) diff --git a/narwhals/_plan/compliant/dataframe.py b/narwhals/_plan/compliant/dataframe.py new file mode 100644 index 0000000000..cc7bab7501 --- /dev/null +++ b/narwhals/_plan/compliant/dataframe.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, Protocol, overload + +from narwhals._plan.compliant.group_by import Grouped +from narwhals._plan.compliant.typing import ColumnT_co, HasVersion, SeriesT +from narwhals._plan.typing import ( + IntoExpr, + NativeDataFrameT, + NativeFrameT_co, + NativeSeriesT, + OneOrIterable, +) + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Mapping, Sequence + + from typing_extensions import Self, TypeAlias + + from narwhals._plan import expressions as ir + from narwhals._plan.compliant.group_by import ( + CompliantGroupBy, + DataFrameGroupBy, + EagerDataFrameGroupBy, + GroupByResolver, + ) + from narwhals._plan.compliant.namespace import EagerNamespace + from narwhals._plan.dataframe import BaseFrame, DataFrame + from narwhals._plan.expressions import NamedIR + from narwhals._plan.options import SortMultipleOptions + from narwhals._plan.typing import Seq + from narwhals._utils import Version + from narwhals.dtypes import DType + from narwhals.typing import IntoSchema + +Incomplete: TypeAlias = Any + + +class CompliantFrame(HasVersion, Protocol[ColumnT_co, NativeFrameT_co]): + def __narwhals_namespace__(self) -> Any: ... + def _evaluate_irs( + self, nodes: Iterable[NamedIR[ir.ExprIR]], / + ) -> Iterator[ColumnT_co]: ... + @property + def _group_by(self) -> type[CompliantGroupBy[Self]]: ... + def _with_native(self, native: Incomplete) -> Self: ... + @classmethod + def from_native(cls, native: Incomplete, /, version: Version) -> Self: ... + @property + def native(self) -> NativeFrameT_co: ... + def to_narwhals(self) -> BaseFrame[NativeFrameT_co]: ... + @property + def columns(self) -> list[str]: ... + def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: ... + def drop_nulls(self, subset: Sequence[str] | None) -> Self: ... + @property + def schema(self) -> Mapping[str, DType]: ... + def select(self, irs: Seq[NamedIR]) -> Self: ... + def select_names(self, *column_names: str) -> Self: ... + def sort(self, by: Seq[NamedIR], options: SortMultipleOptions) -> Self: ... + def with_columns(self, irs: Seq[NamedIR]) -> Self: ... + + +class CompliantDataFrame( + CompliantFrame[SeriesT, NativeDataFrameT], + Protocol[SeriesT, NativeDataFrameT, NativeSeriesT], +): + _native: NativeDataFrameT + + def __len__(self) -> int: ... + @property + def _group_by(self) -> type[DataFrameGroupBy[Self]]: ... + @property + def _grouper(self) -> type[Grouped]: + return Grouped + + def _with_native(self, native: NativeDataFrameT) -> Self: + return self.from_native(native, self.version) + + @classmethod + def from_native(cls, native: NativeDataFrameT, /, version: Version) -> Self: + obj = cls.__new__(cls) + obj._native = native + obj._version = version + return obj + + @property + def native(self) -> NativeDataFrameT: + return self._native + + @classmethod + def from_dict( + cls, data: Mapping[str, Any], /, *, schema: IntoSchema | None = None + ) -> Self: ... + def group_by_agg( + self, by: OneOrIterable[IntoExpr], aggs: OneOrIterable[IntoExpr], / + ) -> Self: + """Compliant-level `group_by(by).agg(agg)`, allows `Expr`.""" + return self._grouper.by(by).agg(aggs).resolve(self).evaluate(self) + + def group_by_names(self, names: Seq[str], /) -> DataFrameGroupBy[Self]: + """Compliant-level `group_by`, allowing only `str` keys.""" + return self._group_by.by_names(self, names) + + def group_by_resolver(self, resolver: GroupByResolver, /) -> DataFrameGroupBy[Self]: + """Narwhals-level resolved `group_by`. + + `keys`, `aggs` are already parsed and projections planned. + """ + return self._group_by.from_resolver(self, resolver) + + def row(self, index: int) -> tuple[Any, ...]: ... + @overload + def to_dict(self, *, as_series: Literal[True]) -> dict[str, SeriesT]: ... + @overload + def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... + @overload + def to_dict( + self, *, as_series: bool + ) -> dict[str, SeriesT] | dict[str, list[Any]]: ... + def to_dict( + self, *, as_series: bool + ) -> dict[str, SeriesT] | dict[str, list[Any]]: ... + def to_narwhals(self) -> DataFrame[NativeDataFrameT, NativeSeriesT]: + from narwhals._plan.dataframe import DataFrame + + return DataFrame[NativeDataFrameT, NativeSeriesT](self) + + def with_row_index(self, name: str) -> Self: ... + + +class EagerDataFrame( + CompliantDataFrame[SeriesT, NativeDataFrameT, NativeSeriesT], + Protocol[SeriesT, NativeDataFrameT, NativeSeriesT], +): + def __narwhals_namespace__(self) -> EagerNamespace[Self, SeriesT, Any, Any]: ... + @property + def _group_by(self) -> type[EagerDataFrameGroupBy[Self]]: ... + def select(self, irs: Seq[NamedIR]) -> Self: + return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) + + def with_columns(self, irs: Seq[NamedIR]) -> Self: + return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py new file mode 100644 index 0000000000..229151284f --- /dev/null +++ b/narwhals/_plan/compliant/expr.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol + +from narwhals._plan.compliant.column import EagerBroadcast, SupportsBroadcast +from narwhals._plan.compliant.typing import ( + FrameT_contra, + HasVersion, + LengthT, + SeriesT, + SeriesT_co, +) +from narwhals._utils import Version + +if TYPE_CHECKING: + from typing_extensions import Self + + from narwhals._plan import expressions as ir + from narwhals._plan.compliant.scalar import CompliantScalar + from narwhals._plan.expressions import ( + BinaryExpr, + FunctionExpr, + aggregation as agg, + boolean, + functions as F, + ) + from narwhals._plan.expressions.boolean import IsBetween, IsFinite, IsNan, IsNull, Not + + +class CompliantExpr(HasVersion, Protocol[FrameT_contra, SeriesT_co]): + """Everything common to `Expr`/`Series` and `Scalar` literal values.""" + + _evaluated: Any + """Compliant or native value.""" + + def _with_native(self, native: Any, name: str, /) -> Self: + return self.from_native(native, name or self.name, self.version) + + @classmethod + def from_native( + cls, native: Any, name: str = "", /, version: Version = Version.MAIN + ) -> Self: ... + @property + def name(self) -> str: ... + # series & scalar + def abs(self, node: FunctionExpr[F.Abs], frame: FrameT_contra, name: str) -> Self: ... + def binary_expr(self, node: BinaryExpr, frame: FrameT_contra, name: str) -> Self: ... + def cast(self, node: ir.Cast, frame: FrameT_contra, name: str) -> Self: ... + def fill_null( + self, node: FunctionExpr[F.FillNull], frame: FrameT_contra, name: str + ) -> Self: ... + def is_between( + self, node: FunctionExpr[IsBetween], frame: FrameT_contra, name: str + ) -> Self: ... + def is_finite( + self, node: FunctionExpr[IsFinite], frame: FrameT_contra, name: str + ) -> Self: ... + def is_nan( + self, node: FunctionExpr[IsNan], frame: FrameT_contra, name: str + ) -> Self: ... + def is_null( + self, node: FunctionExpr[IsNull], frame: FrameT_contra, name: str + ) -> Self: ... + def map_batches( + self, node: ir.AnonymousExpr, frame: FrameT_contra, name: str + ) -> Self: ... + def not_(self, node: FunctionExpr[Not], frame: FrameT_contra, name: str) -> Self: ... + def over(self, node: ir.WindowExpr, frame: FrameT_contra, name: str) -> Self: ... + # NOTE: `Scalar` is returned **only** for un-partitioned `OrderableAggExpr` + # e.g. `nw.col("a").first().over(order_by="b")` + def over_ordered( + self, node: ir.OrderedWindowExpr, frame: FrameT_contra, name: str + ) -> Self | CompliantScalar[FrameT_contra, SeriesT_co]: ... + def pow(self, node: FunctionExpr[F.Pow], frame: FrameT_contra, name: str) -> Self: ... + def rolling_expr( + self, node: ir.RollingExpr, frame: FrameT_contra, name: str + ) -> Self: ... + def ternary_expr( + self, node: ir.TernaryExpr, frame: FrameT_contra, name: str + ) -> Self: ... + # series only + def filter(self, node: ir.Filter, frame: FrameT_contra, name: str) -> Self: ... + def sort(self, node: ir.Sort, frame: FrameT_contra, name: str) -> Self: ... + def sort_by(self, node: ir.SortBy, frame: FrameT_contra, name: str) -> Self: ... + # series -> scalar + def all( + self, node: FunctionExpr[boolean.All], frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def any( + self, node: FunctionExpr[boolean.Any], frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def arg_max( + self, node: agg.ArgMax, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def arg_min( + self, node: agg.ArgMin, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def count( + self, node: agg.Count, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def first( + self, node: agg.First, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def last( + self, node: agg.Last, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def len( + self, node: agg.Len, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def max( + self, node: agg.Max, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def mean( + self, node: agg.Mean, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def median( + self, node: agg.Median, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def min( + self, node: agg.Min, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def n_unique( + self, node: agg.NUnique, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def quantile( + self, node: agg.Quantile, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def sum( + self, node: agg.Sum, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def std( + self, node: agg.Std, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def var( + self, node: agg.Var, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + + +class EagerExpr( + EagerBroadcast[SeriesT], + CompliantExpr[FrameT_contra, SeriesT], + Protocol[FrameT_contra, SeriesT], +): ... + + +class LazyExpr( + SupportsBroadcast[SeriesT, LengthT], + CompliantExpr[FrameT_contra, SeriesT], + Protocol[FrameT_contra, SeriesT, LengthT], +): ... diff --git a/narwhals/_plan/compliant/group_by.py b/narwhals/_plan/compliant/group_by.py new file mode 100644 index 0000000000..7ae5f3e966 --- /dev/null +++ b/narwhals/_plan/compliant/group_by.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +from itertools import chain +from typing import TYPE_CHECKING, Any, Protocol + +from narwhals._plan._expansion import prepare_projection +from narwhals._plan._parse import parse_into_seq_of_expr_ir +from narwhals._plan.common import replace, temp +from narwhals._plan.compliant.typing import ( + DataFrameT, + EagerDataFrameT, + FrameT_co, + ResolverT_co, +) +from narwhals.exceptions import ComputeError + +if TYPE_CHECKING: + from collections.abc import Iterator + + from typing_extensions import Self + + from narwhals._plan.expressions import ExprIR, NamedIR + from narwhals._plan.schema import FrozenSchema, IntoFrozenSchema + from narwhals._plan.typing import IntoExpr, OneOrIterable, Seq + + +class CompliantGroupBy(Protocol[FrameT_co]): + def agg(self, irs: Seq[NamedIR]) -> FrameT_co: ... + @property + def compliant(self) -> FrameT_co: ... + + +class DataFrameGroupBy(CompliantGroupBy[DataFrameT], Protocol[DataFrameT]): + _keys: Seq[NamedIR] + _key_names: Seq[str] + + def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: ... + @classmethod + def by_names( + cls, df: DataFrameT, names: Seq[str], / + ) -> DataFrameGroupBy[DataFrameT]: ... + @classmethod + def from_resolver( + cls, df: DataFrameT, resolver: GroupByResolver, / + ) -> DataFrameGroupBy[DataFrameT]: ... + @property + def keys(self) -> Seq[NamedIR]: + return self._keys + + @property + def key_names(self) -> Seq[str]: + if names := self._key_names: + return names + msg = "at least one key is required in a group_by operation" + raise ComputeError(msg) + + +class EagerDataFrameGroupBy(DataFrameGroupBy[EagerDataFrameT], Protocol[EagerDataFrameT]): + _df: EagerDataFrameT + _key_names: Seq[str] + _key_names_original: Seq[str] + _column_names_original: Seq[str] + + @classmethod + def by_names(cls, df: EagerDataFrameT, names: Seq[str], /) -> Self: + obj = cls.__new__(cls) + obj._df = df + obj._keys = () + obj._key_names = names + obj._key_names_original = () + obj._column_names_original = tuple(df.columns) + return obj + + @classmethod + def from_resolver( + cls, df: EagerDataFrameT, resolver: GroupByResolver, / + ) -> EagerDataFrameGroupBy[EagerDataFrameT]: + key_names = resolver.key_names + if not resolver.requires_projection(): + df = df.drop_nulls(key_names) if resolver._drop_null_keys else df + return cls.by_names(df, key_names) + obj = cls.__new__(cls) + unique_names = temp.column_names(chain(key_names, df.columns)) + safe_keys = tuple( + replace(key, name=name) for key, name in zip(resolver.keys, unique_names) + ) + obj._df = df.with_columns(resolver._schema_in.with_columns_irs(safe_keys)) + obj._keys = safe_keys + obj._key_names = tuple(e.name for e in safe_keys) + obj._key_names_original = key_names + obj._column_names_original = resolver._schema_in.names + return obj + + +class Grouper(Protocol[ResolverT_co]): + """`GroupBy` helper for collecting and forwarding `Expr`s for projection. + + - Uses `Expr` everywhere (no need to duplicate layers) + - Resolver only needs schema (neither needs a frame, but can use one to get `schema`) + """ + + _keys: Seq[ExprIR] + _aggs: Seq[ExprIR] + _drop_null_keys: bool + + @property + def _resolver(self) -> type[ResolverT_co]: ... + def agg(self, *aggs: OneOrIterable[IntoExpr]) -> Self: + self._aggs = parse_into_seq_of_expr_ir(*aggs) + return self + + @classmethod + def by(cls, *by: OneOrIterable[IntoExpr]) -> Self: + obj = cls.__new__(cls) + obj._keys = parse_into_seq_of_expr_ir(*by) + return obj + + def resolve(self, context: IntoFrozenSchema, /) -> ResolverT_co: + """Project keys and aggs in `context`, expanding all `Expr` -> `NamedIR`.""" + return self._resolver.from_grouper(self, context) + + +class GroupByResolver: + """Narwhals-level `GroupBy` resolver.""" + + _schema_in: FrozenSchema + _keys: Seq[NamedIR] + _aggs: Seq[NamedIR] + _key_names: Seq[str] + _schema: FrozenSchema + _drop_null_keys: bool + + @classmethod + def from_grouper(cls, grouper: Grouper[Self], context: IntoFrozenSchema, /) -> Self: + """Loosely based on [`resolve_group_by`]. + + [`resolve_group_by`]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs#L1125-L1227 + """ + obj = cls.__new__(cls) + keys, schema_in = prepare_projection(grouper._keys, schema=context) + obj._keys, obj._schema_in = keys, schema_in + obj._key_names = tuple(e.name for e in keys) + obj._aggs, _ = prepare_projection(grouper._aggs, obj.key_names, schema=schema_in) + obj._schema = schema_in.select(keys).merge(schema_in.select(obj._aggs)) + obj._drop_null_keys = grouper._drop_null_keys + return obj + + @property + def aggs(self) -> Seq[NamedIR]: + return self._aggs + + def evaluate(self, frame: DataFrameT) -> DataFrameT: + """Perform the `group_by` on `frame`.""" + return frame.group_by_resolver(self).agg(self.aggs) + + @property + def keys(self) -> Seq[NamedIR]: + return self._keys + + @property + def key_names(self) -> Seq[str]: + if names := self._key_names: + return names + if keys := self.keys: + return tuple(e.name for e in keys) + msg = "at least one key is required in a group_by operation" + raise ComputeError(msg) + + def requires_projection(self, *, allow_aliasing: bool = False) -> bool: + """Return True is group keys contain anything that is not a column selection. + + Notes: + If False is returned, we can just use the resolved key names as a fast-path to group. + + Arguments: + allow_aliasing: If False (default), any aliasing is not considered to be column selection. + """ + if not all(key.is_column(allow_aliasing=allow_aliasing) for key in self.keys): + if self._drop_null_keys: + msg = "drop_null_keys cannot be True when keys contains Expr or Series" + raise NotImplementedError(msg) + return True + return False + + @property + def schema(self) -> FrozenSchema: + return self._schema + + +class Resolved(GroupByResolver): + """Compliant-level `GroupBy` resolver.""" + + _drop_null_keys: bool = False + + +class Grouped(Grouper[Resolved]): + """Compliant-level `GroupBy` helper.""" + + _keys: Seq[ExprIR] + _aggs: Seq[ExprIR] + _drop_null_keys: bool = False + + @property + def _resolver(self) -> type[Resolved]: + return Resolved diff --git a/narwhals/_plan/compliant/namespace.py b/narwhals/_plan/compliant/namespace.py new file mode 100644 index 0000000000..743312d5c1 --- /dev/null +++ b/narwhals/_plan/compliant/namespace.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, Protocol, overload + +from narwhals._plan.compliant.typing import ( + ConcatT1, + ConcatT2, + EagerDataFrameT, + EagerExprT_co, + EagerScalarT_co, + ExprT_co, + FrameT, + HasVersion, + LazyExprT_co, + LazyScalarT_co, + ScalarT_co, + SeriesT, +) + +if TYPE_CHECKING: + from collections.abc import Iterable + + from typing_extensions import TypeIs + + from narwhals._plan import expressions as ir + from narwhals._plan.expressions import FunctionExpr, boolean, functions as F + from narwhals._plan.expressions.ranges import IntRange + from narwhals._plan.expressions.strings import ConcatStr + from narwhals._plan.series import Series + from narwhals.typing import ConcatMethod, NonNestedLiteral + + +class CompliantNamespace(HasVersion, Protocol[FrameT, ExprT_co, ScalarT_co]): + @property + def _expr(self) -> type[ExprT_co]: ... + @property + def _frame(self) -> type[FrameT]: ... + @property + def _scalar(self) -> type[ScalarT_co]: ... + def all_horizontal( + self, node: FunctionExpr[boolean.AllHorizontal], frame: FrameT, name: str + ) -> ExprT_co | ScalarT_co: ... + def any_horizontal( + self, node: FunctionExpr[boolean.AnyHorizontal], frame: FrameT, name: str + ) -> ExprT_co | ScalarT_co: ... + def col(self, node: ir.Column, frame: FrameT, name: str) -> ExprT_co: ... + def concat_str( + self, node: FunctionExpr[ConcatStr], frame: FrameT, name: str + ) -> ExprT_co | ScalarT_co: ... + def int_range( + self, node: ir.RangeExpr[IntRange], frame: FrameT, name: str + ) -> ExprT_co: ... + def len(self, node: ir.Len, frame: FrameT, name: str) -> ScalarT_co: ... + def lit( + self, node: ir.Literal[Any], frame: FrameT, name: str + ) -> ExprT_co | ScalarT_co: ... + def max_horizontal( + self, node: FunctionExpr[F.MaxHorizontal], frame: FrameT, name: str + ) -> ExprT_co | ScalarT_co: ... + def mean_horizontal( + self, node: FunctionExpr[F.MeanHorizontal], frame: FrameT, name: str + ) -> ExprT_co | ScalarT_co: ... + def min_horizontal( + self, node: FunctionExpr[F.MinHorizontal], frame: FrameT, name: str + ) -> ExprT_co | ScalarT_co: ... + def sum_horizontal( + self, node: FunctionExpr[F.SumHorizontal], frame: FrameT, name: str + ) -> ExprT_co | ScalarT_co: ... + + +# NOTE: `mypy` is wrong +# error: Invariant type variable "ConcatT2" used in protocol where covariant one is expected [misc] +class Concat(Protocol[ConcatT1, ConcatT2]): # type: ignore[misc] + @overload + def concat(self, items: Iterable[ConcatT1], *, how: ConcatMethod) -> ConcatT1: ... + # Series only supports vertical publicly (like in polars) + @overload + def concat( + self, items: Iterable[ConcatT2], *, how: Literal["vertical"] + ) -> ConcatT2: ... + def concat( + self, items: Iterable[ConcatT1 | ConcatT2], *, how: ConcatMethod + ) -> ConcatT1 | ConcatT2: ... + + +class EagerConcat(Concat[ConcatT1, ConcatT2], Protocol[ConcatT1, ConcatT2]): # type: ignore[misc] + def _concat_diagonal(self, items: Iterable[ConcatT1], /) -> ConcatT1: ... + # Series can be used here to go from [Series, Series] -> DataFrame + # but that is only available privately + def _concat_horizontal(self, items: Iterable[ConcatT1 | ConcatT2], /) -> ConcatT1: ... + def _concat_vertical( + self, items: Iterable[ConcatT1 | ConcatT2], / + ) -> ConcatT1 | ConcatT2: ... + + +class EagerNamespace( + EagerConcat[EagerDataFrameT, SeriesT], + CompliantNamespace[EagerDataFrameT, EagerExprT_co, EagerScalarT_co], + Protocol[EagerDataFrameT, SeriesT, EagerExprT_co, EagerScalarT_co], +): + @property + def _dataframe(self) -> type[EagerDataFrameT]: ... + @property + def _frame(self) -> type[EagerDataFrameT]: + return self._dataframe + + @property + def _series(self) -> type[SeriesT]: ... + def _is_dataframe(self, obj: Any) -> TypeIs[EagerDataFrameT]: + return isinstance(obj, self._dataframe) + + def _is_series(self, obj: Any) -> TypeIs[SeriesT]: + return isinstance(obj, self._series) + + def len(self, node: ir.Len, frame: EagerDataFrameT, name: str) -> EagerScalarT_co: + return self._scalar.from_python( + len(frame), name or node.name, dtype=None, version=frame.version + ) + + @overload + def lit( + self, node: ir.Literal[NonNestedLiteral], frame: EagerDataFrameT, name: str + ) -> EagerScalarT_co: ... + @overload + def lit( + self, node: ir.Literal[Series[Any]], frame: EagerDataFrameT, name: str + ) -> EagerExprT_co: ... + def lit( + self, node: ir.Literal[Any], frame: EagerDataFrameT, name: str + ) -> EagerExprT_co | EagerScalarT_co: ... + + +class LazyNamespace( + Concat[FrameT, FrameT], + CompliantNamespace[FrameT, LazyExprT_co, LazyScalarT_co], + Protocol[FrameT, LazyExprT_co, LazyScalarT_co], +): + @property + def _lazyframe(self) -> type[FrameT]: ... + @property + def _frame(self) -> type[FrameT]: + return self._lazyframe diff --git a/narwhals/_plan/compliant/scalar.py b/narwhals/_plan/compliant/scalar.py new file mode 100644 index 0000000000..19fda1f003 --- /dev/null +++ b/narwhals/_plan/compliant/scalar.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol + +from narwhals._plan.compliant.expr import CompliantExpr, EagerExpr, LazyExpr +from narwhals._plan.compliant.typing import FrameT_contra, LengthT, SeriesT, SeriesT_co + +if TYPE_CHECKING: + from typing_extensions import Self + + from narwhals._plan import expressions as ir + from narwhals._plan.expressions import aggregation as agg + from narwhals._utils import Version + from narwhals.typing import IntoDType, PythonLiteral + + +class CompliantScalar( + CompliantExpr[FrameT_contra, SeriesT_co], Protocol[FrameT_contra, SeriesT_co] +): + _name: str + + def _cast_float(self, node: ir.ExprIR, frame: FrameT_contra, name: str) -> Self: + """`polars` interpolates a single scalar as a float.""" + dtype = self.version.dtypes.Float64() + return self.cast(node.cast(dtype), frame, name) + + def _with_evaluated(self, evaluated: Any, name: str) -> Self: + """Expr is based on a series having these via accessors, but a scalar needs to keep passing through.""" + cls = type(self) + obj = cls.__new__(cls) + obj._evaluated = evaluated + obj._name = name or self.name + obj._version = self.version + return obj + + @property + def name(self) -> str: + return self._name + + @classmethod + def from_python( + cls, + value: PythonLiteral, + name: str = "literal", + /, + *, + dtype: IntoDType | None, + version: Version, + ) -> Self: ... + def arg_max(self, node: agg.ArgMax, frame: FrameT_contra, name: str) -> Self: + return self.from_python(0, name, dtype=None, version=self.version) + + def arg_min(self, node: agg.ArgMin, frame: FrameT_contra, name: str) -> Self: + return self.from_python(0, name, dtype=None, version=self.version) + + def count(self, node: agg.Count, frame: FrameT_contra, name: str) -> Self: + """Returns 0 if null, else 1.""" + ... + + def first(self, node: agg.First, frame: FrameT_contra, name: str) -> Self: + return self._with_evaluated(self._evaluated, name) + + def last(self, node: agg.Last, frame: FrameT_contra, name: str) -> Self: + return self._with_evaluated(self._evaluated, name) + + def len(self, node: agg.Len, frame: FrameT_contra, name: str) -> Self: + return self.from_python(1, name, dtype=None, version=self.version) + + def max(self, node: agg.Max, frame: FrameT_contra, name: str) -> Self: + return self._with_evaluated(self._evaluated, name) + + def mean(self, node: agg.Mean, frame: FrameT_contra, name: str) -> Self: + return self._cast_float(node.expr, frame, name) + + def median(self, node: agg.Median, frame: FrameT_contra, name: str) -> Self: + return self._cast_float(node.expr, frame, name) + + def min(self, node: agg.Min, frame: FrameT_contra, name: str) -> Self: + return self._with_evaluated(self._evaluated, name) + + def n_unique(self, node: agg.NUnique, frame: FrameT_contra, name: str) -> Self: + return self.from_python(1, name, dtype=None, version=self.version) + + def quantile(self, node: agg.Quantile, frame: FrameT_contra, name: str) -> Self: + return self._cast_float(node.expr, frame, name) + + def sort(self, node: ir.Sort, frame: FrameT_contra, name: str) -> Self: + return self._with_evaluated(self._evaluated, name) + + def sort_by(self, node: ir.SortBy, frame: FrameT_contra, name: str) -> Self: + return self._with_evaluated(self._evaluated, name) + + def std(self, node: agg.Std, frame: FrameT_contra, name: str) -> Self: + return self.from_python(None, name, dtype=None, version=self.version) + + def sum(self, node: agg.Sum, frame: FrameT_contra, name: str) -> Self: + return self._with_evaluated(self._evaluated, name) + + def var(self, node: agg.Var, frame: FrameT_contra, name: str) -> Self: + return self.from_python(None, name, dtype=None, version=self.version) + + # NOTE: `Filter` behaves the same, (maybe) no need to override + + +class EagerScalar( + CompliantScalar[FrameT_contra, SeriesT], + EagerExpr[FrameT_contra, SeriesT], + Protocol[FrameT_contra, SeriesT], +): + def __len__(self) -> int: + return 1 + + def to_python(self) -> PythonLiteral: ... + + +class LazyScalar( + CompliantScalar[FrameT_contra, SeriesT], + LazyExpr[FrameT_contra, SeriesT, LengthT], + Protocol[FrameT_contra, SeriesT, LengthT], +): ... diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py new file mode 100644 index 0000000000..8c5d2fe3b4 --- /dev/null +++ b/narwhals/_plan/compliant/series.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol + +from narwhals._plan.compliant.typing import HasVersion +from narwhals._plan.typing import NativeSeriesT +from narwhals._utils import Version + +if TYPE_CHECKING: + from collections.abc import Iterable + + from typing_extensions import Self + + from narwhals._plan.series import Series + from narwhals.dtypes import DType + from narwhals.typing import Into1DArray, IntoDType, _1DArray + + +class CompliantSeries(HasVersion, Protocol[NativeSeriesT]): + _native: NativeSeriesT + _name: str + + def __len__(self) -> int: + return len(self.native) + + def __narwhals_series__(self) -> Self: + return self + + def _with_native(self, native: NativeSeriesT) -> Self: + return self.from_native(native, self.name, version=self.version) + + @classmethod + def from_iterable( + cls, + data: Iterable[Any], + *, + version: Version, + name: str = "", + dtype: IntoDType | None = None, + ) -> Self: ... + @classmethod + def from_native( + cls, native: NativeSeriesT, name: str = "", /, *, version: Version = Version.MAIN + ) -> Self: + obj = cls.__new__(cls) + obj._native = native + obj._name = name + obj._version = version + return obj + + @classmethod + def from_numpy( + cls, data: Into1DArray, name: str = "", /, *, version: Version = Version.MAIN + ) -> Self: ... + @property + def dtype(self) -> DType: ... + @property + def name(self) -> str: + return self._name + + @property + def native(self) -> NativeSeriesT: + return self._native + + def alias(self, name: str) -> Self: + return self.from_native(self.native, name, version=self.version) + + def cast(self, dtype: IntoDType) -> Self: ... + def to_list(self) -> list[Any]: ... + def to_narwhals(self) -> Series[NativeSeriesT]: + from narwhals._plan.series import Series + + return Series[NativeSeriesT](self) + + def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: ... diff --git a/narwhals/_plan/compliant/typing.py b/narwhals/_plan/compliant/typing.py new file mode 100644 index 0000000000..91ad9320ed --- /dev/null +++ b/narwhals/_plan/compliant/typing.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol + +from narwhals._typing_compat import TypeVar + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from narwhals._plan.compliant.column import ExprDispatch + from narwhals._plan.compliant.dataframe import ( + CompliantDataFrame, + CompliantFrame, + EagerDataFrame, + ) + from narwhals._plan.compliant.expr import CompliantExpr, EagerExpr, LazyExpr + from narwhals._plan.compliant.group_by import GroupByResolver + from narwhals._plan.compliant.namespace import CompliantNamespace + from narwhals._plan.compliant.scalar import CompliantScalar, EagerScalar, LazyScalar + from narwhals._plan.compliant.series import CompliantSeries + from narwhals._utils import Version + +T = TypeVar("T") +R_co = TypeVar("R_co", covariant=True) +LengthT = TypeVar("LengthT") +NativeT_co = TypeVar("NativeT_co", covariant=True, default=Any) + +ConcatT1 = TypeVar("ConcatT1") +ConcatT2 = TypeVar("ConcatT2", default=ConcatT1) + +ColumnT = TypeVar("ColumnT") +ColumnT_co = TypeVar("ColumnT_co", covariant=True) + +ResolverT_co = TypeVar("ResolverT_co", bound="GroupByResolver", covariant=True) + +ExprAny: TypeAlias = "CompliantExpr[Any, Any]" +ScalarAny: TypeAlias = "CompliantScalar[Any, Any]" +SeriesAny: TypeAlias = "CompliantSeries[Any]" +FrameAny: TypeAlias = "CompliantFrame[Any, Any]" +DataFrameAny: TypeAlias = "CompliantDataFrame[Any, Any, Any]" +NamespaceAny: TypeAlias = "CompliantNamespace[Any, Any, Any]" + +EagerExprAny: TypeAlias = "EagerExpr[Any, Any]" +EagerScalarAny: TypeAlias = "EagerScalar[Any, Any]" +EagerDataFrameAny: TypeAlias = "EagerDataFrame[Any, Any, Any]" + +LazyExprAny: TypeAlias = "LazyExpr[Any, Any, Any]" +LazyScalarAny: TypeAlias = "LazyScalar[Any, Any, Any]" + +ExprT_co = TypeVar("ExprT_co", bound=ExprAny, covariant=True) +ScalarT = TypeVar("ScalarT", bound=ScalarAny) +ScalarT_co = TypeVar("ScalarT_co", bound=ScalarAny, covariant=True) +SeriesT = TypeVar("SeriesT", bound=SeriesAny) +SeriesT_co = TypeVar("SeriesT_co", bound=SeriesAny, covariant=True) +FrameT = TypeVar("FrameT", bound=FrameAny) +FrameT_co = TypeVar("FrameT_co", bound=FrameAny, covariant=True) +FrameT_contra = TypeVar("FrameT_contra", bound=FrameAny, contravariant=True) +DataFrameT = TypeVar("DataFrameT", bound=DataFrameAny) +NamespaceT_co = TypeVar("NamespaceT_co", bound="NamespaceAny", covariant=True) + +EagerExprT_co = TypeVar("EagerExprT_co", bound=EagerExprAny, covariant=True) +EagerScalarT_co = TypeVar("EagerScalarT_co", bound=EagerScalarAny, covariant=True) +EagerDataFrameT = TypeVar("EagerDataFrameT", bound=EagerDataFrameAny) + +LazyExprT_co = TypeVar("LazyExprT_co", bound=LazyExprAny, covariant=True) +LazyScalarT_co = TypeVar("LazyScalarT_co", bound=LazyScalarAny, covariant=True) + +Ctx: TypeAlias = "ExprDispatch[FrameT_contra, R_co, NamespaceAny]" +"""Type of an unknown expression dispatch context. + +- `FrameT_contra`: Compliant data/lazyframe +- `R_co`: Upper bound return type of the context +""" + + +class SupportsNarwhalsNamespace(Protocol[NamespaceT_co]): + def __narwhals_namespace__(self) -> NamespaceT_co: ... + + +def namespace(obj: SupportsNarwhalsNamespace[NamespaceT_co], /) -> NamespaceT_co: + """Return the compliant namespace.""" + return obj.__narwhals_namespace__() + + +# NOTE: Unlike `nw._utils._StoresVersion`, here the property is public +class HasVersion(Protocol): + _version: Version + + @property + def version(self) -> Version: + """Narwhals API version (V1 or MAIN).""" + return self._version diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 8956c33457..a914923226 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -10,7 +10,8 @@ from narwhals._plan.typing import ( IntoExpr, NativeDataFrameT, - NativeFrameT, + NativeDataFrameT_co, + NativeFrameT_co, NativeSeriesT, OneOrIterable, ) @@ -22,14 +23,16 @@ from collections.abc import Sequence import pyarrow as pa - from typing_extensions import Self + from typing_extensions import Self, TypeAlias - from narwhals._plan.protocols import CompliantBaseFrame, CompliantDataFrame - from narwhals.typing import NativeFrame + from narwhals._plan.arrow.typing import NativeArrowDataFrame + from narwhals._plan.compliant.dataframe import CompliantDataFrame, CompliantFrame +Incomplete: TypeAlias = Any -class BaseFrame(Generic[NativeFrameT]): - _compliant: CompliantBaseFrame[Any, NativeFrameT] + +class BaseFrame(Generic[NativeFrameT_co]): + _compliant: CompliantFrame[Any, NativeFrameT_co] _version: ClassVar[Version] = Version.MAIN @property @@ -47,30 +50,26 @@ def columns(self) -> list[str]: def __repr__(self) -> str: # pragma: no cover return generate_repr(f"nw.{type(self).__name__}", self.to_native().__repr__()) - @classmethod - def from_native(cls, native: Any, /) -> Self: - raise NotImplementedError + def __init__(self, compliant: CompliantFrame[Any, NativeFrameT_co], /) -> None: + self._compliant = compliant - @classmethod - def _from_compliant(cls, compliant: CompliantBaseFrame[Any, NativeFrameT], /) -> Self: - obj = cls.__new__(cls) - obj._compliant = compliant - return obj + def _with_compliant(self, compliant: CompliantFrame[Any, Incomplete], /) -> Self: + return type(self)(compliant) - def to_native(self) -> NativeFrameT: + def to_native(self) -> NativeFrameT_co: return self._compliant.native def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: named_irs, schema = prepare_projection( _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), schema=self ) - return self._from_compliant(self._compliant.select(schema.select_irs(named_irs))) + return self._with_compliant(self._compliant.select(schema.select_irs(named_irs))) def with_columns(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: named_irs, schema = prepare_projection( _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), schema=self ) - return self._from_compliant( + return self._with_compliant( self._compliant.with_columns(schema.with_columns_irs(named_irs)) ) @@ -85,32 +84,43 @@ def sort( by, *more_by, descending=descending, nulls_last=nulls_last ) named_irs, _ = prepare_projection(sort, schema=self) - return self._from_compliant(self._compliant.sort(named_irs, opts)) + return self._with_compliant(self._compliant.sort(named_irs, opts)) def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: - return self._from_compliant(self._compliant.drop(columns, strict=strict)) + return self._with_compliant(self._compliant.drop(columns, strict=strict)) def drop_nulls(self, subset: str | Sequence[str] | None = None) -> Self: subset = [subset] if isinstance(subset, str) else subset - return self._from_compliant(self._compliant.drop_nulls(subset)) + return self._with_compliant(self._compliant.drop_nulls(subset)) -class DataFrame(BaseFrame[NativeDataFrameT], Generic[NativeDataFrameT, NativeSeriesT]): - _compliant: CompliantDataFrame[Any, NativeDataFrameT, NativeSeriesT] +class DataFrame( + BaseFrame[NativeDataFrameT_co], Generic[NativeDataFrameT_co, NativeSeriesT] +): + _compliant: CompliantDataFrame[Any, NativeDataFrameT_co, NativeSeriesT] @property def _series(self) -> type[Series[NativeSeriesT]]: return Series[NativeSeriesT] - # NOTE: Gave up on trying to get typing working for now + @overload + @classmethod + def from_native( + cls: type[DataFrame[Any, Any]], native: NativeArrowDataFrame, / + ) -> DataFrame[pa.Table, pa.ChunkedArray[Any]]: ... + @overload + @classmethod + def from_native( + cls: type[DataFrame[Any, Any]], native: NativeDataFrameT, / + ) -> DataFrame[NativeDataFrameT]: ... @classmethod - def from_native( # type: ignore[override] - cls, native: NativeFrame, / - ) -> DataFrame[pa.Table, pa.ChunkedArray[Any]]: + def from_native( + cls: type[DataFrame[Any, Any]], native: NativeDataFrameT, / + ) -> DataFrame[Any, Any]: if is_pyarrow_table(native): from narwhals._plan.arrow.dataframe import ArrowDataFrame - return ArrowDataFrame.from_native(native, cls._version).to_narwhals() + return cls(ArrowDataFrame.from_native(native, cls._version)) raise NotImplementedError(type(native)) @@ -129,7 +139,7 @@ def to_dict( ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[Any]]: if as_series: return { - key: self._series._from_compliant(value) + key: self._series(value) for key, value in self._compliant.to_dict(as_series=as_series).items() } return self._compliant.to_dict(as_series=as_series) diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index a898bf879b..1b283f81a6 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -29,12 +29,12 @@ if t.TYPE_CHECKING: from typing_extensions import Self + from narwhals._plan.compliant.typing import Ctx, FrameT_contra, R_co from narwhals._plan.expressions.functions import MapBatches # noqa: F401 from narwhals._plan.expressions.literal import LiteralValue from narwhals._plan.expressions.selectors import Selector from narwhals._plan.expressions.window import Window from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions - from narwhals._plan.protocols import Ctx, FrameT_contra, R_co from narwhals.dtypes import DType __all__ = [ diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index 5e95bd484e..52ccd9ef6b 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Generic from narwhals._plan._parse import parse_into_seq_of_expr_ir -from narwhals._plan.protocols import GroupByResolver as Resolved, Grouper +from narwhals._plan.compliant.group_by import GroupByResolver as Resolved, Grouper from narwhals._plan.typing import DataFrameT if TYPE_CHECKING: @@ -25,7 +25,7 @@ def __init__(self, frame: DataFrameT, grouper: Grouped, /) -> None: def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFrameT: frame = self._frame - return frame._from_compliant( + return frame._with_compliant( self._grouper.agg(*aggs, **named_aggs) .resolve(frame) .evaluate(frame._compliant) @@ -35,7 +35,7 @@ def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: frame = self._frame resolver = self._grouper.agg().resolve(frame) for key, df in frame._compliant.group_by_resolver(resolver): - yield key, frame._from_compliant(df) + yield key, frame._with_compliant(df) class Grouped(Grouper["Resolved"]): diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py deleted file mode 100644 index cff5e790e8..0000000000 --- a/narwhals/_plan/protocols.py +++ /dev/null @@ -1,891 +0,0 @@ -"""TODO: Split this module up into `narwhals._plan.compliant.*`.""" - -from __future__ import annotations - -from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized -from itertools import chain -from typing import TYPE_CHECKING, Any, Literal, Protocol, overload - -from narwhals._plan._expansion import prepare_projection -from narwhals._plan._parse import parse_into_seq_of_expr_ir -from narwhals._plan.common import flatten_hash_safe, replace, temp -from narwhals._plan.typing import ( - IntoExpr, - NativeDataFrameT, - NativeFrameT, - NativeSeriesT, - Seq, -) -from narwhals._typing_compat import TypeVar -from narwhals._utils import Version -from narwhals.exceptions import ComputeError - -if TYPE_CHECKING: - from typing_extensions import Self, TypeAlias, TypeIs - - from narwhals._plan import expressions as ir - from narwhals._plan.dataframe import BaseFrame, DataFrame - from narwhals._plan.expressions import ( - BinaryExpr, - ExprIR, - FunctionExpr, - NamedIR, - aggregation as agg, - boolean, - functions as F, - ) - from narwhals._plan.expressions.boolean import IsBetween, IsFinite, IsNan, IsNull, Not - from narwhals._plan.expressions.ranges import IntRange - from narwhals._plan.expressions.strings import ConcatStr - from narwhals._plan.options import SortMultipleOptions - from narwhals._plan.schema import FrozenSchema, IntoFrozenSchema - from narwhals._plan.series import Series - from narwhals._plan.typing import OneOrIterable - from narwhals.dtypes import DType - from narwhals.typing import ( - ConcatMethod, - Into1DArray, - IntoDType, - IntoSchema, - NonNestedLiteral, - PythonLiteral, - _1DArray, - ) - -T = TypeVar("T") -R_co = TypeVar("R_co", covariant=True) -LengthT = TypeVar("LengthT") -NativeT_co = TypeVar("NativeT_co", covariant=True, default=Any) - -ConcatT1 = TypeVar("ConcatT1") -ConcatT2 = TypeVar("ConcatT2", default=ConcatT1) - -ColumnT = TypeVar("ColumnT") -ColumnT_co = TypeVar("ColumnT_co", covariant=True) - -ResolverT_co = TypeVar("ResolverT_co", bound="GroupByResolver", covariant=True) - -ExprAny: TypeAlias = "CompliantExpr[Any, Any]" -ScalarAny: TypeAlias = "CompliantScalar[Any, Any]" -SeriesAny: TypeAlias = "CompliantSeries[Any]" -FrameAny: TypeAlias = "CompliantBaseFrame[Any, Any]" -DataFrameAny: TypeAlias = "CompliantDataFrame[Any, Any, Any]" -NamespaceAny: TypeAlias = "CompliantNamespace[Any, Any, Any]" - -EagerExprAny: TypeAlias = "EagerExpr[Any, Any]" -EagerScalarAny: TypeAlias = "EagerScalar[Any, Any]" -EagerDataFrameAny: TypeAlias = "EagerDataFrame[Any, Any, Any]" - -LazyExprAny: TypeAlias = "LazyExpr[Any, Any, Any]" -LazyScalarAny: TypeAlias = "LazyScalar[Any, Any, Any]" - -ExprT_co = TypeVar("ExprT_co", bound=ExprAny, covariant=True) -ScalarT = TypeVar("ScalarT", bound=ScalarAny) -ScalarT_co = TypeVar("ScalarT_co", bound=ScalarAny, covariant=True) -SeriesT = TypeVar("SeriesT", bound=SeriesAny) -SeriesT_co = TypeVar("SeriesT_co", bound=SeriesAny, covariant=True) -FrameT = TypeVar("FrameT", bound=FrameAny) -FrameT_co = TypeVar("FrameT_co", bound=FrameAny, covariant=True) -FrameT_contra = TypeVar("FrameT_contra", bound=FrameAny, contravariant=True) -DataFrameT = TypeVar("DataFrameT", bound=DataFrameAny) -NamespaceT_co = TypeVar("NamespaceT_co", bound="NamespaceAny", covariant=True) - -EagerExprT_co = TypeVar("EagerExprT_co", bound=EagerExprAny, covariant=True) -EagerScalarT_co = TypeVar("EagerScalarT_co", bound=EagerScalarAny, covariant=True) -EagerDataFrameT = TypeVar("EagerDataFrameT", bound=EagerDataFrameAny) - -LazyExprT_co = TypeVar("LazyExprT_co", bound=LazyExprAny, covariant=True) -LazyScalarT_co = TypeVar("LazyScalarT_co", bound=LazyScalarAny, covariant=True) - -Ctx: TypeAlias = "ExprDispatch[FrameT_contra, R_co, NamespaceAny]" -"""Type of an unknown expression dispatch context. - -- `FrameT_contra`: Compliant data/lazyframe -- `R_co`: Upper bound return type of the context -""" - - -class SupportsNarwhalsNamespace(Protocol[NamespaceT_co]): - def __narwhals_namespace__(self) -> NamespaceT_co: ... - - -def namespace(obj: SupportsNarwhalsNamespace[NamespaceT_co], /) -> NamespaceT_co: - """Return the compliant namespace.""" - return obj.__narwhals_namespace__() - - -# NOTE: Unlike the version in `nw._utils`, here `.version` it is public -class StoresVersion(Protocol): - _version: Version - - @property - def version(self) -> Version: - """Narwhals API version (V1 or MAIN).""" - return self._version - - -class SupportsBroadcast(Protocol[SeriesT, LengthT]): - """Minimal broadcasting for `Expr` results.""" - - @classmethod - def from_series(cls, series: SeriesT, /) -> Self: ... - def to_series(self) -> SeriesT: ... - def broadcast(self, length: LengthT, /) -> SeriesT: ... - def _length(self) -> LengthT: - """Return the length of the current expression.""" - ... - - @classmethod - def _length_max(cls, lengths: Sequence[LengthT], /) -> LengthT: - """Return the maximum length among `exprs`.""" - ... - - @classmethod - def _length_required( - cls, exprs: Sequence[SupportsBroadcast[SeriesT, LengthT]], / - ) -> LengthT | None: - """Return the broadcast length, if all lengths do not equal the maximum.""" - - @classmethod - def _length_all( - cls, exprs: Sequence[SupportsBroadcast[SeriesT, LengthT]], / - ) -> Sequence[LengthT]: - return [e._length() for e in exprs] - - @classmethod - def align( - cls, *exprs: OneOrIterable[SupportsBroadcast[SeriesT, LengthT]] - ) -> Iterator[SeriesT]: - exprs = tuple[SupportsBroadcast[SeriesT, LengthT], ...](flatten_hash_safe(exprs)) - length = cls._length_required(exprs) - if length is None: - for e in exprs: - yield e.to_series() - else: - for e in exprs: - yield e.broadcast(length) - - -class EagerBroadcast(Sized, SupportsBroadcast[SeriesT, int], Protocol[SeriesT]): - """Determines expression length via the size of the container.""" - - def _length(self) -> int: - return len(self) - - @classmethod - def _length_max(cls, lengths: Sequence[int], /) -> int: - return max(lengths) - - @classmethod - def _length_required( - cls, exprs: Sequence[SupportsBroadcast[SeriesT, int]], / - ) -> int | None: - lengths = cls._length_all(exprs) - max_length = cls._length_max(lengths) - required = any(len_ != max_length for len_ in lengths) - return max_length if required else None - - -class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]): - @classmethod - def from_ir(cls, node: ir.ExprIR, frame: FrameT_contra, name: str) -> R_co: - obj = cls.__new__(cls) - obj._version = frame.version - return node.dispatch(obj, frame, name) - - @classmethod - def from_named_ir(cls, named_ir: NamedIR[ir.ExprIR], frame: FrameT_contra) -> R_co: - return cls.from_ir(named_ir.expr, frame, named_ir.name) - - # NOTE: Needs to stay `covariant` and never be used as a parameter - def __narwhals_namespace__(self) -> NamespaceT_co: ... - - -class CompliantExpr(StoresVersion, Protocol[FrameT_contra, SeriesT_co]): - """Everything common to `Expr`/`Series` and `Scalar` literal values.""" - - _evaluated: Any - """Compliant or native value.""" - - @property - def name(self) -> str: ... - @classmethod - def from_native( - cls, native: Any, name: str = "", /, version: Version = Version.MAIN - ) -> Self: ... - def _with_native(self, native: Any, name: str, /) -> Self: - return self.from_native(native, name or self.name, self.version) - - # series & scalar - def abs(self, node: FunctionExpr[F.Abs], frame: FrameT_contra, name: str) -> Self: ... - def cast(self, node: ir.Cast, frame: FrameT_contra, name: str) -> Self: ... - def pow(self, node: FunctionExpr[F.Pow], frame: FrameT_contra, name: str) -> Self: ... - def not_(self, node: FunctionExpr[Not], frame: FrameT_contra, name: str) -> Self: ... - def fill_null( - self, node: FunctionExpr[F.FillNull], frame: FrameT_contra, name: str - ) -> Self: ... - def is_between( - self, node: FunctionExpr[IsBetween], frame: FrameT_contra, name: str - ) -> Self: ... - def is_finite( - self, node: FunctionExpr[IsFinite], frame: FrameT_contra, name: str - ) -> Self: ... - def is_nan( - self, node: FunctionExpr[IsNan], frame: FrameT_contra, name: str - ) -> Self: ... - def is_null( - self, node: FunctionExpr[IsNull], frame: FrameT_contra, name: str - ) -> Self: ... - def binary_expr(self, node: BinaryExpr, frame: FrameT_contra, name: str) -> Self: ... - def ternary_expr( - self, node: ir.TernaryExpr, frame: FrameT_contra, name: str - ) -> Self: ... - def over(self, node: ir.WindowExpr, frame: FrameT_contra, name: str) -> Self: ... - # NOTE: `Scalar` is returned **only** for un-partitioned `OrderableAggExpr` - # e.g. `nw.col("a").first().over(order_by="b")` - def over_ordered( - self, node: ir.OrderedWindowExpr, frame: FrameT_contra, name: str - ) -> Self | CompliantScalar[FrameT_contra, SeriesT_co]: ... - def map_batches( - self, node: ir.AnonymousExpr, frame: FrameT_contra, name: str - ) -> Self: ... - def rolling_expr( - self, node: ir.RollingExpr, frame: FrameT_contra, name: str - ) -> Self: ... - # series only (section 3) - def sort(self, node: ir.Sort, frame: FrameT_contra, name: str) -> Self: ... - def sort_by(self, node: ir.SortBy, frame: FrameT_contra, name: str) -> Self: ... - def filter(self, node: ir.Filter, frame: FrameT_contra, name: str) -> Self: ... - # series -> scalar - def first( - self, node: agg.First, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def last( - self, node: agg.Last, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def arg_min( - self, node: agg.ArgMin, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def arg_max( - self, node: agg.ArgMax, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def sum( - self, node: agg.Sum, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def n_unique( - self, node: agg.NUnique, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def std( - self, node: agg.Std, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def var( - self, node: agg.Var, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def quantile( - self, node: agg.Quantile, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def count( - self, node: agg.Count, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def len( - self, node: agg.Len, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def max( - self, node: agg.Max, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def mean( - self, node: agg.Mean, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def median( - self, node: agg.Median, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def min( - self, node: agg.Min, frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def all( - self, node: FunctionExpr[boolean.All], frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - def any( - self, node: FunctionExpr[boolean.Any], frame: FrameT_contra, name: str - ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - - -class CompliantScalar( - CompliantExpr[FrameT_contra, SeriesT_co], Protocol[FrameT_contra, SeriesT_co] -): - _name: str - - @property - def name(self) -> str: - return self._name - - @classmethod - def from_python( - cls, - value: PythonLiteral, - name: str = "literal", - /, - *, - dtype: IntoDType | None, - version: Version, - ) -> Self: ... - def _with_evaluated(self, evaluated: Any, name: str) -> Self: - """Expr is based on a series having these via accessors, but a scalar needs to keep passing through.""" - cls = type(self) - obj = cls.__new__(cls) - obj._evaluated = evaluated - obj._name = name or self.name - obj._version = self.version - return obj - - def max(self, node: agg.Max, frame: FrameT_contra, name: str) -> Self: - """Returns self.""" - return self._with_evaluated(self._evaluated, name) - - def min(self, node: agg.Min, frame: FrameT_contra, name: str) -> Self: - """Returns self.""" - return self._with_evaluated(self._evaluated, name) - - def sum(self, node: agg.Sum, frame: FrameT_contra, name: str) -> Self: - """Returns self.""" - return self._with_evaluated(self._evaluated, name) - - def first(self, node: agg.First, frame: FrameT_contra, name: str) -> Self: - """Returns self.""" - return self._with_evaluated(self._evaluated, name) - - def last(self, node: agg.Last, frame: FrameT_contra, name: str) -> Self: - """Returns self.""" - return self._with_evaluated(self._evaluated, name) - - def _cast_float(self, node: ir.ExprIR, frame: FrameT_contra, name: str) -> Self: - """`polars` interpolates a single scalar as a float.""" - dtype = self.version.dtypes.Float64() - return self.cast(node.cast(dtype), frame, name) - - def mean(self, node: agg.Mean, frame: FrameT_contra, name: str) -> Self: - return self._cast_float(node.expr, frame, name) - - def median(self, node: agg.Median, frame: FrameT_contra, name: str) -> Self: - return self._cast_float(node.expr, frame, name) - - def quantile(self, node: agg.Quantile, frame: FrameT_contra, name: str) -> Self: - return self._cast_float(node.expr, frame, name) - - def n_unique(self, node: agg.NUnique, frame: FrameT_contra, name: str) -> Self: - """Returns 1.""" - ... - - def std(self, node: agg.Std, frame: FrameT_contra, name: str) -> Self: - """Returns null.""" - ... - - def var(self, node: agg.Var, frame: FrameT_contra, name: str) -> Self: - """Returns null.""" - ... - - def arg_min(self, node: agg.ArgMin, frame: FrameT_contra, name: str) -> Self: - """Returns 0.""" - ... - - def arg_max(self, node: agg.ArgMax, frame: FrameT_contra, name: str) -> Self: - """Returns 0.""" - ... - - def count(self, node: agg.Count, frame: FrameT_contra, name: str) -> Self: - """Returns 0 if null, else 1.""" - ... - - def len(self, node: agg.Len, frame: FrameT_contra, name: str) -> Self: - """Returns 1.""" - ... - - def sort(self, node: ir.Sort, frame: FrameT_contra, name: str) -> Self: - return self._with_evaluated(self._evaluated, name) - - def sort_by(self, node: ir.SortBy, frame: FrameT_contra, name: str) -> Self: - return self._with_evaluated(self._evaluated, name) - - # NOTE: `Filter` behaves the same, (maybe) no need to override - - -class EagerExpr( - EagerBroadcast[SeriesT], - CompliantExpr[FrameT_contra, SeriesT], - Protocol[FrameT_contra, SeriesT], -): ... - - -class LazyExpr( - SupportsBroadcast[SeriesT, LengthT], - CompliantExpr[FrameT_contra, SeriesT], - Protocol[FrameT_contra, SeriesT, LengthT], -): ... - - -class EagerScalar( - CompliantScalar[FrameT_contra, SeriesT], - EagerExpr[FrameT_contra, SeriesT], - Protocol[FrameT_contra, SeriesT], -): - def __len__(self) -> int: - return 1 - - def to_python(self) -> PythonLiteral: ... - - -class LazyScalar( - CompliantScalar[FrameT_contra, SeriesT], - LazyExpr[FrameT_contra, SeriesT, LengthT], - Protocol[FrameT_contra, SeriesT, LengthT], -): ... - - -# NOTE: `mypy` is wrong -# error: Invariant type variable "ConcatT2" used in protocol where covariant one is expected [misc] -class Concat(Protocol[ConcatT1, ConcatT2]): # type: ignore[misc] - @overload - def concat(self, items: Iterable[ConcatT1], *, how: ConcatMethod) -> ConcatT1: ... - # Series only supports vertical publicly (like in polars) - @overload - def concat( - self, items: Iterable[ConcatT2], *, how: Literal["vertical"] - ) -> ConcatT2: ... - def concat( - self, items: Iterable[ConcatT1 | ConcatT2], *, how: ConcatMethod - ) -> ConcatT1 | ConcatT2: ... - - -class EagerConcat(Concat[ConcatT1, ConcatT2], Protocol[ConcatT1, ConcatT2]): # type: ignore[misc] - def _concat_diagonal(self, items: Iterable[ConcatT1], /) -> ConcatT1: ... - # Series can be used here to go from [Series, Series] -> DataFrame - # but that is only available privately - def _concat_horizontal(self, items: Iterable[ConcatT1 | ConcatT2], /) -> ConcatT1: ... - def _concat_vertical( - self, items: Iterable[ConcatT1 | ConcatT2], / - ) -> ConcatT1 | ConcatT2: ... - - -class CompliantNamespace(StoresVersion, Protocol[FrameT, ExprT_co, ScalarT_co]): - @property - def _frame(self) -> type[FrameT]: ... - @property - def _expr(self) -> type[ExprT_co]: ... - @property - def _scalar(self) -> type[ScalarT_co]: ... - def col(self, node: ir.Column, frame: FrameT, name: str) -> ExprT_co: ... - def lit( - self, node: ir.Literal[Any], frame: FrameT, name: str - ) -> ExprT_co | ScalarT_co: ... - def len(self, node: ir.Len, frame: FrameT, name: str) -> ScalarT_co: ... - def any_horizontal( - self, node: FunctionExpr[boolean.AnyHorizontal], frame: FrameT, name: str - ) -> ExprT_co | ScalarT_co: ... - def all_horizontal( - self, node: FunctionExpr[boolean.AllHorizontal], frame: FrameT, name: str - ) -> ExprT_co | ScalarT_co: ... - def sum_horizontal( - self, node: FunctionExpr[F.SumHorizontal], frame: FrameT, name: str - ) -> ExprT_co | ScalarT_co: ... - def min_horizontal( - self, node: FunctionExpr[F.MinHorizontal], frame: FrameT, name: str - ) -> ExprT_co | ScalarT_co: ... - def max_horizontal( - self, node: FunctionExpr[F.MaxHorizontal], frame: FrameT, name: str - ) -> ExprT_co | ScalarT_co: ... - def mean_horizontal( - self, node: FunctionExpr[F.MeanHorizontal], frame: FrameT, name: str - ) -> ExprT_co | ScalarT_co: ... - def concat_str( - self, node: FunctionExpr[ConcatStr], frame: FrameT, name: str - ) -> ExprT_co | ScalarT_co: ... - def int_range( - self, node: ir.RangeExpr[IntRange], frame: FrameT, name: str - ) -> ExprT_co: ... - - -class EagerNamespace( - EagerConcat[EagerDataFrameT, SeriesT], - CompliantNamespace[EagerDataFrameT, EagerExprT_co, EagerScalarT_co], - Protocol[EagerDataFrameT, SeriesT, EagerExprT_co, EagerScalarT_co], -): - @property - def _series(self) -> type[SeriesT]: ... - @property - def _dataframe(self) -> type[EagerDataFrameT]: ... - @property - def _frame(self) -> type[EagerDataFrameT]: - return self._dataframe - - def _is_series(self, obj: Any) -> TypeIs[SeriesT]: - return isinstance(obj, self._series) - - def _is_dataframe(self, obj: Any) -> TypeIs[EagerDataFrameT]: - return isinstance(obj, self._dataframe) - - @overload - def lit( - self, node: ir.Literal[NonNestedLiteral], frame: EagerDataFrameT, name: str - ) -> EagerScalarT_co: ... - @overload - def lit( - self, node: ir.Literal[Series[Any]], frame: EagerDataFrameT, name: str - ) -> EagerExprT_co: ... - def lit( - self, node: ir.Literal[Any], frame: EagerDataFrameT, name: str - ) -> EagerExprT_co | EagerScalarT_co: ... - def len(self, node: ir.Len, frame: EagerDataFrameT, name: str) -> EagerScalarT_co: - return self._scalar.from_python( - len(frame), name or node.name, dtype=None, version=frame.version - ) - - -class LazyNamespace( - Concat[FrameT, FrameT], - CompliantNamespace[FrameT, LazyExprT_co, LazyScalarT_co], - Protocol[FrameT, LazyExprT_co, LazyScalarT_co], -): - @property - def _lazyframe(self) -> type[FrameT]: ... - @property - def _frame(self) -> type[FrameT]: - return self._lazyframe - - -class CompliantBaseFrame(StoresVersion, Protocol[ColumnT_co, NativeFrameT]): - _native: NativeFrameT - - def __narwhals_namespace__(self) -> Any: ... - @property - def _group_by(self) -> type[CompliantGroupBy[Self]]: ... - @property - def native(self) -> NativeFrameT: - return self._native - - @property - def columns(self) -> list[str]: ... - def to_narwhals(self) -> BaseFrame[NativeFrameT]: ... - @classmethod - def from_native(cls, native: NativeFrameT, /, version: Version) -> Self: - obj = cls.__new__(cls) - obj._native = native - obj._version = version - return obj - - def _with_native(self, native: NativeFrameT) -> Self: - return self.from_native(native, self.version) - - @property - def schema(self) -> Mapping[str, DType]: ... - def _evaluate_irs( - self, nodes: Iterable[NamedIR[ir.ExprIR]], / - ) -> Iterator[ColumnT_co]: ... - def select(self, irs: Seq[NamedIR]) -> Self: ... - def select_names(self, *column_names: str) -> Self: ... - def with_columns(self, irs: Seq[NamedIR]) -> Self: ... - def sort(self, by: Seq[NamedIR], options: SortMultipleOptions) -> Self: ... - def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: ... - def drop_nulls(self, subset: Sequence[str] | None) -> Self: ... - - -class CompliantDataFrame( - CompliantBaseFrame[SeriesT, NativeDataFrameT], - Protocol[SeriesT, NativeDataFrameT, NativeSeriesT], -): - @property - def _group_by(self) -> type[DataFrameGroupBy[Self]]: ... - @property - def _grouper(self) -> type[Grouped]: - return Grouped - - @classmethod - def from_dict( - cls, data: Mapping[str, Any], /, *, schema: IntoSchema | None = None - ) -> Self: ... - def group_by_agg( - self, by: OneOrIterable[IntoExpr], aggs: OneOrIterable[IntoExpr], / - ) -> Self: - """Compliant-level `group_by(by).agg(agg)`, allows `Expr`.""" - return self._grouper.by(by).agg(aggs).resolve(self).evaluate(self) - - def group_by_names(self, names: Seq[str], /) -> DataFrameGroupBy[Self]: - """Compliant-level `group_by`, allowing only `str` keys.""" - return self._group_by.by_names(self, names) - - def group_by_resolver(self, resolver: GroupByResolver, /) -> DataFrameGroupBy[Self]: - """Narwhals-level resolved `group_by`. - - `keys`, `aggs` are already parsed and projections planned. - """ - return self._group_by.from_resolver(self, resolver) - - def to_narwhals(self) -> DataFrame[NativeDataFrameT, NativeSeriesT]: ... - @overload - def to_dict(self, *, as_series: Literal[True]) -> dict[str, SeriesT]: ... - @overload - def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... - @overload - def to_dict( - self, *, as_series: bool - ) -> dict[str, SeriesT] | dict[str, list[Any]]: ... - def to_dict( - self, *, as_series: bool - ) -> dict[str, SeriesT] | dict[str, list[Any]]: ... - def __len__(self) -> int: ... - def with_row_index(self, name: str) -> Self: ... - def row(self, index: int) -> tuple[Any, ...]: ... - - -class EagerDataFrame( - CompliantDataFrame[SeriesT, NativeDataFrameT, NativeSeriesT], - Protocol[SeriesT, NativeDataFrameT, NativeSeriesT], -): - @property - def _group_by(self) -> type[EagerDataFrameGroupBy[Self]]: ... - def __narwhals_namespace__(self) -> EagerNamespace[Self, SeriesT, Any, Any]: ... - def select(self, irs: Seq[NamedIR]) -> Self: - return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) - - def with_columns(self, irs: Seq[NamedIR]) -> Self: - return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) - - -class CompliantSeries(StoresVersion, Protocol[NativeSeriesT]): - _native: NativeSeriesT - _name: str - - def __narwhals_series__(self) -> Self: - return self - - @property - def native(self) -> NativeSeriesT: - return self._native - - @property - def dtype(self) -> DType: ... - @property - def name(self) -> str: - return self._name - - def to_narwhals(self) -> Series[NativeSeriesT]: - from narwhals._plan.series import Series - - return Series[NativeSeriesT]._from_compliant(self) - - @classmethod - def from_native( - cls, native: NativeSeriesT, name: str = "", /, *, version: Version = Version.MAIN - ) -> Self: - obj = cls.__new__(cls) - obj._native = native - obj._name = name - obj._version = version - return obj - - @classmethod - def from_numpy( - cls, data: Into1DArray, name: str = "", /, *, version: Version = Version.MAIN - ) -> Self: ... - @classmethod - def from_iterable( - cls, - data: Iterable[Any], - *, - version: Version, - name: str = "", - dtype: IntoDType | None = None, - ) -> Self: ... - def _with_native(self, native: NativeSeriesT) -> Self: - return self.from_native(native, self.name, version=self.version) - - def alias(self, name: str) -> Self: - return self.from_native(self.native, name, version=self.version) - - def cast(self, dtype: IntoDType) -> Self: ... - def __len__(self) -> int: - return len(self.native) - - def to_list(self) -> list[Any]: ... - def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: ... - - -class CompliantGroupBy(Protocol[FrameT_co]): - @property - def compliant(self) -> FrameT_co: ... - def agg(self, irs: Seq[NamedIR]) -> FrameT_co: ... - - -class DataFrameGroupBy(CompliantGroupBy[DataFrameT], Protocol[DataFrameT]): - _keys: Seq[NamedIR] - _key_names: Seq[str] - - @classmethod - def from_resolver( - cls, df: DataFrameT, resolver: GroupByResolver, / - ) -> DataFrameGroupBy[DataFrameT]: ... - @classmethod - def by_names( - cls, df: DataFrameT, names: Seq[str], / - ) -> DataFrameGroupBy[DataFrameT]: ... - def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: ... - @property - def keys(self) -> Seq[NamedIR]: - return self._keys - - @property - def key_names(self) -> Seq[str]: - if names := self._key_names: - return names - msg = "at least one key is required in a group_by operation" - raise ComputeError(msg) - - -class EagerDataFrameGroupBy(DataFrameGroupBy[EagerDataFrameT], Protocol[EagerDataFrameT]): - _df: EagerDataFrameT - _key_names: Seq[str] - _key_names_original: Seq[str] - _column_names_original: Seq[str] - - @classmethod - def by_names(cls, df: EagerDataFrameT, names: Seq[str], /) -> Self: - obj = cls.__new__(cls) - obj._df = df - obj._keys = () - obj._key_names = names - obj._key_names_original = () - obj._column_names_original = tuple(df.columns) - return obj - - @classmethod - def from_resolver( - cls, df: EagerDataFrameT, resolver: GroupByResolver, / - ) -> EagerDataFrameGroupBy[EagerDataFrameT]: - key_names = resolver.key_names - if not resolver.requires_projection(): - df = df.drop_nulls(key_names) if resolver._drop_null_keys else df - return cls.by_names(df, key_names) - obj = cls.__new__(cls) - unique_names = temp.column_names(chain(key_names, df.columns)) - safe_keys = tuple( - replace(key, name=name) for key, name in zip(resolver.keys, unique_names) - ) - obj._df = df.with_columns(resolver._schema_in.with_columns_irs(safe_keys)) - obj._keys = safe_keys - obj._key_names = tuple(e.name for e in safe_keys) - obj._key_names_original = key_names - obj._column_names_original = resolver._schema_in.names - return obj - - -class Grouper(Protocol[ResolverT_co]): - """`GroupBy` helper for collecting and forwarding `Expr`s for projection. - - - Uses `Expr` everywhere (no need to duplicate layers) - - Resolver only needs schema (neither needs a frame, but can use one to get `schema`) - """ - - _keys: Seq[ExprIR] - _aggs: Seq[ExprIR] - _drop_null_keys: bool - - @classmethod - def by(cls, *by: OneOrIterable[IntoExpr]) -> Self: - obj = cls.__new__(cls) - obj._keys = parse_into_seq_of_expr_ir(*by) - return obj - - def agg(self, *aggs: OneOrIterable[IntoExpr]) -> Self: - self._aggs = parse_into_seq_of_expr_ir(*aggs) - return self - - @property - def _resolver(self) -> type[ResolverT_co]: ... - - def resolve(self, context: IntoFrozenSchema, /) -> ResolverT_co: - """Project keys and aggs in `context`, expanding all `Expr` -> `NamedIR`.""" - return self._resolver.from_grouper(self, context) - - -class GroupByResolver: - """Narwhals-level `GroupBy` resolver.""" - - _schema_in: FrozenSchema - _keys: Seq[NamedIR] - _aggs: Seq[NamedIR] - _key_names: Seq[str] - _schema: FrozenSchema - _drop_null_keys: bool - - @property - def keys(self) -> Seq[NamedIR]: - return self._keys - - @property - def aggs(self) -> Seq[NamedIR]: - return self._aggs - - @property - def key_names(self) -> Seq[str]: - if names := self._key_names: - return names - if keys := self.keys: - return tuple(e.name for e in keys) - msg = "at least one key is required in a group_by operation" - raise ComputeError(msg) - - @property - def schema(self) -> FrozenSchema: - return self._schema - - def evaluate(self, frame: DataFrameT) -> DataFrameT: - """Perform the `group_by` on `frame`.""" - return frame.group_by_resolver(self).agg(self.aggs) - - @classmethod - def from_grouper(cls, grouper: Grouper[Self], context: IntoFrozenSchema, /) -> Self: - """Loosely based on [`resolve_group_by`]. - - [`resolve_group_by`]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs#L1125-L1227 - """ - obj = cls.__new__(cls) - keys, schema_in = prepare_projection(grouper._keys, schema=context) - obj._keys, obj._schema_in = keys, schema_in - obj._key_names = tuple(e.name for e in keys) - obj._aggs, _ = prepare_projection(grouper._aggs, obj.key_names, schema=schema_in) - obj._schema = schema_in.select(keys).merge(schema_in.select(obj._aggs)) - obj._drop_null_keys = grouper._drop_null_keys - return obj - - def requires_projection(self, *, allow_aliasing: bool = False) -> bool: - """Return True is group keys contain anything that is not a column selection. - - Notes: - If False is returned, we can just use the resolved key names as a fast-path to group. - - Arguments: - allow_aliasing: If False (default), any aliasing is not considered to be column selection. - """ - if not all(key.is_column(allow_aliasing=allow_aliasing) for key in self.keys): - if self._drop_null_keys: - msg = "drop_null_keys cannot be True when keys contains Expr or Series" - raise NotImplementedError(msg) - return True - return False - - -class Resolved(GroupByResolver): - """Compliant-level `GroupBy` resolver.""" - - _drop_null_keys: bool = False - - -class Grouped(Grouper[Resolved]): - """Compliant-level `GroupBy` builder.""" - - _keys: Seq[ExprIR] - _aggs: Seq[ExprIR] - _drop_null_keys: bool = False - - @property - def _resolver(self) -> type[Resolved]: - return Resolved diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index 1ab9366ea3..5108d8ca4e 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -2,23 +2,19 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic -from narwhals._plan.typing import NativeSeriesT +from narwhals._plan.typing import NativeSeriesT, NativeSeriesT_co from narwhals._utils import Version from narwhals.dependencies import is_pyarrow_chunked_array if TYPE_CHECKING: from collections.abc import Iterator - import pyarrow as pa - from typing_extensions import Self - - from narwhals._plan.protocols import CompliantSeries + from narwhals._plan.compliant.series import CompliantSeries from narwhals.dtypes import DType - from narwhals.typing import NativeSeries -class Series(Generic[NativeSeriesT]): - _compliant: CompliantSeries[NativeSeriesT] +class Series(Generic[NativeSeriesT_co]): + _compliant: CompliantSeries[NativeSeriesT_co] _version: ClassVar[Version] = Version.MAIN @property @@ -33,27 +29,21 @@ def dtype(self) -> DType: def name(self) -> str: return self._compliant.name - # NOTE: Gave up on trying to get typing working for now + def __init__(self, compliant: CompliantSeries[NativeSeriesT_co], /) -> None: + self._compliant = compliant + @classmethod def from_native( - cls, native: NativeSeries, name: str = "", / - ) -> Series[pa.ChunkedArray[Any]]: + cls: type[Series[Any]], native: NativeSeriesT, name: str = "", / + ) -> Series[NativeSeriesT]: if is_pyarrow_chunked_array(native): from narwhals._plan.arrow.series import ArrowSeries - return ArrowSeries.from_native( - native, name, version=cls._version - ).to_narwhals() + return cls(ArrowSeries.from_native(native, name, version=cls._version)) raise NotImplementedError(type(native)) - @classmethod - def _from_compliant(cls, compliant: CompliantSeries[NativeSeriesT], /) -> Self: - obj = cls.__new__(cls) - obj._compliant = compliant - return obj - - def to_native(self) -> NativeSeriesT: + def to_native(self) -> NativeSeriesT_co: return self._compliant.native def to_list(self) -> list[Any]: @@ -63,5 +53,5 @@ def __iter__(self) -> Iterator[Any]: yield from self.to_native() -class SeriesV1(Series[NativeSeriesT]): +class SeriesV1(Series[NativeSeriesT_co]): _version: ClassVar[Version] = Version.V1 diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 2a734488a6..843e7f4b54 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -87,10 +87,22 @@ "NonNestedLiteralT", bound="NonNestedLiteral", default="NonNestedLiteral" ) NativeSeriesT = TypeVar("NativeSeriesT", bound="NativeSeries", default="NativeSeries") +NativeSeriesT_co = TypeVar( + "NativeSeriesT_co", bound="NativeSeries", covariant=True, default="NativeSeries" +) NativeFrameT = TypeVar("NativeFrameT", bound="NativeFrame", default="NativeFrame") +NativeFrameT_co = TypeVar( + "NativeFrameT_co", bound="NativeFrame", covariant=True, default="NativeFrame" +) NativeDataFrameT = TypeVar( "NativeDataFrameT", bound="NativeDataFrame", default="NativeDataFrame" ) +NativeDataFrameT_co = TypeVar( + "NativeDataFrameT_co", + bound="NativeDataFrame", + covariant=True, + default="NativeDataFrame", +) LiteralT = TypeVar("LiteralT", bound="NonNestedLiteral | Series[t.Any]", default=t.Any) MapIR: TypeAlias = "t.Callable[[ExprIR], ExprIR]" """A function to apply to all nodes in this tree.""" diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 7b7113e450..905f97ac40 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -540,6 +540,7 @@ def test_row_is_py_literal( if TYPE_CHECKING: + from typing_extensions import assert_type def test_protocol_expr() -> None: """Static test for all members implemented. @@ -554,3 +555,14 @@ def test_protocol_expr() -> None: scalar = ArrowScalar() assert expr assert scalar + + def test_dataframe_from_native_overloads() -> None: + """Ensure we can reveal the `NativeSeries` **without** a dependency.""" + data: dict[str, Any] = {} + native_good = pa.table(data) + result_good = nwp.DataFrame.from_native(native_good) + assert_type(result_good, "nwp.DataFrame[pa.Table, pa.ChunkedArray[Any]]") + + native_bad = native_good.to_batches()[0] + nwp.DataFrame.from_native(native_bad) # type: ignore[call-overload] + assert_type(native_bad, "pa.RecordBatch") From bde22aca49d8a0979cbfc5b7ea6d2802db114815 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 12 Oct 2025 16:15:30 +0000 Subject: [PATCH 363/368] feat(expr-ir): Acero `order_by`, `hashjoin` , `DataFrame.{filter,join}`, `Expr.is_{first,last}_distinct` (#3173) --- narwhals/_plan/_expr_ir.py | 4 + narwhals/_plan/_guards.py | 6 +- narwhals/_plan/_parse.py | 95 ++++++- narwhals/_plan/arrow/acero.py | 235 ++++++++++++++-- narwhals/_plan/arrow/dataframe.py | 37 ++- narwhals/_plan/arrow/expr.py | 45 ++- narwhals/_plan/arrow/functions.py | 46 +++- narwhals/_plan/arrow/series.py | 20 +- narwhals/_plan/arrow/typing.py | 11 + narwhals/_plan/common.py | 36 ++- narwhals/_plan/compliant/dataframe.py | 28 +- narwhals/_plan/compliant/expr.py | 16 +- narwhals/_plan/compliant/group_by.py | 2 +- narwhals/_plan/compliant/scalar.py | 13 +- narwhals/_plan/compliant/series.py | 10 +- narwhals/_plan/dataframe.py | 150 +++++++++- narwhals/_plan/exceptions.py | 2 +- narwhals/_plan/expr.py | 31 +-- narwhals/_plan/expressions/__init__.py | 6 +- narwhals/_plan/expressions/aggregation.py | 12 + narwhals/_plan/expressions/expr.py | 3 +- narwhals/_plan/expressions/selectors.py | 3 +- narwhals/_plan/options.py | 5 +- narwhals/_plan/series.py | 29 +- narwhals/_plan/typing.py | 15 +- tests/plan/compliant_test.py | 26 +- tests/plan/frame_filter_test.py | 176 ++++++++++++ tests/plan/group_by_test.py | 9 +- tests/plan/is_first_last_distinct_test.py | 134 +++++++++ tests/plan/join_test.py | 321 ++++++++++++++++++++++ tests/plan/utils.py | 25 +- 31 files changed, 1411 insertions(+), 140 deletions(-) create mode 100644 tests/plan/frame_filter_test.py create mode 100644 tests/plan/is_first_last_distinct_test.py create mode 100644 tests/plan/join_test.py diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py index 862fb0c44a..f0c9a08548 100644 --- a/narwhals/_plan/_expr_ir.py +++ b/narwhals/_plan/_expr_ir.py @@ -304,3 +304,7 @@ def is_column(self, *, allow_aliasing: bool = False) -> bool: ir = self.expr return isinstance(ir, Column) and ((self.name == ir.name) or allow_aliasing) + + +def named_ir(name: str, expr: ExprIRT, /) -> NamedIR[ExprIRT]: + return NamedIR(expr=expr, name=name) diff --git a/narwhals/_plan/_guards.py b/narwhals/_plan/_guards.py index ed2762d993..780070f038 100644 --- a/narwhals/_plan/_guards.py +++ b/narwhals/_plan/_guards.py @@ -15,7 +15,7 @@ from narwhals._plan.compliant.series import CompliantSeries from narwhals._plan.expr import Expr from narwhals._plan.series import Series - from narwhals._plan.typing import NativeSeriesT, Seq + from narwhals._plan.typing import IntoExprColumn, NativeSeriesT, Seq from narwhals.typing import NonNestedLiteral T = TypeVar("T") @@ -67,6 +67,10 @@ def is_series(obj: Series[NativeSeriesT] | Any) -> TypeIs[Series[NativeSeriesT]] return isinstance(obj, _series().Series) +def is_into_expr_column(obj: Any) -> TypeIs[IntoExprColumn]: + return isinstance(obj, (str, _expr().Expr, _series().Series)) + + def is_compliant_series( obj: CompliantSeries[NativeSeriesT] | Any, ) -> TypeIs[CompliantSeries[NativeSeriesT]]: diff --git a/narwhals/_plan/_parse.py b/narwhals/_plan/_parse.py index 651166ebee..c2a5cc7c2f 100644 --- a/narwhals/_plan/_parse.py +++ b/narwhals/_plan/_parse.py @@ -6,13 +6,14 @@ from itertools import chain from typing import TYPE_CHECKING -from narwhals._plan._guards import is_expr, is_iterable_reject +from narwhals._plan._guards import is_expr, is_into_expr_column, is_iterable_reject from narwhals._plan.exceptions import ( invalid_into_expr_error, is_iterable_pandas_error, is_iterable_polars_error, ) from narwhals.dependencies import get_polars, is_pandas_dataframe, is_pandas_series +from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: from collections.abc import Iterator @@ -22,7 +23,13 @@ from typing_extensions import TypeAlias, TypeIs from narwhals._plan.expressions import ExprIR - from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq + from narwhals._plan.typing import ( + IntoExpr, + IntoExprColumn, + OneOrIterable, + PartialSeries, + Seq, + ) from narwhals.typing import IntoDType T = TypeVar("T") @@ -85,15 +92,33 @@ def parse_into_expr_ir( - input: IntoExpr, *, str_as_lit: bool = False, dtype: IntoDType | None = None + input: IntoExpr | list[Any], + *, + str_as_lit: bool = False, + list_as_series: PartialSeries | None = None, + dtype: IntoDType | None = None, ) -> ExprIR: - """Parse a single input into an `ExprIR` node.""" + """Parse a single input into an `ExprIR` node. + + Arguments: + input: The input to be parsed as an expression. + str_as_lit: Interpret string input as a string literal. If set to `False` (default), + strings are parsed as column names. + list_as_series: Interpret list input as a Series literal, using the provided constructor. + If set to `None` (default), lists will raise when passed to `lit`. + dtype: If the input is expected to resolve to a literal with a known dtype, pass + this to the `lit` constructor. + """ from narwhals._plan import col, lit if is_expr(input): expr = input elif isinstance(input, str) and not str_as_lit: expr = col(input) + elif isinstance(input, list): + if list_as_series is None: + raise TypeError(input) + expr = lit(list_as_series(input)) else: expr = lit(input, dtype=dtype) return expr._ir @@ -105,50 +130,90 @@ def parse_into_seq_of_expr_ir( **named_inputs: IntoExpr, ) -> Seq[ExprIR]: """Parse variadic inputs into a flat sequence of `ExprIR` nodes.""" - return tuple(_parse_into_iter_expr_ir(first_input, *more_inputs, **named_inputs)) + return tuple( + _parse_into_iter_expr_ir( + first_input, *more_inputs, _list_as_series=None, **named_inputs + ) + ) def parse_predicates_constraints_into_expr_ir( - first_predicate: OneOrIterable[IntoExprColumn] = (), - *more_predicates: IntoExprColumn | _RaisesInvalidIntoExprError, + first_predicate: OneOrIterable[IntoExprColumn] | list[bool] = (), + *more_predicates: IntoExprColumn | list[bool] | _RaisesInvalidIntoExprError, + _list_as_series: PartialSeries | None = None, **constraints: IntoExpr, ) -> ExprIR: """Parse variadic predicates and constraints into an `ExprIR` node. The result is an AND-reduction of all inputs. """ - all_predicates = _parse_into_iter_expr_ir(first_predicate, *more_predicates) + all_predicates = _parse_into_iter_expr_ir( + first_predicate, *more_predicates, _list_as_series=_list_as_series + ) if constraints: chained = chain(all_predicates, _parse_constraints(constraints)) return _combine_predicates(chained) return _combine_predicates(all_predicates) +def parse_sort_by_into_seq_of_expr_ir( + by: OneOrIterable[IntoExprColumn] = (), *more_by: IntoExprColumn +) -> Seq[ExprIR]: + """Parse `DataFrame.sort` and `Expr.sort_by` keys into a flat sequence of `ExprIR` nodes.""" + return tuple(_parse_sort_by_into_iter_expr_ir(by, more_by)) + + +# TODO @dangotbanned: Review the rejection predicate +# It doesn't cover all length-changing expressions, only aggregations/literals +def _parse_sort_by_into_iter_expr_ir( + by: OneOrIterable[IntoExprColumn], more_by: Iterable[IntoExprColumn] +) -> Iterator[ExprIR]: + for e in _parse_into_iter_expr_ir(by, *more_by): + if e.is_scalar: + msg = f"All expressions sort keys must preserve length, but got:\n{e!r}" + raise InvalidOperationError(msg) + yield e + + def _parse_into_iter_expr_ir( - first_input: OneOrIterable[IntoExpr], *more_inputs: IntoExpr, **named_inputs: IntoExpr + first_input: OneOrIterable[IntoExpr], + *more_inputs: IntoExpr | list[Any], + _list_as_series: PartialSeries | None = None, + **named_inputs: IntoExpr, ) -> Iterator[ExprIR]: if not _is_empty_sequence(first_input): # NOTE: These need to be separated to introduce an intersection type # Otherwise, `str | bytes` always passes through typing if _is_iterable(first_input) and not is_iterable_reject(first_input): - if more_inputs: + if more_inputs and ( + _list_as_series is None or not isinstance(first_input, list) + ): raise invalid_into_expr_error(first_input, more_inputs, named_inputs) + # NOTE: Ensures `first_input = [False, True, True] -> lit(Series([False, True, True]))` + elif ( + _list_as_series is not None + and isinstance(first_input, list) + and not is_into_expr_column(first_input[0]) + ): + yield parse_into_expr_ir(first_input, list_as_series=_list_as_series) else: - yield from _parse_positional_inputs(first_input) + yield from _parse_positional_inputs(first_input, _list_as_series) else: - yield parse_into_expr_ir(first_input) + yield parse_into_expr_ir(first_input, list_as_series=_list_as_series) else: # NOTE: Passthrough case for no inputs - but gets skipped when calling next yield from () if more_inputs: - yield from _parse_positional_inputs(more_inputs) + yield from _parse_positional_inputs(more_inputs, _list_as_series) if named_inputs: yield from _parse_named_inputs(named_inputs) -def _parse_positional_inputs(inputs: Iterable[IntoExpr], /) -> Iterator[ExprIR]: +def _parse_positional_inputs( + inputs: Iterable[IntoExpr | list[Any]], /, list_as_series: PartialSeries | None = None +) -> Iterator[ExprIR]: for into in inputs: - yield parse_into_expr_ir(into) + yield parse_into_expr_ir(into, list_as_series=list_as_series) def _parse_named_inputs(named_inputs: dict[str, IntoExpr], /) -> Iterator[ExprIR]: diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py index 768248e312..f99fad5289 100644 --- a/narwhals/_plan/arrow/acero.py +++ b/narwhals/_plan/arrow/acero.py @@ -25,26 +25,41 @@ import pyarrow.compute as pc # ignore-banned-import from pyarrow.acero import Declaration as Decl -from narwhals._plan.typing import OneOrSeq -from narwhals.typing import SingleColSelector +from narwhals._plan.common import ensure_list_str, flatten_hash_safe, temp +from narwhals._plan.options import SortMultipleOptions +from narwhals._plan.typing import NonCrossJoinStrategy, OneOrSeq +from narwhals._utils import check_column_names_are_unique +from narwhals.typing import JoinStrategy, SingleColSelector if TYPE_CHECKING: - from collections.abc import Callable, Collection, Iterable, Iterator + from collections.abc import ( + Callable, + Collection, + Iterable, + Iterator, + Mapping, + Sequence, + ) - from typing_extensions import TypeAlias + from typing_extensions import TypeAlias, TypeIs from narwhals._arrow.typing import ( # type: ignore[attr-defined] AggregateOptions as _AggregateOptions, Aggregation as _Aggregation, ) from narwhals._plan.arrow.group_by import AggSpec - from narwhals._plan.arrow.typing import NullPlacement + from narwhals._plan.arrow.typing import ( + ArrowAny, + JoinTypeSubset, + NullPlacement, + ScalarAny, + ) from narwhals._plan.typing import OneOrIterable, Order, Seq from narwhals.typing import NonNestedLiteral Incomplete: TypeAlias = Any Expr: TypeAlias = pc.Expression -IntoExpr: TypeAlias = "Expr | NonNestedLiteral" +IntoExpr: TypeAlias = "Expr | NonNestedLiteral | ScalarAny" Field: TypeAlias = Union[Expr, SingleColSelector] """Anything that passes as a single item in [`_compute._ensure_field_ref`]. @@ -57,12 +72,28 @@ Opts: TypeAlias = "AggregateOptions | None" OutputName: TypeAlias = str +IntoDecl: TypeAlias = Union[pa.Table, Decl] +"""An in-memory table, or a plan that began with one.""" + _THREAD_UNSAFE: Final = frozenset[Aggregation]( ("hash_first", "hash_last", "first", "last") ) col = pc.field -lit = cast("Callable[[NonNestedLiteral], Expr]", pc.scalar) -"""Alias for `pyarrow.compute.scalar`.""" +lit = cast("Callable[[NonNestedLiteral | ScalarAny], Expr]", pc.scalar) +"""Alias for `pyarrow.compute.scalar`. + +Extends the signature from `bool | float | str`. + +See https://github.com/apache/arrow/pull/47609#discussion_r2392499842 +""" + +_HOW_JOIN: Mapping[JoinStrategy, JoinTypeSubset] = { + "inner": "inner", + "left": "left outer", + "full": "full outer", + "anti": "left anti", + "semi": "left semi", +} # NOTE: ATOW there are 304 valid function names, 46 can be used for some kind of agg @@ -72,8 +103,17 @@ def can_thread(function_name: str, /) -> bool: return function_name not in _THREAD_UNSAFE +def cols_iter(names: Iterable[str], /) -> Iterator[Expr]: + for name in names: + yield col(name) + + +def _is_expr(obj: Any) -> TypeIs[pc.Expression]: + return isinstance(obj, pc.Expression) + + def _parse_into_expr(into: IntoExpr, /, *, str_as_lit: bool = False) -> Expr: - if isinstance(into, pc.Expression): + if _is_expr(into): return into if isinstance(into, str) and not str_as_lit: return col(into) @@ -99,6 +139,10 @@ def _parse_all_horizontal(predicates: Seq[Expr], constraints: dict[str, Any], /) return reduce(operator.and_, chain(predicates, it)) +def _into_decl(source: IntoDecl, /) -> Decl: + return source if not isinstance(source, pa.Table) else table_source(source) + + def table_source(native: pa.Table, /) -> Decl: """Start building a logical plan, using `native` as the source table. @@ -177,6 +221,31 @@ def project(**named_exprs: IntoExpr) -> Decl: return _project(names=named_exprs.keys(), exprs=exprs) +def _add_column(native: pa.Table, index: int, name: str, values: IntoExpr) -> Decl: + column = values if _is_expr(values) else lit(values) + schema = native.schema + schema_names = schema.names + if index == 0: + names: Sequence[str] = (name, *schema_names) + exprs = (column, *cols_iter(schema_names)) + elif index == native.num_columns: + names = (*schema_names, name) + exprs = (*cols_iter(schema_names), column) + else: + schema_names.insert(index, name) + names = schema_names + exprs = tuple(_parse_into_iter_expr(nm if nm != name else column for nm in names)) + return declare(table_source(native), _project(exprs, names)) + + +def append_column(native: pa.Table, name: str, values: IntoExpr) -> Decl: + return _add_column(native, native.num_columns, name, values) + + +def prepend_column(native: pa.Table, name: str, values: IntoExpr) -> Decl: + return _add_column(native, 0, name, values) + + def _order_by( sort_keys: Iterable[tuple[str, Order]] = (), *, @@ -189,23 +258,80 @@ def _order_by( return Decl("order_by", pac.OrderByNodeOptions(keys, null_placement=null_placement)) -# TODO @dangotbanned: Utilize `SortMultipleOptions.to_arrow_acero` -def sort_by(*args: Any, **kwds: Any) -> Decl: - msg = "Should convert from polars args -> use `_order_by" - raise NotImplementedError(msg) +def sort_by( + by: OneOrIterable[str], + *more_by: str, + descending: OneOrIterable[bool] = False, + nulls_last: bool = False, +) -> Decl: + return SortMultipleOptions.parse( + descending=descending, nulls_last=nulls_last + ).to_arrow_acero(tuple(flatten_hash_safe((by, more_by)))) + + +def _join_options( + how: NonCrossJoinStrategy, + left_on: OneOrIterable[str], + right_on: OneOrIterable[str], + suffix: str = "_right", + left_names: Iterable[str] | None = None, + right_names: Iterable[str] = (), + *, + coalesce_keys: bool = True, +) -> pac.HashJoinNodeOptions: + right_on = ensure_list_str(right_on) + rhs_names: Iterable[str] | None = None + # polars full join does not coalesce keys + if not (coalesce_keys and (how != "full")): + lhs_names = None + else: + lhs_names = left_names + if how in {"inner", "left"}: + rhs_names = (name for name in right_names if name not in right_on) + tp: Incomplete = pac.HashJoinNodeOptions + return tp( # type: ignore[no-any-return] + _HOW_JOIN[how], + left_keys=ensure_list_str(left_on), + right_keys=right_on, + left_output=lhs_names, + right_output=rhs_names, + output_suffix_for_right=suffix, + ) + + +def _hashjoin( + left: IntoDecl, right: IntoDecl, /, options: pac.HashJoinNodeOptions +) -> Decl: + return Decl("hashjoin", options, [_into_decl(left), _into_decl(right)]) -def collect(*declarations: Decl, use_threads: bool = True) -> pa.Table: +def declare(*declarations: Decl) -> Decl: + """Compose one or more `Declaration` nodes for execution as a pipeline.""" + if len(declarations) == 1: + return declarations[0] + # NOTE: stubs + docs say `list`, but impl allows any iterable + decls: Incomplete = declarations + return Decl.from_sequence(decls) + + +def collect( + *declarations: Decl, + use_threads: bool = True, + ensure_unique_column_names: bool = False, +) -> pa.Table: """Compose and evaluate a logical plan. Arguments: *declarations: One or more `Declaration` nodes to execute as a pipeline. **The first node must be a `table_source`**. use_threads: Pass `False` if `declarations` contains any order-dependent aggregation(s). + ensure_unique_column_names: Pass `True` if `declarations` adds generated column names that were + not explicitly defined on the `narwhals`-side. E.g. `join(suffix=...)`. """ - # NOTE: stubs + docs say `list`, but impl allows any iterable - decls: Incomplete = declarations - return Decl.from_sequence(decls).to_table(use_threads=use_threads) + result = declare(*declarations).to_table(use_threads=use_threads) + if ensure_unique_column_names: + check_column_names_are_unique(result.column_names) + return result def group_by_table( @@ -251,3 +377,78 @@ def select_names_table( native: pa.Table, column_names: OneOrIterable[str], *more_names: str ) -> pa.Table: return collect(table_source(native), select_names(column_names, *more_names)) + + +def join_tables( + left: pa.Table, + right: pa.Table, + how: NonCrossJoinStrategy, + left_on: OneOrIterable[str], + right_on: OneOrIterable[str], + suffix: str = "_right", + *, + coalesce_keys: bool = True, +) -> pa.Table: + """Join two tables. + + Based on: + - [`pyarrow.Table.join`] + - [`pyarrow.acero._perform_join`] + - [`narwhals._arrow.dataframe.DataFrame.join`] + + [`pyarrow.Table.join`]: https://github.com/apache/arrow/blob/f7320c9a40082639f9e0cf8b3075286e3fc6c0b9/python/pyarrow/table.pxi#L5764-L5772 + [`pyarrow.acero._perform_join`]: https://github.com/apache/arrow/blob/f7320c9a40082639f9e0cf8b3075286e3fc6c0b9/python/pyarrow/acero.py#L82-L260 + [`narwhals._arrow.dataframe.DataFrame.join`]: https://github.com/narwhals-dev/narwhals/blob/f4787d3f9e027306cb1786db7b471f63b393b8d1/narwhals/_arrow/dataframe.py#L393-L433 + """ + left_on = left_on or () + right_on = right_on or left_on + opts = _join_options( + how, + left_on, + right_on, + suffix, + left.schema.names, + right.schema.names, + coalesce_keys=coalesce_keys, + ) + return collect(_hashjoin(left, right, opts), ensure_unique_column_names=True) + + +def join_cross_tables( + left: pa.Table, right: pa.Table, suffix: str = "_right", *, coalesce_keys: bool = True +) -> pa.Table: + """Perform a cross join between tables.""" + left_names, right_names = left.column_names, right.column_names + on = temp.column_name(set().union(left_names, right_names)) + opts = _join_options( + how="inner", + left_on=on, + right_on=on, + suffix=suffix, + left_names=[on, *left_names], + right_names=right_names, + coalesce_keys=coalesce_keys, + ) + left_, right_ = prepend_column(left, on, 0), prepend_column(right, on, 0) + decl = _hashjoin(left_, right_, opts) + return collect(decl, ensure_unique_column_names=True).remove_column(0) + + +def _add_column_table( + native: pa.Table, index: int, name: str, values: IntoExpr | ArrowAny +) -> pa.Table: + if isinstance(values, (pa.ChunkedArray, pa.Array)): + return native.add_column(index, name, values) + return _add_column(native, index, name, values).to_table() + + +def append_column_table( + native: pa.Table, name: str, values: IntoExpr | ArrowAny +) -> pa.Table: + return _add_column_table(native, native.num_columns, name, values) + + +def prepend_column_table( + native: pa.Table, name: str, values: IntoExpr | ArrowAny +) -> pa.Table: + return _add_column_table(native, 0, name, values) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 668fd5330c..15b9dc80c0 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -9,14 +9,15 @@ import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import native_to_narwhals_dtype -from narwhals._plan.arrow import functions as fn +from narwhals._plan.arrow import acero, functions as fn +from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.compliant.dataframe import EagerDataFrame from narwhals._plan.compliant.typing import namespace from narwhals._plan.expressions import NamedIR from narwhals._plan.typing import Seq -from narwhals._utils import Version, parse_columns_to_drop +from narwhals._utils import Implementation, Version, parse_columns_to_drop from narwhals.schema import Schema if TYPE_CHECKING: @@ -24,17 +25,17 @@ from typing_extensions import Self - from narwhals._arrow.typing import ChunkedArrayAny # noqa: F401 - from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar + from narwhals._arrow.typing import ChunkedArrayAny from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.expressions import ExprIR, NamedIR from narwhals._plan.options import SortMultipleOptions - from narwhals._plan.typing import Seq + from narwhals._plan.typing import NonCrossJoinStrategy, Seq from narwhals.dtypes import DType from narwhals.typing import IntoSchema class ArrowDataFrame(EagerDataFrame[Series, "pa.Table", "ChunkedArrayAny"]): + implementation = Implementation.PYARROW _native: pa.Table _version: Version @@ -144,3 +145,29 @@ def select_names(self, *column_names: str) -> Self: def row(self, index: int) -> tuple[Any, ...]: row = self.native.slice(index, 1) return tuple(chain.from_iterable(row.to_pydict().values())) + + def join( + self, + other: Self, + *, + how: NonCrossJoinStrategy, + left_on: Sequence[str], + right_on: Sequence[str], + suffix: str = "_right", + ) -> Self: + left, right = self.native, other.native + result = acero.join_tables(left, right, how, left_on, right_on, suffix=suffix) + return self._with_native(result) + + def join_cross(self, other: Self, *, suffix: str = "_right") -> Self: + result = acero.join_cross_tables(self.native, other.native, suffix=suffix) + return self._with_native(result) + + def filter(self, predicate: NamedIR) -> Self: + mask: pc.Expression | ChunkedArrayAny + resolved = Expr.from_named_ir(predicate, self) + if isinstance(resolved, Expr): + mask = resolved.broadcast(len(self)).native + else: + mask = acero.lit(resolved.native) + return self._with_native(self.native.filter(mask)) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 9dc25f05cd..fb2bad1479 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -6,21 +6,17 @@ import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import narwhals_to_native_dtype +from narwhals._plan import expressions as ir from narwhals._plan.arrow import functions as fn from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co +from narwhals._plan.common import temp from narwhals._plan.compliant.column import ExprDispatch from narwhals._plan.compliant.expr import EagerExpr from narwhals._plan.compliant.scalar import EagerScalar from narwhals._plan.compliant.typing import namespace from narwhals._plan.expressions import NamedIR -from narwhals._utils import ( - Implementation, - Version, - _StoresNative, - generate_temporary_column_name, - not_implemented, -) +from narwhals._utils import Implementation, Version, _StoresNative, not_implemented from narwhals.exceptions import InvalidOperationError, ShapeError if TYPE_CHECKING: @@ -29,7 +25,6 @@ from typing_extensions import Self, TypeAlias from narwhals._arrow.typing import ChunkedArrayAny, Incomplete - from narwhals._plan import expressions as ir from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.expressions.aggregation import ( @@ -53,6 +48,8 @@ All, IsBetween, IsFinite, + IsFirstDistinct, + IsLastDistinct, IsNan, IsNull, Not, @@ -198,6 +195,9 @@ def _with_native(self, result: ChunkedOrScalarAny, name: str, /) -> Scalar | Sel return ArrowScalar.from_native(result, name, version=self.version) return self.from_native(result, name or self.name, self.version) + # NOTE: I'm not sure what I meant by + # > "isn't natively supported on `ChunkedArray`" + # Was that supposed to say "is only supported on `ChunkedArray`"? def _dispatch_expr(self, node: ir.ExprIR, frame: Frame, name: str) -> Series: """Use instead of `_dispatch` *iff* an operation isn't natively supported on `ChunkedArray`. @@ -231,10 +231,8 @@ def sort(self, node: ir.Sort, frame: Frame, name: str) -> Expr: def sort_by(self, node: ir.SortBy, frame: Frame, name: str) -> Expr: series = self._dispatch_expr(node.expr, frame, name) - by = ( - self._dispatch_expr(e, frame, f"_{idx}") - for idx, e in enumerate(node.by) - ) + it_names = temp.column_names(frame) + by = (self._dispatch_expr(e, frame, nm) for e, nm in zip(node.by, it_names)) df = namespace(self)._concat_horizontal((series, *by)) names = df.columns[1:] indices = pc.sort_indices(df.native, options=node.options.to_arrow(names)) @@ -342,7 +340,7 @@ def over_ordered( # NOTE: Converting `over(order_by=..., options=...)` into the right shape for `DataFrame.sort` sort_by = tuple(NamedIR.from_ir(e) for e in node.order_by) options = node.sort_options.to_multiple(len(node.order_by)) - idx_name = generate_temporary_column_name(8, frame.columns) + idx_name = temp.column_name(frame) sorted_context = frame.with_row_index(idx_name).sort(sort_by, options) evaluated = node.expr.dispatch(self, sorted_context.drop([idx_name]), name) if isinstance(evaluated, ArrowScalar): @@ -374,6 +372,27 @@ def map_batches(self, node: ir.AnonymousExpr, frame: Frame, name: str) -> Self: def rolling_expr(self, node: ir.RollingExpr, frame: Frame, name: str) -> Self: raise NotImplementedError + def _is_first_last_distinct( + self, + node: FunctionExpr[IsFirstDistinct | IsLastDistinct], + frame: Frame, + name: str, + ) -> Self: + idx_name = temp.column_name([name]) + expr_ir = fn.IS_FIRST_LAST_DISTINCT[type(node.function)](idx_name) + series = self._dispatch_expr(node.input[0], frame, name) + df = series.to_frame().with_row_index(idx_name) + distinct_index = ( + df.group_by_names((name,)) + .agg((ir.named_ir(idx_name, expr_ir),)) + .get_column(idx_name) + .native + ) + return self._with_native(fn.is_in(df.to_series().native, distinct_index), name) + + is_first_distinct = _is_first_last_distinct + is_last_distinct = _is_first_last_distinct + class ArrowScalar( _ArrowDispatch["ArrowScalar"], diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 1fd1942b2c..32255d37ec 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -3,6 +3,7 @@ from __future__ import annotations import typing as t +from collections.abc import Callable from typing import TYPE_CHECKING, Any import pyarrow as pa # ignore-banned-import @@ -13,6 +14,7 @@ chunked_array as _chunked_array, floordiv_compat as floordiv, ) +from narwhals._plan import expressions as ir from narwhals._plan.arrow import options from narwhals._plan.expressions import operators as ops from narwhals._utils import Implementation @@ -20,22 +22,22 @@ if TYPE_CHECKING: from collections.abc import Iterable, Mapping - from typing_extensions import TypeIs + from typing_extensions import TypeAlias, TypeIs from narwhals._arrow.dataframe import PromoteOptions - from narwhals._arrow.typing import ( - ArrayAny, - ArrayOrScalar, - ChunkedArrayAny, - Incomplete, - ) + from narwhals._arrow.typing import Incomplete from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.arrow.typing import ( + Array, + ArrayAny, + ArrowAny, BinaryComp, BinaryLogical, BinaryNumericTemporal, BinOp, ChunkedArray, + ChunkedArrayAny, + ChunkedOrArrayAny, ChunkedOrScalar, ChunkedOrScalarAny, DataType, @@ -56,6 +58,9 @@ BACKEND_VERSION = Implementation.PYARROW._backend_version() +IntoColumnAgg: TypeAlias = Callable[[str], ir.AggExpr] +"""Helper constructor for single-column aggregations.""" + is_null = pc.is_null is_not_null = t.cast("UnaryFunction[ScalarAny,pa.BooleanScalar]", pc.is_valid) is_nan = pc.is_nan @@ -111,6 +116,10 @@ def modulus(lhs: Any, rhs: Any) -> Any: "none": (gt, lt), "both": (gt_eq, lt_eq), } +IS_FIRST_LAST_DISTINCT: Mapping[type[ir.boolean.BooleanFunction], IntoColumnAgg] = { + ir.boolean.IsFirstDistinct: ir.min, + ir.boolean.IsLastDistinct: ir.max, +} @t.overload @@ -210,6 +219,27 @@ def is_between( return and_(fn_lhs(native, lower), fn_rhs(native, upper)) +@t.overload +def is_in( + values: ChunkedArrayAny, /, other: ChunkedOrArrayAny +) -> ChunkedArray[pa.BooleanScalar]: ... +@t.overload +def is_in(values: ArrayAny, /, other: ChunkedOrArrayAny) -> Array[pa.BooleanScalar]: ... +@t.overload +def is_in(values: ScalarAny, /, other: ChunkedOrArrayAny) -> pa.BooleanScalar: ... +def is_in(values: ArrowAny, /, other: ChunkedOrArrayAny) -> ArrowAny: + """Check if elements of `values` are present in `other`. + + Roughly equivalent to [`polars.Expr.is_in`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.is_in.html) + + Returns a mask with `len(values)` elements. + """ + # NOTE: Stubs don't include a `ChunkedArray` return + # NOTE: Replaced ambiguous parameter name (`value_set`) + is_in_: Incomplete = pc.is_in + return is_in_(values, other) # type: ignore[no-any-return] + + def binary( lhs: ChunkedOrScalarAny, op: type[ops.Operator], rhs: ChunkedOrScalarAny ) -> ChunkedOrScalarAny: @@ -257,7 +287,7 @@ def array( def chunked_array( - arr: ArrayOrScalar | list[Iterable[Any]], dtype: DataType | None = None, / + arr: ArrowAny | list[Iterable[Any]], dtype: DataType | None = None, / ) -> ChunkedArrayAny: return _chunked_array(array(arr) if isinstance(arr, pa.Scalar) else arr, dtype) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index c068cf43ed..fa72df9e7f 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -5,7 +5,8 @@ from narwhals._arrow.utils import narwhals_to_native_dtype, native_to_narwhals_dtype from narwhals._plan.arrow import functions as fn from narwhals._plan.compliant.series import CompliantSeries -from narwhals._utils import Version +from narwhals._plan.compliant.typing import namespace +from narwhals._utils import Implementation, Version from narwhals.dependencies import is_numpy_array_1d if TYPE_CHECKING: @@ -13,12 +14,27 @@ from typing_extensions import Self - from narwhals._arrow.typing import ChunkedArrayAny # noqa: F401 + from narwhals._arrow.typing import ChunkedArrayAny + from narwhals._plan.arrow.dataframe import ArrowDataFrame as DataFrame + from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals.dtypes import DType from narwhals.typing import Into1DArray, IntoDType, _1DArray class ArrowSeries(CompliantSeries["ChunkedArrayAny"]): + implementation = Implementation.PYARROW + _native: ChunkedArrayAny + _version: Version + _name: str + + def __narwhals_namespace__(self) -> ArrowNamespace: + from narwhals._plan.arrow.namespace import ArrowNamespace + + return ArrowNamespace(self._version) + + def to_frame(self) -> DataFrame: + return namespace(self)._dataframe.from_dict({self.name: self.native}) + def to_list(self) -> list[Any]: return self.native.to_pylist() diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index dc0795c95b..a515b99091 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -121,12 +121,23 @@ class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Prot DataTypeT_co = TypeVar("DataTypeT_co", bound=DataType, covariant=True, default=Any) ScalarT_co = TypeVar("ScalarT_co", bound="pa.Scalar[Any]", covariant=True, default=Any) Scalar: TypeAlias = "pa.Scalar[DataTypeT_co]" +Array: TypeAlias = "pa.Array[ScalarT_co]" ChunkedArray: TypeAlias = "pa.ChunkedArray[ScalarT_co]" ChunkedOrScalar: TypeAlias = "ChunkedArray[ScalarT_co] | ScalarT_co" +ChunkedOrArray: TypeAlias = "ChunkedArray[ScalarT_co] | Array[ScalarT_co]" ScalarAny: TypeAlias = "Scalar[Any]" +ArrayAny: TypeAlias = "Array[Any]" +ChunkedArrayAny: TypeAlias = "ChunkedArray[Any]" ChunkedOrScalarAny: TypeAlias = "ChunkedOrScalar[ScalarAny]" +ChunkedOrArrayAny: TypeAlias = "ChunkedOrArray[ScalarAny]" +ArrowAny: TypeAlias = "ChunkedOrScalarAny | ArrayAny" NativeScalar: TypeAlias = ScalarAny BinOp: TypeAlias = Callable[..., ChunkedOrScalarAny] StoresNativeT_co = TypeVar("StoresNativeT_co", bound=StoresNative[Any], covariant=True) DataTypeRemap: TypeAlias = Mapping[DataType, DataType] NullPlacement: TypeAlias = Literal["at_start", "at_end"] + +JoinTypeSubset: TypeAlias = Literal[ + "inner", "left outer", "full outer", "left anti", "left semi" +] +"""Only the `pyarrow` `JoinType`'s we use in narwhals""" diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index defe398f95..c29e8e8071 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, cast, overload from narwhals._plan._guards import is_iterable_reject -from narwhals._utils import _hasattr_static +from narwhals._utils import _hasattr_static, qualified_type_name from narwhals.dtypes import DType from narwhals.exceptions import NarwhalsError from narwhals.utils import Version @@ -22,12 +22,15 @@ from typing_extensions import TypeIs + from narwhals._plan.compliant.series import CompliantSeries + from narwhals._plan.series import Series from narwhals._plan.typing import ( DTypeT, ExprIRT, FunctionT, NonNestedDTypeT, OneOrIterable, + Seq, ) from narwhals._utils import _StoresColumns from narwhals.typing import NonNestedDType, NonNestedLiteral @@ -109,9 +112,21 @@ def into_dtype(dtype: DTypeT | type[NonNestedDTypeT], /) -> DTypeT | NonNestedDT return dtype -# TODO @dangotbanned: Review again and try to work around (https://github.com/microsoft/pyright/issues/10673#issuecomment-3033789021) +# NOTE: See (https://github.com/microsoft/pyright/issues/10673#issuecomment-3033789021) # The issue is `T` possibly being `Iterable` # Ignoring here still leaks the issue to the caller, where you need to annotate the base case +@overload +def flatten_hash_safe(iterable: Iterable[OneOrIterable[str]], /) -> Iterator[str]: ... +@overload +def flatten_hash_safe( + iterable: Iterable[OneOrIterable[Series]], / +) -> Iterator[Series]: ... +@overload +def flatten_hash_safe( + iterable: Iterable[OneOrIterable[CompliantSeries]], / +) -> Iterator[CompliantSeries]: ... +@overload +def flatten_hash_safe(iterable: Iterable[OneOrIterable[T]], /) -> Iterator[T]: ... def flatten_hash_safe(iterable: Iterable[OneOrIterable[T]], /) -> Iterator[T]: """Fully unwrap all levels of nesting. @@ -124,6 +139,23 @@ def flatten_hash_safe(iterable: Iterable[OneOrIterable[T]], /) -> Iterator[T]: yield element # type: ignore[misc] +def _not_one_or_iterable_str_error(obj: Any, /) -> TypeError: + msg = f"Expected one or an iterable of strings, but got: {qualified_type_name(obj)!r}\n{obj!r}" + return TypeError(msg) + + +def ensure_seq_str(obj: OneOrIterable[str], /) -> Seq[str]: + if not isinstance(obj, Iterable): + raise _not_one_or_iterable_str_error(obj) + return (obj,) if isinstance(obj, str) else tuple(obj) + + +def ensure_list_str(obj: OneOrIterable[str], /) -> list[str]: + if not isinstance(obj, Iterable): + raise _not_one_or_iterable_str_error(obj) + return [obj] if isinstance(obj, str) else list(obj) + + def _has_columns(obj: Any) -> TypeIs[_StoresColumns]: return _hasattr_static(obj, "columns") diff --git a/narwhals/_plan/compliant/dataframe.py b/narwhals/_plan/compliant/dataframe.py index cc7bab7501..45728fce47 100644 --- a/narwhals/_plan/compliant/dataframe.py +++ b/narwhals/_plan/compliant/dataframe.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, Protocol, overload +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Protocol, overload from narwhals._plan.compliant.group_by import Grouped from narwhals._plan.compliant.typing import ColumnT_co, HasVersion, SeriesT @@ -9,6 +9,7 @@ NativeDataFrameT, NativeFrameT_co, NativeSeriesT, + NonCrossJoinStrategy, OneOrIterable, ) @@ -29,7 +30,8 @@ from narwhals._plan.expressions import NamedIR from narwhals._plan.options import SortMultipleOptions from narwhals._plan.typing import Seq - from narwhals._utils import Version + from narwhals._typing import _EagerAllowedImpl + from narwhals._utils import Implementation, Version from narwhals.dtypes import DType from narwhals.typing import IntoSchema @@ -37,6 +39,8 @@ class CompliantFrame(HasVersion, Protocol[ColumnT_co, NativeFrameT_co]): + implementation: ClassVar[Implementation] + def __narwhals_namespace__(self) -> Any: ... def _evaluate_irs( self, nodes: Iterable[NamedIR[ir.ExprIR]], / @@ -53,6 +57,9 @@ def to_narwhals(self) -> BaseFrame[NativeFrameT_co]: ... def columns(self) -> list[str]: ... def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: ... def drop_nulls(self, subset: Sequence[str] | None) -> Self: ... + # Shouldn't *need* to be `NamedIR`, but current impl depends on a name being passed around + def filter(self, predicate: NamedIR, /) -> Self: ... + def rename(self, mapping: Mapping[str, str]) -> Self: ... @property def schema(self) -> Mapping[str, DType]: ... def select(self, irs: Seq[NamedIR]) -> Self: ... @@ -65,6 +72,7 @@ class CompliantDataFrame( CompliantFrame[SeriesT, NativeDataFrameT], Protocol[SeriesT, NativeDataFrameT, NativeSeriesT], ): + implementation: ClassVar[_EagerAllowedImpl] _native: NativeDataFrameT def __len__(self) -> int: ... @@ -92,6 +100,7 @@ def native(self) -> NativeDataFrameT: def from_dict( cls, data: Mapping[str, Any], /, *, schema: IntoSchema | None = None ) -> Self: ... + def get_column(self, name: str) -> SeriesT: ... def group_by_agg( self, by: OneOrIterable[IntoExpr], aggs: OneOrIterable[IntoExpr], / ) -> Self: @@ -109,6 +118,17 @@ def group_by_resolver(self, resolver: GroupByResolver, /) -> DataFrameGroupBy[Se """ return self._group_by.from_resolver(self, resolver) + def filter(self, predicate: NamedIR, /) -> Self: ... + def join( + self, + other: Self, + *, + how: NonCrossJoinStrategy, + left_on: Sequence[str], + right_on: Sequence[str], + suffix: str = "_right", + ) -> Self: ... + def join_cross(self, other: Self, *, suffix: str = "_right") -> Self: ... def row(self, index: int) -> tuple[Any, ...]: ... @overload def to_dict(self, *, as_series: Literal[True]) -> dict[str, SeriesT]: ... @@ -126,6 +146,7 @@ def to_narwhals(self) -> DataFrame[NativeDataFrameT, NativeSeriesT]: return DataFrame[NativeDataFrameT, NativeSeriesT](self) + def to_series(self, index: int = 0) -> SeriesT: ... def with_row_index(self, name: str) -> Self: ... @@ -141,3 +162,6 @@ def select(self, irs: Seq[NamedIR]) -> Self: def with_columns(self, irs: Seq[NamedIR]) -> Self: return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) + + def to_series(self, index: int = 0) -> SeriesT: + return self.get_column(self.columns[index]) diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index 229151284f..5defde7d61 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -24,7 +24,15 @@ boolean, functions as F, ) - from narwhals._plan.expressions.boolean import IsBetween, IsFinite, IsNan, IsNull, Not + from narwhals._plan.expressions.boolean import ( + IsBetween, + IsFinite, + IsFirstDistinct, + IsLastDistinct, + IsNan, + IsNull, + Not, + ) class CompliantExpr(HasVersion, Protocol[FrameT_contra, SeriesT_co]): @@ -55,6 +63,12 @@ def is_between( def is_finite( self, node: FunctionExpr[IsFinite], frame: FrameT_contra, name: str ) -> Self: ... + def is_first_distinct( + self, node: FunctionExpr[IsFirstDistinct], frame: FrameT_contra, name: str + ) -> Self: ... + def is_last_distinct( + self, node: FunctionExpr[IsLastDistinct], frame: FrameT_contra, name: str + ) -> Self: ... def is_nan( self, node: FunctionExpr[IsNan], frame: FrameT_contra, name: str ) -> Self: ... diff --git a/narwhals/_plan/compliant/group_by.py b/narwhals/_plan/compliant/group_by.py index 7ae5f3e966..8e05144393 100644 --- a/narwhals/_plan/compliant/group_by.py +++ b/narwhals/_plan/compliant/group_by.py @@ -25,7 +25,7 @@ class CompliantGroupBy(Protocol[FrameT_co]): - def agg(self, irs: Seq[NamedIR]) -> FrameT_co: ... + def agg(self, irs: Seq[NamedIR[Any]]) -> FrameT_co: ... @property def compliant(self) -> FrameT_co: ... diff --git a/narwhals/_plan/compliant/scalar.py b/narwhals/_plan/compliant/scalar.py index 19fda1f003..abb873aa5e 100644 --- a/narwhals/_plan/compliant/scalar.py +++ b/narwhals/_plan/compliant/scalar.py @@ -9,7 +9,8 @@ from typing_extensions import Self from narwhals._plan import expressions as ir - from narwhals._plan.expressions import aggregation as agg + from narwhals._plan.expressions import FunctionExpr, aggregation as agg + from narwhals._plan.expressions.boolean import IsFirstDistinct, IsLastDistinct from narwhals._utils import Version from narwhals.typing import IntoDType, PythonLiteral @@ -60,6 +61,16 @@ def count(self, node: agg.Count, frame: FrameT_contra, name: str) -> Self: def first(self, node: agg.First, frame: FrameT_contra, name: str) -> Self: return self._with_evaluated(self._evaluated, name) + def is_first_distinct( + self, node: FunctionExpr[IsFirstDistinct], frame: FrameT_contra, name: str + ) -> Self: + return self.from_python(True, name, dtype=None, version=self.version) + + def is_last_distinct( + self, node: FunctionExpr[IsLastDistinct], frame: FrameT_contra, name: str + ) -> Self: + return self.from_python(True, name, dtype=None, version=self.version) + def last(self, node: agg.Last, frame: FrameT_contra, name: str) -> Self: return self._with_evaluated(self._evaluated, name) diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index 8c5d2fe3b4..b97e06331d 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, ClassVar, Protocol from narwhals._plan.compliant.typing import HasVersion from narwhals._plan.typing import NativeSeriesT @@ -9,20 +9,25 @@ if TYPE_CHECKING: from collections.abc import Iterable - from typing_extensions import Self + from typing_extensions import Self, TypeAlias from narwhals._plan.series import Series + from narwhals._typing import _EagerAllowedImpl from narwhals.dtypes import DType from narwhals.typing import Into1DArray, IntoDType, _1DArray +Incomplete: TypeAlias = Any + class CompliantSeries(HasVersion, Protocol[NativeSeriesT]): + implementation: ClassVar[_EagerAllowedImpl] _native: NativeSeriesT _name: str def __len__(self) -> int: return len(self.native) + def __narwhals_namespace__(self) -> Incomplete: ... def __narwhals_series__(self) -> Self: return self @@ -66,6 +71,7 @@ def alias(self, name: str) -> Self: return self.from_native(self.native, name, version=self.version) def cast(self, dtype: IntoDType) -> Self: ... + def to_frame(self) -> Incomplete: ... def to_list(self) -> list[Any]: ... def to_narwhals(self) -> Series[NativeSeriesT]: from narwhals._plan.series import Series diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index a914923226..625f1990e0 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -1,32 +1,41 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, get_args, overload from narwhals._plan import _parse from narwhals._plan._expansion import prepare_projection -from narwhals._plan.expr import _parse_sort_by +from narwhals._plan.common import ensure_seq_str, temp from narwhals._plan.group_by import GroupBy, Grouped +from narwhals._plan.options import SortMultipleOptions from narwhals._plan.series import Series from narwhals._plan.typing import ( + ColumnNameOrSelector, IntoExpr, + IntoExprColumn, NativeDataFrameT, NativeDataFrameT_co, NativeFrameT_co, NativeSeriesT, + NonCrossJoinStrategy, OneOrIterable, + PartialSeries, + Seq, ) -from narwhals._utils import Version, generate_repr +from narwhals._utils import Implementation, Version, generate_repr from narwhals.dependencies import is_pyarrow_table from narwhals.schema import Schema +from narwhals.typing import IntoDType, JoinStrategy if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Iterable, Mapping, Sequence import pyarrow as pa - from typing_extensions import Self, TypeAlias + from typing_extensions import Self, TypeAlias, TypeIs from narwhals._plan.arrow.typing import NativeArrowDataFrame from narwhals._plan.compliant.dataframe import CompliantDataFrame, CompliantFrame + from narwhals._typing import _EagerAllowedImpl + Incomplete: TypeAlias = Any @@ -39,6 +48,10 @@ class BaseFrame(Generic[NativeFrameT_co]): def version(self) -> Version: return self._version + @property + def implementation(self) -> Implementation: + return self._compliant.implementation + @property def schema(self) -> Schema: return Schema(self._compliant.schema.items()) @@ -59,6 +72,17 @@ def _with_compliant(self, compliant: CompliantFrame[Any, Incomplete], /) -> Self def to_native(self) -> NativeFrameT_co: return self._compliant.native + def filter( + self, *predicates: OneOrIterable[IntoExprColumn], **constraints: Any + ) -> Self: + e = _parse.parse_predicates_constraints_into_expr_ir(*predicates, **constraints) + named_irs, _ = prepare_projection((e,), schema=self) + if len(named_irs) != 1: + # Should be unreachable, but I guess we will see + msg = f"Expected a single predicate after expansion, but got {len(named_irs)!r}\n\n{named_irs!r}" + raise ValueError(msg) + return self._with_compliant(self._compliant.filter(named_irs[0])) + def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: named_irs, schema = prepare_projection( _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), schema=self @@ -75,34 +99,55 @@ def with_columns(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> S def sort( self, - by: OneOrIterable[str], - *more_by: str, + by: OneOrIterable[ColumnNameOrSelector], + *more_by: ColumnNameOrSelector, descending: OneOrIterable[bool] = False, nulls_last: OneOrIterable[bool] = False, ) -> Self: - sort, opts = _parse_sort_by( - by, *more_by, descending=descending, nulls_last=nulls_last - ) + sort = _parse.parse_sort_by_into_seq_of_expr_ir(by, *more_by) + opts = SortMultipleOptions.parse(descending=descending, nulls_last=nulls_last) named_irs, _ = prepare_projection(sort, schema=self) return self._with_compliant(self._compliant.sort(named_irs, opts)) - def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: + def drop(self, *columns: str, strict: bool = True) -> Self: return self._with_compliant(self._compliant.drop(columns, strict=strict)) def drop_nulls(self, subset: str | Sequence[str] | None = None) -> Self: subset = [subset] if isinstance(subset, str) else subset return self._with_compliant(self._compliant.drop_nulls(subset)) + def rename(self, mapping: Mapping[str, str]) -> Self: + return self._with_compliant(self._compliant.rename(mapping)) + class DataFrame( BaseFrame[NativeDataFrameT_co], Generic[NativeDataFrameT_co, NativeSeriesT] ): _compliant: CompliantDataFrame[Any, NativeDataFrameT_co, NativeSeriesT] + @property + def implementation(self) -> _EagerAllowedImpl: + return self._compliant.implementation + + def __len__(self) -> int: + return len(self._compliant) + @property def _series(self) -> type[Series[NativeSeriesT]]: return Series[NativeSeriesT] + def _partial_series( + self, *, dtype: IntoDType | None = None + ) -> PartialSeries[NativeSeriesT]: + it_names = temp.column_names(self.columns) + backend = self.implementation + series = self._series.from_iterable + + def fn(values: Iterable[Any], /) -> Series[NativeSeriesT]: + return series(values, name=next(it_names), dtype=dtype, backend=backend) + + return fn + @overload @classmethod def from_native( @@ -144,8 +189,11 @@ def to_dict( } return self._compliant.to_dict(as_series=as_series) - def __len__(self) -> int: - return len(self._compliant) + def to_series(self, index: int = 0) -> Series[NativeSeriesT]: + return self._series(self._compliant.to_series(index)) + + def get_column(self, name: str) -> Series[NativeSeriesT]: + return self._series(self._compliant.get_column(name)) @overload def group_by( @@ -172,3 +220,79 @@ def group_by( def row(self, index: int) -> tuple[Any, ...]: return self._compliant.row(index) + + def join( + self, + other: Self, + on: str | Sequence[str] | None = None, + how: JoinStrategy = "inner", + *, + left_on: str | Sequence[str] | None = None, + right_on: str | Sequence[str] | None = None, + suffix: str = "_right", + ) -> Self: + left, right = self._compliant, other._compliant + how = _validate_join_strategy(how) + if how == "cross": + if left_on is not None or right_on is not None or on is not None: + msg = "Can not pass `left_on`, `right_on` or `on` keys for cross join" + raise ValueError(msg) + return self._with_compliant(left.join_cross(right, suffix=suffix)) + left_on, right_on = normalize_join_on(on, how, left_on, right_on) + return self._with_compliant( + left.join(right, how=how, left_on=left_on, right_on=right_on, suffix=suffix) + ) + + def filter( + self, *predicates: OneOrIterable[IntoExprColumn] | list[bool], **constraints: Any + ) -> Self: + e = _parse.parse_predicates_constraints_into_expr_ir( + *predicates, + _list_as_series=self._partial_series(dtype=self.version.dtypes.Boolean()), + **constraints, + ) + named_irs, _ = prepare_projection((e,), schema=self) + if len(named_irs) != 1: + # Should be unreachable, but I guess we will see + msg = f"Expected a single predicate after expansion, but got {len(named_irs)!r}\n\n{named_irs!r}" + raise ValueError(msg) + return self._with_compliant(self._compliant.filter(named_irs[0])) + + +def _is_join_strategy(obj: Any) -> TypeIs[JoinStrategy]: + return obj in {"inner", "left", "full", "cross", "anti", "semi"} + + +def _validate_join_strategy(how: str, /) -> JoinStrategy: + if _is_join_strategy(how): + return how + msg = f"Only the following join strategies are supported: {get_args(JoinStrategy)}; found '{how}'." + raise NotImplementedError(msg) + + +def normalize_join_on( + on: OneOrIterable[str] | None, + how: NonCrossJoinStrategy, + left_on: OneOrIterable[str] | None, + right_on: OneOrIterable[str] | None, + /, +) -> tuple[Seq[str], Seq[str]]: + """Reduce the 3 potential key (`on*`) arguments to 2. + + Ensures the keys spelling is compatible with the join strategy. + """ + if on is None: + if left_on is None or right_on is None: + msg = f"Either (`left_on` and `right_on`) or `on` keys should be specified for {how}." + raise ValueError(msg) + left_on = ensure_seq_str(left_on) + right_on = ensure_seq_str(right_on) + if len(left_on) != len(right_on): + msg = "`left_on` and `right_on` must have the same length." + raise ValueError(msg) + return left_on, right_on + if left_on is not None or right_on is not None: + msg = f"If `on` is specified, `left_on` and `right_on` should be None for {how}." + raise ValueError(msg) + on = ensure_seq_str(on) + return on, on diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index 8f4348aaa3..cfeb87644b 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -137,7 +137,7 @@ def over_row_separable_error( def invalid_into_expr_error( first_input: Iterable[IntoExpr], - more_inputs: tuple[IntoExpr, ...], + more_inputs: tuple[Any, ...], named_inputs: dict[str, IntoExpr], /, ) -> InvalidIntoExprError: diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 7695c1d92f..a56f58c7f6 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -10,6 +10,7 @@ parse_into_expr_ir, parse_into_seq_of_expr_ir, parse_predicates_constraints_into_expr_ir, + parse_sort_by_into_seq_of_expr_ir, ) from narwhals._plan.expressions import ( aggregation as agg, @@ -25,7 +26,7 @@ rolling_options, ) from narwhals._utils import Version -from narwhals.exceptions import ComputeError, InvalidOperationError +from narwhals.exceptions import ComputeError if TYPE_CHECKING: from typing_extensions import Never, Self @@ -50,21 +51,6 @@ ) -# NOTE: Trying to keep consistent logic between `DataFrame.sort` and `Expr.sort_by` -def _parse_sort_by( - by: OneOrIterable[IntoExpr] = (), - *more_by: IntoExpr, - descending: OneOrIterable[bool] = False, - nulls_last: OneOrIterable[bool] = False, -) -> tuple[Seq[ir.ExprIR], SortMultipleOptions]: - sort_by = parse_into_seq_of_expr_ir(by, *more_by) - if length_changing := next((e for e in sort_by if e.is_scalar), None): - msg = f"All expressions sort keys must preserve length, but got:\n{length_changing!r}" - raise InvalidOperationError(msg) - options = SortMultipleOptions.parse(descending=descending, nulls_last=nulls_last) - return sort_by, options - - # NOTE: Overly simplified placeholders for mocking typing # Entirely ignoring namespace + function binding class Expr: @@ -151,8 +137,8 @@ def quantile( def over( self, - *partition_by: OneOrIterable[IntoExpr], - order_by: OneOrIterable[IntoExpr] = None, + *partition_by: OneOrIterable[IntoExprColumn], + order_by: OneOrIterable[IntoExprColumn] | None = None, descending: bool = False, nulls_last: bool = False, ) -> Self: @@ -175,14 +161,13 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: def sort_by( self, - by: OneOrIterable[IntoExpr], - *more_by: IntoExpr, + by: OneOrIterable[IntoExprColumn], + *more_by: IntoExprColumn, descending: OneOrIterable[bool] = False, nulls_last: OneOrIterable[bool] = False, ) -> Self: - keys, opts = _parse_sort_by( - by, *more_by, descending=descending, nulls_last=nulls_last - ) + keys = parse_sort_by_into_seq_of_expr_ir(by, *more_by) + opts = SortMultipleOptions.parse(descending=descending, nulls_last=nulls_last) return self._from_ir(ir.SortBy(expr=self._ir, by=keys, options=opts)) def filter( diff --git a/narwhals/_plan/expressions/__init__.py b/narwhals/_plan/expressions/__init__.py index 4444bbd6be..5e5bef9362 100644 --- a/narwhals/_plan/expressions/__init__.py +++ b/narwhals/_plan/expressions/__init__.py @@ -4,6 +4,7 @@ ExprIR, NamedIR, SelectorIR, + named_ir, ) from narwhals._plan._function import Function from narwhals._plan.expressions import ( @@ -13,7 +14,7 @@ operators, selectors, ) -from narwhals._plan.expressions.aggregation import AggExpr, OrderableAggExpr +from narwhals._plan.expressions.aggregation import AggExpr, OrderableAggExpr, max, min from narwhals._plan.expressions.expr import ( Alias, All, @@ -88,6 +89,9 @@ "cols", "functions", "index_columns", + "max", + "min", + "named_ir", "nth", "operators", "over", diff --git a/narwhals/_plan/expressions/aggregation.py b/narwhals/_plan/expressions/aggregation.py index 0f26a82c10..129da889aa 100644 --- a/narwhals/_plan/expressions/aggregation.py +++ b/narwhals/_plan/expressions/aggregation.py @@ -63,3 +63,15 @@ class Std(AggExpr): class Var(AggExpr): __slots__ = (*AggExpr.__slots__, "ddof") ddof: int + + +def min(name: str, /) -> Min: + from narwhals._plan.expressions import col + + return Min(expr=col(name)) + + +def max(name: str, /) -> Max: + from narwhals._plan.expressions import col + + return Max(expr=col(name)) diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index 1b283f81a6..ec5ffa36a1 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -143,8 +143,7 @@ class Exclude(_ColumnSelection, child=("expr",)): @staticmethod def from_names(expr: ExprIR, *names: str | t.Iterable[str]) -> Exclude: - flat: t.Iterator[str] = flatten_hash_safe(names) - return Exclude(expr=expr, names=tuple(flat)) + return Exclude(expr=expr, names=tuple(flatten_hash_safe(names))) def __repr__(self) -> str: return f"{self.expr!r}.exclude({list(self.names)!r})" diff --git a/narwhals/_plan/expressions/selectors.py b/narwhals/_plan/expressions/selectors.py index 5d6bfa8292..d09d7af761 100644 --- a/narwhals/_plan/expressions/selectors.py +++ b/narwhals/_plan/expressions/selectors.py @@ -14,7 +14,6 @@ from narwhals._utils import Version, _parse_time_unit_and_time_zone if TYPE_CHECKING: - from collections.abc import Iterator from datetime import timezone from typing import TypeVar @@ -127,7 +126,7 @@ def from_string(pattern: str, /) -> Matches: @staticmethod def from_names(*names: OneOrIterable[str]) -> Matches: """Implements `cs.by_name` to support `__r__` with column selections.""" - it: Iterator[str] = flatten_hash_safe(names) + it = flatten_hash_safe(names) return Matches.from_string(f"^({'|'.join(re.escape(name) for name in it)})$") def __repr__(self) -> str: diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 303d07a097..ca0d101676 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -49,7 +49,10 @@ class FunctionFlags(enum.Flag): """ LENGTH_PRESERVING = 1 << 9 - """mutually exclusive with `RETURNS_SCALAR`""" + """In isolation, means that the function is dependent on the context of surrounding rows. + + Mutually exclusive with `RETURNS_SCALAR`. + """ def is_elementwise(self) -> bool: return (FunctionFlags.ROW_SEPARABLE | FunctionFlags.LENGTH_PRESERVING) in self diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index 5108d8ca4e..d3a17f29a4 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -3,14 +3,16 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic from narwhals._plan.typing import NativeSeriesT, NativeSeriesT_co -from narwhals._utils import Version +from narwhals._utils import Implementation, Version, is_eager_allowed from narwhals.dependencies import is_pyarrow_chunked_array if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Iterable, Iterator from narwhals._plan.compliant.series import CompliantSeries + from narwhals._typing import EagerAllowed, IntoBackend from narwhals.dtypes import DType + from narwhals.typing import IntoDType class Series(Generic[NativeSeriesT_co]): @@ -32,6 +34,29 @@ def name(self) -> str: def __init__(self, compliant: CompliantSeries[NativeSeriesT_co], /) -> None: self._compliant = compliant + @classmethod + def from_iterable( + cls: type[Series[Any]], + values: Iterable[Any], + *, + name: str = "", + dtype: IntoDType | None = None, + backend: IntoBackend[EagerAllowed], + ) -> Series[Any]: + implementation = Implementation.from_backend(backend) + if is_eager_allowed(implementation): + if implementation is Implementation.PYARROW: + from narwhals._plan.arrow.series import ArrowSeries + + return cls( + ArrowSeries.from_iterable( + values, name=name, version=cls._version, dtype=dtype + ) + ) + raise NotImplementedError(implementation) + msg = f"{implementation} support in Narwhals is lazy-only" + raise ValueError(msg) + @classmethod def from_native( cls: type[Series[Any]], native: NativeSeriesT, name: str = "", / diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 843e7f4b54..5214c2f92c 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -5,13 +5,15 @@ from narwhals._typing_compat import TypeVar if t.TYPE_CHECKING: + from collections.abc import Callable, Iterable + from typing_extensions import TypeAlias from narwhals import dtypes from narwhals._plan._expr_ir import ExprIR, NamedIR, SelectorIR from narwhals._plan._function import Function from narwhals._plan.dataframe import DataFrame - from narwhals._plan.expr import Expr + from narwhals._plan.expr import Expr, Selector from narwhals._plan.expressions import operators as ops from narwhals._plan.expressions.functions import RollingWindow from narwhals._plan.expressions.namespace import IRNamespace @@ -26,6 +28,7 @@ ) __all__ = [ + "ColumnNameOrSelector", "DataFrameT", "FunctionT", "IntoExpr", @@ -63,7 +66,7 @@ LeftT = TypeVar("LeftT", bound="ExprIR", default="ExprIR") OperatorT = TypeVar("OperatorT", bound="ops.Operator", default="ops.Operator") RightT = TypeVar("RightT", bound="ExprIR", default="ExprIR") -OperatorFn: TypeAlias = "t.Callable[[t.Any, t.Any], t.Any]" +OperatorFn: TypeAlias = "Callable[[t.Any, t.Any], t.Any]" ExprIRT = TypeVar("ExprIRT", bound="ExprIR", default="ExprIR") ExprIRT2 = TypeVar("ExprIRT2", bound="ExprIR", default="ExprIR") NamedOrExprIRT = TypeVar("NamedOrExprIRT", "NamedIR[t.Any]", "ExprIR") @@ -87,6 +90,7 @@ "NonNestedLiteralT", bound="NonNestedLiteral", default="NonNestedLiteral" ) NativeSeriesT = TypeVar("NativeSeriesT", bound="NativeSeries", default="NativeSeries") +NativeSeriesAnyT = TypeVar("NativeSeriesAnyT", bound="NativeSeries", default="t.Any") NativeSeriesT_co = TypeVar( "NativeSeriesT_co", bound="NativeSeries", covariant=True, default="NativeSeries" ) @@ -104,7 +108,7 @@ default="NativeDataFrame", ) LiteralT = TypeVar("LiteralT", bound="NonNestedLiteral | Series[t.Any]", default=t.Any) -MapIR: TypeAlias = "t.Callable[[ExprIR], ExprIR]" +MapIR: TypeAlias = "Callable[[ExprIR], ExprIR]" """A function to apply to all nodes in this tree.""" T = TypeVar("T") @@ -115,12 +119,15 @@ Using instead of `Sequence`, as a `list` can be passed there (can't break immutability promise). """ -Udf: TypeAlias = "t.Callable[[t.Any], t.Any]" +Udf: TypeAlias = "Callable[[t.Any], t.Any]" """Placeholder for `map_batches(function=...)`.""" IntoExprColumn: TypeAlias = "Expr | Series[t.Any] | str" IntoExpr: TypeAlias = "NonNestedLiteral | IntoExprColumn" +ColumnNameOrSelector: TypeAlias = "str | Selector" OneOrIterable: TypeAlias = "T | t.Iterable[T]" OneOrSeq: TypeAlias = t.Union[T, Seq[T]] DataFrameT = TypeVar("DataFrameT", bound="DataFrame[t.Any, t.Any]") Order: TypeAlias = t.Literal["ascending", "descending"] +NonCrossJoinStrategy: TypeAlias = t.Literal["inner", "left", "full", "semi", "anti"] +PartialSeries: TypeAlias = "Callable[[Iterable[t.Any]], Series[NativeSeriesAnyT]]" diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 905f97ac40..2d63cb99f2 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -15,7 +15,7 @@ from narwhals import _plan as nwp from narwhals._utils import Version from narwhals.exceptions import ComputeError -from tests.utils import assert_equal_data +from tests.plan.utils import assert_equal_data, dataframe if TYPE_CHECKING: from collections.abc import Sequence @@ -411,9 +411,7 @@ def test_select( expected: dict[str, Any], data_small: dict[str, Any], ) -> None: - frame = pa.table(data_small) - df = nwp.DataFrame.from_native(frame) - result = df.select(expr).to_dict(as_series=False) + result = dataframe(data_small).select(expr) assert_equal_data(result, expected) @@ -485,9 +483,7 @@ def test_with_columns( expected: dict[str, Any], data_smaller: dict[str, Any], ) -> None: - frame = pa.table(data_smaller) - df = nwp.DataFrame.from_native(frame) - result = df.with_columns(expr).to_dict(as_series=False) + result = dataframe(data_smaller).with_columns(expr) assert_equal_data(result, expected) @@ -515,11 +511,11 @@ def test_first_last_expr_with_columns( ) -> None: """Related https://github.com/narwhals-dev/narwhals/pull/2528#discussion_r2225930065.""" height = len(next(iter(data_indexed.values()))) - expected_broadcast = height * [expected] - frame = nwp.DataFrame.from_native(pa.table(data_indexed)) + expected_full = {"result": height * [expected]} + frame = dataframe(data_indexed) expr = agg.over(order_by="idx").alias("result") - result = frame.with_columns(expr).select("result").to_dict(as_series=False) - assert_equal_data(result, {"result": expected_broadcast}) + result = frame.with_columns(expr).select("result") + assert_equal_data(result, expected_full) @pytest.mark.parametrize( @@ -528,7 +524,7 @@ def test_first_last_expr_with_columns( def test_row_is_py_literal( data_indexed: dict[str, Any], index: int, expected: tuple[PythonLiteral, ...] ) -> None: - frame = nwp.DataFrame.from_native(pa.table(data_indexed)) + frame = dataframe(data_indexed) result = frame.row(index) assert all(v is None or isinstance(v, (int, float)) for v in result) assert result == expected @@ -549,12 +545,18 @@ def test_protocol_expr() -> None: doesn't happen elsewhere at the moment. """ pytest.importorskip("pyarrow") + from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar + from narwhals._plan.arrow.series import ArrowSeries expr = ArrowExpr() scalar = ArrowScalar() + df = ArrowDataFrame() + ser = ArrowSeries() assert expr assert scalar + assert df + assert ser def test_dataframe_from_native_overloads() -> None: """Ensure we can reveal the `NativeSeries` **without** a dependency.""" diff --git a/tests/plan/frame_filter_test.py b/tests/plan/frame_filter_test.py new file mode 100644 index 0000000000..94edf66482 --- /dev/null +++ b/tests/plan/frame_filter_test.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Any + +import pytest + +pytest.importorskip("pyarrow") + +import narwhals._plan as nwp +from narwhals.exceptions import ColumnNotFoundError, ShapeError +from tests.plan.utils import assert_equal_data, dataframe, series + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture +def data() -> Data: + return {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]} + + +@pytest.fixture +def data_2() -> Data: + """Dataset used [upstream]. + + [upstream]: https://github.com/pola-rs/polars/blob/a4522d719de940be3ef99d494ccd1cd6067475c6/py-polars/tests/unit/lazyframe/test_lazyframe.py#L175-L182 + """ + return {"a": [1, 1, 1, 2, 2], "b": [1, 1, 2, 2, 2], "c": [1, 1, 2, 3, 4]} + + +@pytest.mark.parametrize( + "predicate", + [[False, True, True], series([False, True, True]), nwp.col("a") > 1], + ids=["list[bool]", "Series", "Expr"], +) +def test_filter_single( + data: Data, predicate: list[bool] | nwp.Series[Any] | nwp.Expr +) -> None: + expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]} + result = dataframe(data).filter(predicate) + assert_equal_data(result, expected) + + +def test_filter_aggregated_predicate(data: Data) -> None: + # NOTE: Unclear why this isn't permitted on `main` + pytest.importorskip("polars") + import polars as pl + + expected = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]} + expected_invert: Data = {"a": [], "b": [], "z": []} + df = dataframe(data) + df_pl = pl.DataFrame(data) + predicate = nwp.col("a").max() > 2 + predicate_pl = pl.col("a").max() > 2 + + result = df.filter(predicate) + result_pl = df_pl.filter(predicate_pl).to_dict(as_series=False) + assert_equal_data(result, expected) + assert_equal_data(result, result_pl) + + result = df.filter(~predicate) + result_pl = df_pl.filter(~predicate_pl).to_dict(as_series=False) + assert_equal_data(result, expected_invert) + assert_equal_data(result, result_pl) + + +def test_filter_raise_on_shape_mismatch(data: Data) -> None: + df = dataframe(data) + with pytest.raises(ShapeError): + df.filter(nwp.col("b").filter(nwp.col("b") < 6)) + + +def test_filter_with_constraints() -> None: + df = dataframe({"a": [1, 3, 2], "b": [4, 4, 6]}) + result_scalar = df.filter(a=3) + expected_scalar = {"a": [3], "b": [4]} + assert_equal_data(result_scalar, expected_scalar) + result_expr = df.filter(a=nwp.col("b") // 3) + expected_expr = {"a": [1, 2], "b": [4, 6]} + assert_equal_data(result_expr, expected_expr) + + +def test_filter_missing_column() -> None: + df = dataframe({"a": [1, 2], "b": [3, 4]}) + msg = ( + r"The following columns were not found: \[.*\]" + r"\n\nHint: Did you mean one of these columns: \['a', 'b'\]?" + ) + with pytest.raises(ColumnNotFoundError, match=msg): + df.filter(c=5) + + +def test_filter_mask_mixed() -> None: + df = dataframe({"a": range(5), "b": [2, 2, 4, 2, 4]}) + mask = [True, False, True, True, False] + mask_2 = [True, True, False, True, False] + expected_mask_only = {"a": [0, 2, 3], "b": [2, 4, 2]} + expected_mixed = {"a": [0, 3], "b": [2, 2]} + + assert_equal_data(df.filter(mask), expected_mask_only) + + with pytest.raises( + ColumnNotFoundError, match=re.escape("not found: ['c', 'd', 'e', 'f', 'g']") + ): + df.filter(mask, c=1, d=2, e=3, f=4, g=5) + + assert_equal_data(df.filter(mask, b=2), expected_mixed) + assert_equal_data(df.filter(mask, nwp.col("b") == 2), expected_mixed) + assert_equal_data(df.filter(mask, mask_2), expected_mixed) + assert_equal_data(df.filter(mask, series(mask_2)), expected_mixed) + assert_equal_data(df.filter(mask, nwp.col("b") != 4, b=2), expected_mixed) + + +def test_filter_multiple_predicates(data_2: Data) -> None: + """https://github.com/pola-rs/polars/blob/a4522d719de940be3ef99d494ccd1cd6067475c6/py-polars/tests/unit/lazyframe/test_lazyframe.py#L175-L202.""" + df = dataframe(data_2) + + # multiple predicates + expected = {"a": [1, 1, 1], "b": [1, 1, 2], "c": [1, 1, 2]} + for out in ( + df.filter(nwp.col("a") == 1, nwp.col("b") <= 2), # positional/splat + df.filter([nwp.col("a") == 1, nwp.col("b") <= 2]), # as list + ): + assert_equal_data(out, expected) + + # multiple kwargs + assert_equal_data(df.filter(a=1, b=2), {"a": [1], "b": [2], "c": [2]}) + + # both positional and keyword args + assert_equal_data( + df.filter(nwp.col("c") < 4, a=2, b=2), {"a": [2], "b": [2], "c": [3]} + ) + + +def test_filter_string_predicate() -> None: + """https://github.com/pola-rs/polars/blob/a4522d719de940be3ef99d494ccd1cd6067475c6/py-polars/tests/unit/lazyframe/test_lazyframe.py#L204-L210.""" + data = {"description": ["eq", "gt", "ge"], "predicate": ["==", ">", ">="]} + expected = {"description": ["eq"], "predicate": ["=="]} + df = dataframe(data) + result = df.filter(predicate="==") + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + "predicate", + [ + [nwp.lit(True)], + iter([nwp.lit(True)]), + [True, True, True], + iter([True, True, True]), + (p for p in (nwp.col("z") < 10,)), + (p for p in (nwp.col("a") > 0, nwp.col("b") > 0)), + ], +) +def test_filter_seq_iterable_all_true(predicate: Any, data: Data) -> None: + """https://github.com/pola-rs/polars/blob/a4522d719de940be3ef99d494ccd1cd6067475c6/py-polars/tests/unit/lazyframe/test_lazyframe.py#L213-L233.""" + df = dataframe(data) + assert_equal_data(df.filter(predicate), df.to_dict(as_series=False)) + + +@pytest.mark.parametrize( + "predicate", + [ + [nwp.lit(False)], + iter([nwp.lit(False)]), + [False, False, False], + iter([False, False, False]), + (p for p in (nwp.col("z") > 10,)), + (p for p in (nwp.col("a") < 0, nwp.col("b") < 0)), + ], +) +def test_filter_seq_iterable_all_false(predicate: Any, data: Data) -> None: + df = dataframe(data) + expected: Data = {"a": [], "b": [], "z": []} + assert_equal_data(df.filter(predicate), expected) diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index 2b60c118db..e1e60f3605 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -9,6 +9,7 @@ from narwhals import _plan as nwp from narwhals._plan import selectors as npcs from narwhals.exceptions import InvalidOperationError +from tests.plan.utils import assert_equal_data, dataframe from tests.utils import PYARROW_VERSION, assert_equal_data as _assert_equal_data pytest.importorskip("pyarrow") @@ -22,14 +23,6 @@ from narwhals._plan.typing import IntoExpr -def dataframe(data: dict[str, Any], /) -> nwp.DataFrame[Any, Any]: - return nwp.DataFrame.from_native(pa.table(data)) - - -def assert_equal_data(result: nwp.DataFrame, expected: Mapping[str, Any]) -> None: - _assert_equal_data(result.to_dict(as_series=False), expected) - - def test_group_by_iter() -> None: data = {"a": [1, 1, 3], "b": [4, 4, 6], "c": [7.0, 8.0, 9.0]} df = dataframe(data) diff --git a/tests/plan/is_first_last_distinct_test.py b/tests/plan/is_first_last_distinct_test.py new file mode 100644 index 0000000000..0b23eac761 --- /dev/null +++ b/tests/plan/is_first_last_distinct_test.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from narwhals import _plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture +def data() -> Data: + return {"a": [1, 1, 2, 3, 3, 2], "b": [1, 2, 3, 2, 1, 3]} + + +@pytest.fixture +def data_indexed(data: Data) -> Data: + return data | {"i": [None, 1, 2, 3, 4, 5]} + + +@pytest.fixture +def data_alt_1() -> Data: + return {"a": [1, 1, 2, 2, 2], "b": [1, 3, 3, 2, 3]} + + +@pytest.fixture +def data_alt_1_indexed(data_alt_1: Data) -> Data: + return data_alt_1 | {"i": [0, 1, 2, 3, 4]} + + +@pytest.fixture +def data_alt_2() -> Data: + return {"a": [1, 1, 2, 2, 2], "b": [1, 2, 2, 2, 1]} + + +@pytest.fixture +def data_alt_2_indexed(data_alt_2: Data) -> Data: + return data_alt_2 | {"i": [None, 1, 2, 3, 4]} + + +@pytest.fixture +def expected() -> Data: + return { + "a": [True, False, True, True, False, False], + "b": [True, True, True, False, False, False], + } + + +@pytest.fixture +def expected_invert(expected: Data) -> Data: + return {k: [not el for el in v] for k, v in expected.items()} + + +# NOTE: Isn't supported on `main` for `pyarrow` + lots of other cases (non-elementary group-by agg) +# Could be interesting to attempt here? +XFAIL_PARTITIONED_ORDER_BY = pytest.mark.xfail( + reason="Not supporting `over(*partition_by, order_by=...)` yet", + raises=NotImplementedError, +) + + +def test_is_first_distinct(data: Data, expected: Data) -> None: + result = dataframe(data).select(nwp.all().is_first_distinct()) + assert_equal_data(result, expected) + + +def test_is_last_distinct(data: Data, expected_invert: Data) -> None: + result = dataframe(data).select(nwp.all().is_last_distinct()) + assert_equal_data(result, expected_invert) + + +def test_is_first_distinct_order_by(data_indexed: Data, expected: Data) -> None: + result = ( + dataframe(data_indexed) + .select(nwp.col("a", "b").is_first_distinct().over(order_by="i"), "i") + .sort("i") + .drop("i") + ) + assert_equal_data(result, expected) + + +def test_is_last_distinct_order_by(data_indexed: Data, expected_invert: Data) -> None: + result = ( + dataframe(data_indexed) + .select(nwp.col("a", "b").is_last_distinct().over(order_by="i"), "i") + .sort("i") + .drop("i") + ) + assert_equal_data(result, expected_invert) + + +@XFAIL_PARTITIONED_ORDER_BY +def test_is_first_distinct_partitioned_order_by( + data_alt_1_indexed: Data, +) -> None: # pragma: no cover + expected = {"b": [True, True, True, True, False]} + result = ( + dataframe(data_alt_1_indexed) + .select(nwp.col("b").is_first_distinct().over("a", order_by="i"), "i") + .sort("i") + .drop("i") + ) + assert_equal_data(result, expected) + + +@XFAIL_PARTITIONED_ORDER_BY +def test_is_last_distinct_partitioned_order_by( + data_alt_1_indexed: Data, +) -> None: # pragma: no cover + expected = {"b": [True, True, False, True, True]} + result = ( + dataframe(data_alt_1_indexed) + .select(nwp.col("b").is_last_distinct().over("a", order_by="i"), "i") + .sort("i") + .drop("i") + ) + assert_equal_data(result, expected) + + +@XFAIL_PARTITIONED_ORDER_BY +def test_is_last_distinct_partitioned_order_by_nulls( + data_alt_2_indexed: Data, +) -> None: # pragma: no cover + expected = {"b": [True, True, False, True, True]} + result = ( + dataframe(data_alt_2_indexed) + .select(nwp.col("b").is_last_distinct().over("a", order_by="i"), "i") + .sort("i") + .drop("i") + ) + assert_equal_data(result, expected) diff --git a/tests/plan/join_test.py b/tests/plan/join_test.py new file mode 100644 index 0000000000..8ba63db9b9 --- /dev/null +++ b/tests/plan/join_test.py @@ -0,0 +1,321 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, TypedDict + +import pytest + +import narwhals._plan as nwp +from narwhals.exceptions import DuplicateError +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from collections.abc import Sequence + + from typing_extensions import TypeAlias + + from narwhals.typing import JoinStrategy + from tests.conftest import Data + + On: TypeAlias = "str | Sequence[str] | None" + + +class Keywords(TypedDict, total=False): + """Arguments for `DataFrame.join`.""" + + on: On + how: JoinStrategy + left_on: On + right_on: On + suffix: str + + +@pytest.fixture +def data() -> Data: + return {"a": [1, 3, 2], "b": [4, 4, 6], "zor ro": [7.0, 8.0, 9.0]} + + +@pytest.fixture +def data_indexed(data: Data) -> Data: + return data | {"idx": [0, 1, 2]} + + +@pytest.fixture +def data_a_only(data: Data) -> Data: + return {"a": data["a"]} + + +LEFT_DATA_1 = {"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"], "age": [25, 30, 35]} +RIGHT_DATA_1 = { + "id": [2, 3, 4], + "department": ["HR", "Engineering", "Marketing"], + "salary": [50000, 60000, 70000], +} +EXPECTED_DATA_1 = { + "id": [1, 2, 3, None], + "name": ["Alice", "Bob", "Charlie", None], + "age": [25, 30, 35, None], + "id_right": [None, 2, 3, 4], + "department": [None, "HR", "Engineering", "Marketing"], + "salary": [None, 50000, 60000, 70000], +} + + +@pytest.mark.parametrize( + ("left_data", "right_data", "expected", "kwds"), + [ + ( + LEFT_DATA_1, + RIGHT_DATA_1, + EXPECTED_DATA_1, + Keywords(left_on=["id"], right_on=["id"]), + ), + (LEFT_DATA_1, RIGHT_DATA_1, EXPECTED_DATA_1, Keywords(on="id")), + ( + { + "id": [1, 2, 3, 4], + "year": [2020, 2021, 2022, 2023], + "value1": [100, 200, 300, 400], + }, + { + "id": [2, 3, 4, 5], + "year_foo": [2021, 2022, 2023, 2024], + "value2": [500, 600, 700, 800], + }, + { + "id": [1, 2, 3, 4, None], + "year": [2020, 2021, 2022, 2023, None], + "value1": [100, 200, 300, 400, None], + "id_right": [None, 2, 3, 4, 5], + # since year is different, don't apply suffix + "year_foo": [None, 2021, 2022, 2023, 2024], + "value2": [None, 500, 600, 700, 800], + }, + Keywords(left_on=["id", "year"], right_on=["id", "year_foo"]), + ), + ], + ids=["left_on-right_on-identical", "on", "left_on-right_on-different"], +) +def test_join_full( + left_data: Data, right_data: Data, expected: Data, kwds: Keywords +) -> None: + kwds["how"] = "full" + result = ( + dataframe(left_data) + .join(dataframe(right_data), **kwds) + .sort("id", nulls_last=True) + ) + assert_equal_data(result, expected) + + +def test_join_full_duplicate() -> None: + left = dataframe({"f": [1, 2, 3], "v": [1, 2, 3]}) + right = left.rename({"v": "f_right"}) + with pytest.raises(DuplicateError): + left.join(right, "f", how="full", suffix="_right") + + +def test_join_inner_x2_duplicate(data_indexed: Data) -> None: + df = dataframe(data_indexed) + with pytest.raises(DuplicateError): + df.join(df, "a").join(df, "a") + + +@pytest.mark.parametrize("kwds", [Keywords(left_on="a", right_on="a"), Keywords(on="a")]) +def test_join_inner_single_key(data_indexed: Data, kwds: Keywords) -> None: + df = dataframe(data_indexed) + result = df.join(df, **kwds).sort("idx").drop("idx_right") + expected = { + "a": [1, 3, 2], + "b": [4, 4, 6], + "zor ro": [7.0, 8.0, 9.0], + "idx": [0, 1, 2], + "b_right": [4, 4, 6], + "zor ro_right": [7.0, 8.0, 9.0], + } + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + "kwds", [Keywords(left_on=["a", "b"], right_on=["a", "b"]), Keywords(on=["a", "b"])] +) +def test_join_inner_two_keys(data_indexed: Data, kwds: Keywords) -> None: + df = dataframe(data_indexed) + result = df.join(df, **kwds).sort("idx").drop("idx_right") + expected = { + "a": [1, 3, 2], + "b": [4, 4, 6], + "zor ro": [7.0, 8.0, 9.0], + "idx": [0, 1, 2], + "zor ro_right": [7.0, 8.0, 9.0], + } + assert_equal_data(result, expected) + + +def test_join_left() -> None: + data_left = {"a": [1.0, 2.0, 3.0], "b": [4.0, 5.0, 6.0], "idx": [0.0, 1.0, 2.0]} + data_right = {"a": [1.0, 2.0, 3.0], "co": [4.0, 5.0, 7.0], "idx": [0.0, 1.0, 2.0]} + df_left = dataframe(data_left) + df_right = dataframe(data_right) + result = ( + df_left.join(df_right, left_on="b", right_on="co", how="left") + .sort("idx") + .drop("idx_right") + ) + expected = {"a": [1, 2, 3], "b": [4, 5, 6], "idx": [0, 1, 2], "a_right": [1, 2, None]} + result_on_list = df_left.join(df_right, ["a", "idx"], how="left").sort("idx") + expected_on_list = {"a": [1, 2, 3], "b": [4, 5, 6], "idx": [0, 1, 2], "co": [4, 5, 7]} + assert_equal_data(result, expected) + assert_equal_data(result_on_list, expected_on_list) + + +def test_join_left_multiple_column() -> None: + df = dataframe({"a": [1, 2, 3], "b": [4, 5, 6], "idx": [0, 1, 2]}) + right = df.rename({"b": "c"}) + result = ( + df.join(right, left_on=["a", "b"], right_on=["a", "c"], how="left") + .sort("idx") + .drop("idx_right") + ) + expected = {"a": [1, 2, 3], "b": [4, 5, 6], "idx": [0, 1, 2]} + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("kwds", "expected"), + [ + ( + Keywords(left_on="b", right_on="c"), + { + "a": [1, 2, 3], + "b": [4, 5, 6], + "d": [1, 4, 2], + "idx": [0, 1, 2], + "a_right": [1, 2, 3], + "d_right": [1, 4, 2], + }, + ), + ( + Keywords(left_on="a", right_on="d"), + { + "a": [1, 2, 3], + "b": [4, 5, 6], + "d": [1, 4, 2], + "idx": [0, 1, 2], + "a_right": [1.0, 3.0, None], + "c": [4.0, 6.0, None], + }, + ), + ], +) +def test_join_left_overlapping_column(kwds: Keywords, expected: dict[str, Any]) -> None: + kwds["how"] = "left" + source = { + "a": [1.0, 2.0, 3.0], + "b": [4.0, 5.0, 6.0], + "d": [1.0, 4.0, 2.0], + "idx": [0.0, 1.0, 2.0], + } + df = dataframe(source) + right = df.rename({"b": "c"}) + result = df.join(right, **kwds).sort("idx").drop("idx_right") + assert_equal_data(result, expected) + + +def test_join_cross(data_a_only: Data) -> None: + df = dataframe(data_a_only) + result = df.join(df, how="cross").sort("a", "a_right") + expected = {"a": [1, 1, 1, 2, 2, 2, 3, 3, 3], "a_right": [1, 2, 3, 1, 2, 3, 1, 2, 3]} + assert_equal_data(result, expected) + + +@pytest.mark.parametrize("how", ["inner", "left"]) +@pytest.mark.parametrize("suffix", ["_right", "_custom_suffix"]) +def test_join_with_suffix(how: JoinStrategy, suffix: str, data: Data) -> None: + df = dataframe(data) + on = ["a", "b"] + result = df.join(df, left_on=on, right_on=on, how=how, suffix=suffix) + assert result.schema.names() == ["a", "b", "zor ro", f"zor ro{suffix}"] + + +@pytest.mark.parametrize("suffix", ["_right", "_custom_suffix"]) +def test_join_cross_with_suffix(suffix: str, data_a_only: Data) -> None: + df = dataframe(data_a_only) + result = df.join(df, how="cross", suffix=suffix).sort("a", f"a{suffix}") + expected = { + "a": [1, 1, 1, 2, 2, 2, 3, 3, 3], + f"a{suffix}": [1, 2, 3, 1, 2, 3, 1, 2, 3], + } + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("on", "predicate", "expected"), + [ + (["a", "b"], (nwp.col("b") < 5), {"a": [1, 3], "b": [4, 4], "zor ro": [7, 8]}), + (["b"], (nwp.col("b") < 5), {"a": [1, 3], "b": [4, 4], "zor ro": [7, 8]}), + (["b"], (nwp.col("b") > 5), {"a": [2], "b": [6], "zor ro": [9]}), + ], +) +@pytest.mark.parametrize("how", ["anti", "semi"]) +def test_join_filter( + on: str | Sequence[str], + predicate: nwp.Expr, + how: Literal["anti", "semi"], + expected: Data, + data: Data, +) -> None: + # NOTE: "anti" and "semi" should be the inverse of each other + df = dataframe(data) + other = df.filter(predicate if how == "semi" else ~predicate) + result = df.join(other, on, how=how).sort(on) + assert_equal_data(result, expected) + + +EITHER_LR_OR_ON = r"`left_on` and `right_on`.+or.+`on`" +ONLY_ON = r"`on` is specified.+`left_on` and `right_on`.+be.+None" +SAME_LENGTH = r"`left_on` and `right_on`.+same length" + + +@pytest.mark.parametrize( + ("kwds", "message"), + [ + (Keywords(), EITHER_LR_OR_ON), + (Keywords(left_on="a"), EITHER_LR_OR_ON), + (Keywords(right_on="a"), EITHER_LR_OR_ON), + (Keywords(on="a", right_on="a"), ONLY_ON), + (Keywords(left_on=["a", "b"], right_on="a"), SAME_LENGTH), + ], +) +@pytest.mark.parametrize("how", ["inner", "left", "semi", "anti"]) +def test_join_keys_exceptions( + how: JoinStrategy, kwds: Keywords, message: str, data: Data +) -> None: + df = dataframe(data) + kwds["how"] = how + with pytest.raises(ValueError, match=message): + df.join(df, **kwds) + + +@pytest.mark.parametrize( + "kwds", + [ + Keywords(left_on="a"), + Keywords(on="a"), + Keywords(right_on="a"), + Keywords(left_on="a", right_on="a"), + ], +) +def test_join_cross_keys_exceptions(kwds: Keywords, data_a_only: Data) -> None: + df = dataframe(data_a_only) + kwds["how"] = "cross" + with pytest.raises(ValueError, match=r"not.+ `left_on`.+`right_on`.+`on`.+cross"): + df.join(df, **kwds) + + +def test_join_not_implemented(data_a_only: Data) -> None: + df = dataframe(data_a_only) + pattern = ( + r"supported.+'inner', 'left', 'full', 'cross', 'semi', 'anti'.+ found 'right'" + ) + with pytest.raises(NotImplementedError, match=(pattern)): + df.join(df, left_on="a", right_on="a", how="right") # type: ignore[arg-type] diff --git a/tests/plan/utils.py b/tests/plan/utils.py index d1ae2ce95e..8f857b694c 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -1,11 +1,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any + +import pytest from narwhals import _plan as nwp from narwhals._plan import expressions as ir +from tests.utils import assert_equal_data as _assert_equal_data + +pytest.importorskip("pyarrow") + +import pyarrow as pa if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + from typing_extensions import LiteralString @@ -51,3 +60,17 @@ def assert_expr_ir_equal( def named_ir(name: str, expr: nwp.Expr | ir.ExprIR, /) -> ir.NamedIR[ir.ExprIR]: """Helper constructor for test compare.""" return ir.NamedIR(expr=expr._ir if isinstance(expr, nwp.Expr) else expr, name=name) + + +def dataframe(data: dict[str, Any], /) -> nwp.DataFrame[pa.Table, pa.ChunkedArray[Any]]: + return nwp.DataFrame.from_native(pa.table(data)) + + +def series(values: Iterable[Any], /) -> nwp.Series[pa.ChunkedArray[Any]]: + return nwp.Series.from_native(pa.chunked_array([values])) + + +def assert_equal_data( + result: nwp.DataFrame[Any, Any], expected: Mapping[str, Any] +) -> None: + _assert_equal_data(result.to_dict(as_series=False), expected) From 1a433a991be9740a7bbd217ec57f064159a8cb9e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 12 Oct 2025 16:20:29 +0000 Subject: [PATCH 364/368] fix: Update alias import --- narwhals/_plan/arrow/functions.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 32255d37ec..53b56d19b8 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -24,8 +24,7 @@ from typing_extensions import TypeAlias, TypeIs - from narwhals._arrow.dataframe import PromoteOptions - from narwhals._arrow.typing import Incomplete + from narwhals._arrow.typing import Incomplete, PromoteOptions from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.arrow.typing import ( Array, From bf544d4337678908f6a602c0bdc4342bd68a4f04 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 13 Oct 2025 15:59:52 +0000 Subject: [PATCH 365/368] fix(expr-ir): Ensure only `__slots__`, and not `__dict__` too (#3201) --- narwhals/_plan/_immutable.py | 146 ++++++++++------------ narwhals/_plan/_meta.py | 91 ++++++++++++++ narwhals/_plan/expr.py | 8 +- narwhals/_plan/expressions/__init__.py | 10 ++ narwhals/_plan/expressions/aggregation.py | 6 +- narwhals/_plan/expressions/expr.py | 2 +- narwhals/_plan/expressions/functions.py | 36 ++++-- narwhals/_plan/options.py | 4 +- narwhals/_plan/schema.py | 11 +- tests/plan/expr_parsing_test.py | 61 +++++++++ tests/plan/immutable_test.py | 71 ++++++++++- 11 files changed, 331 insertions(+), 115 deletions(-) create mode 100644 narwhals/_plan/_meta.py diff --git a/narwhals/_plan/_immutable.py b/narwhals/_plan/_immutable.py index 0abe0739b6..09cc7fc90a 100644 --- a/narwhals/_plan/_immutable.py +++ b/narwhals/_plan/_immutable.py @@ -1,44 +1,21 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, TypeVar +from typing import TYPE_CHECKING + +# ruff: noqa: N806 +from narwhals._plan._meta import ImmutableMeta if TYPE_CHECKING: - from collections.abc import Iterator - from typing import Any, Callable - - from typing_extensions import Never, Self, dataclass_transform - -else: - # https://docs.python.org/3/library/typing.html#typing.dataclass_transform - def dataclass_transform( - *, - eq_default: bool = True, - order_default: bool = False, - kw_only_default: bool = False, - frozen_default: bool = False, - field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = (), - **kwargs: Any, - ) -> Callable[[T], T]: - def decorator(cls_or_fn: T) -> T: - cls_or_fn.__dataclass_transform__ = { - "eq_default": eq_default, - "order_default": order_default, - "kw_only_default": kw_only_default, - "frozen_default": frozen_default, - "field_specifiers": field_specifiers, - "kwargs": kwargs, - } - return cls_or_fn - - return decorator - - -T = TypeVar("T") -_IMMUTABLE_HASH_NAME: Literal["__immutable_hash_value__"] = "__immutable_hash_value__" - - -@dataclass_transform(kw_only_default=True, frozen_default=True) -class Immutable: + from collections.abc import Iterable, Iterator + from typing import Any, ClassVar, Final + + from typing_extensions import Never, Self + + +_HASH_NAME: Final = "__immutable_hash_value__" + + +class Immutable(metaclass=ImmutableMeta): """A poor man's frozen dataclass. - Keyword-only constructor (IDE supported) @@ -49,40 +26,43 @@ class Immutable: [`copy.replace`]: https://docs.python.org/3.13/library/copy.html#copy.replace """ - __slots__ = (_IMMUTABLE_HASH_NAME,) - __immutable_hash_value__: int + __slots__ = (_HASH_NAME,) + if not TYPE_CHECKING: + # NOTE: Trying to avoid this being added to synthesized `__init__` + # Seems to be the only difference when decorating the metaclass + __immutable_hash_value__: int - @property - def __immutable_keys__(self) -> Iterator[str]: - slots: tuple[str, ...] = self.__slots__ - for name in slots: - if name != _IMMUTABLE_HASH_NAME: - yield name + __immutable_keys__: ClassVar[tuple[str, ...]] @property def __immutable_values__(self) -> Iterator[Any]: + """Override to configure hash seed.""" + getattr_ = getattr for name in self.__immutable_keys__: - yield getattr(self, name) + yield getattr_(self, name) @property def __immutable_items__(self) -> Iterator[tuple[str, Any]]: + getattr_ = getattr for name in self.__immutable_keys__: - yield name, getattr(self, name) + yield name, getattr_(self, name) @property def __immutable_hash__(self) -> int: - if hasattr(self, _IMMUTABLE_HASH_NAME): - return self.__immutable_hash_value__ - hash_value = hash((self.__class__, *self.__immutable_values__)) - object.__setattr__(self, _IMMUTABLE_HASH_NAME, hash_value) - return self.__immutable_hash_value__ + HASH = _HASH_NAME + if hasattr(self, HASH): + hash_value: int = getattr(self, HASH) + else: + hash_value = hash((self.__class__, *self.__immutable_values__)) + object.__setattr__(self, HASH, hash_value) + return hash_value def __setattr__(self, name: str, value: Never) -> Never: msg = f"{type(self).__name__!r} is immutable, {name!r} cannot be set." raise AttributeError(msg) def __replace__(self, **changes: Any) -> Self: - """https://docs.python.org/3.13/library/copy.html#copy.replace""" # noqa: D415 + """https://docs.python.org/3.13/library/copy.html#copy.replace.""" if len(changes) == 1: # The most common case is a single field replacement. # Iff that field happens to be equal, we can noop, preserving the current object's hash. @@ -96,13 +76,6 @@ def __replace__(self, **changes: Any) -> Self: changes[name] = value_current return type(self)(**changes) - def __init_subclass__(cls, *args: Any, **kwds: Any) -> None: - super().__init_subclass__(*args, **kwds) - if cls.__slots__: - ... - else: - cls.__slots__ = () - def __hash__(self) -> int: return self.__immutable_hash__ @@ -111,8 +84,9 @@ def __eq__(self, other: object) -> bool: return True if type(self) is not type(other): return False + getattr_ = getattr return all( - getattr(self, key) == getattr(other, key) for key in self.__immutable_keys__ + getattr_(self, key) == getattr_(other, key) for key in self.__immutable_keys__ ) def __str__(self) -> str: @@ -120,26 +94,16 @@ def __str__(self) -> str: return f"{type(self).__name__}({fields})" def __init__(self, **kwds: Any) -> None: - required: set[str] = set(self.__immutable_keys__) - if not required and not kwds: - # NOTE: Fastpath for empty slots - ... - elif required == set(kwds): - for name, value in kwds.items(): - object.__setattr__(self, name, value) - elif missing := required.difference(kwds): - msg = ( - f"{type(self).__name__!r} requires attributes {sorted(required)!r}, \n" - f"but missing values for {sorted(missing)!r}" - ) - raise TypeError(msg) - else: - extra = set(kwds).difference(required) - msg = ( - f"{type(self).__name__!r} only supports attributes {sorted(required)!r}, \n" - f"but got unknown arguments {sorted(extra)!r}" - ) - raise TypeError(msg) + if (keys := self.__immutable_keys__) or kwds: + required = set(keys) + if required == kwds.keys(): + object__setattr__ = object.__setattr__ + for name, value in kwds.items(): + object__setattr__(self, name, value) + elif missing := required.difference(kwds): + raise _init_missing_error(self, required, missing) + else: + raise _init_extra_error(self, required, set(kwds).difference(required)) def _field_str(name: str, value: Any) -> str: @@ -149,3 +113,23 @@ def _field_str(name: str, value: Any) -> str: if isinstance(value, str): return f"{name}={value!r}" return f"{name}={value}" + + +def _init_missing_error( + obj: object, required: Iterable[str], missing: Iterable[str] +) -> TypeError: + msg = ( + f"{type(obj).__name__!r} requires attributes {sorted(required)!r}, \n" + f"but missing values for {sorted(missing)!r}" + ) + return TypeError(msg) + + +def _init_extra_error( + obj: object, required: Iterable[str], extra: Iterable[str] +) -> TypeError: + msg = ( + f"{type(obj).__name__!r} only supports attributes {sorted(required)!r}, \n" + f"but got unknown arguments {sorted(extra)!r}" + ) + return TypeError(msg) diff --git a/narwhals/_plan/_meta.py b/narwhals/_plan/_meta.py new file mode 100644 index 0000000000..9a0e7d10ec --- /dev/null +++ b/narwhals/_plan/_meta.py @@ -0,0 +1,91 @@ +"""Metaclasses and other unholy metaprogramming nonsense.""" + +from __future__ import annotations + +# ruff: noqa: N806 +from itertools import chain +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Callable + from typing import Any, Final, TypeVar + + import _typeshed + from typing_extensions import dataclass_transform + + from narwhals._plan.typing import Seq + + T = TypeVar("T") + +else: + # https://docs.python.org/3/library/typing.html#typing.dataclass_transform + def dataclass_transform( + *, + eq_default: bool = True, + order_default: bool = False, + kw_only_default: bool = False, + frozen_default: bool = False, + field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = (), + **kwargs: Any, + ) -> Callable[[T], T]: + def decorator(cls_or_fn: T) -> T: + cls_or_fn.__dataclass_transform__ = { + "eq_default": eq_default, + "order_default": order_default, + "kw_only_default": kw_only_default, + "frozen_default": frozen_default, + "field_specifiers": field_specifiers, + "kwargs": kwargs, + } + return cls_or_fn + + return decorator + + +__all__ = ["ImmutableMeta", "SlottedMeta", "dataclass_transform"] + +flatten = chain.from_iterable +_KEYS_NAME: Final = "__immutable_keys__" +_HASH_NAME: Final = "__immutable_hash_value__" + + +class SlottedMeta(type): + """Ensure [`__slots__`] are always defined to prevent `__dict__` creation. + + [`__slots__`]: https://docs.python.org/3/reference/datamodel.html#object.__slots__ + """ + + # https://github.com/python/typeshed/blob/776508741d76b58f9dcb2aaf42f7d4596a48d580/stdlib/abc.pyi#L13-L19 + # https://github.com/python/typeshed/blob/776508741d76b58f9dcb2aaf42f7d4596a48d580/stdlib/_typeshed/__init__.pyi#L36-L40 + # https://github.com/astral-sh/ruff/issues/8353#issuecomment-1786238311 + # https://docs.python.org/3/reference/datamodel.html#creating-the-class-object + def __new__( + metacls: type[_typeshed.Self], + cls_name: str, + bases: tuple[type, ...], + namespace: dict[str, Any], + /, + **kwds: Any, + ) -> _typeshed.Self: + namespace.setdefault("__slots__", ()) + return super().__new__(metacls, cls_name, bases, namespace, **kwds) # type: ignore[no-any-return, misc] + + +@dataclass_transform(kw_only_default=True, frozen_default=True) +class ImmutableMeta(SlottedMeta): + def __new__( + metacls: type[_typeshed.Self], + cls_name: str, + bases: tuple[type, ...], + namespace: dict[str, Any], + /, + **kwds: Any, + ) -> _typeshed.Self: + KEYS, HASH = _KEYS_NAME, _HASH_NAME + getattr_: Callable[..., Seq[str]] = getattr + it_bases = (getattr_(b, KEYS, ()) for b in bases) + it_all = chain( + flatten(it_bases), namespace.get(KEYS, namespace.get("__slots__", ())) + ) + namespace[KEYS] = tuple(key for key in it_all if key != HASH) + return super().__new__(metacls, cls_name, bases, namespace, **kwds) # type: ignore[no-any-return, misc] diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index a56f58c7f6..f1dea8ac80 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -194,13 +194,11 @@ def hist( if bin_count is not None: msg = "can only provide one of `bin_count` or `bins`" raise ComputeError(msg) - node = F.HistBins(bins=tuple(bins), include_breakpoint=include_breakpoint) + node = F.Hist.from_bins(bins, include_breakpoint=include_breakpoint) elif bin_count is not None: - node = F.HistBinCount( - bin_count=bin_count, include_breakpoint=include_breakpoint - ) + node = F.Hist.from_bin_count(bin_count, include_breakpoint=include_breakpoint) else: - node = F.HistBinCount(include_breakpoint=include_breakpoint) + node = F.Hist.from_bin_count(include_breakpoint=include_breakpoint) return self._with_unary(node) def log(self, base: float = math.e) -> Self: diff --git a/narwhals/_plan/expressions/__init__.py b/narwhals/_plan/expressions/__init__.py index 5e5bef9362..1bdae0224f 100644 --- a/narwhals/_plan/expressions/__init__.py +++ b/narwhals/_plan/expressions/__init__.py @@ -10,9 +10,14 @@ from narwhals._plan.expressions import ( aggregation, boolean, + categorical, functions, + lists, operators, selectors, + strings, + struct, + temporal, ) from narwhals._plan.expressions.aggregation import AggExpr, OrderableAggExpr, max, min from narwhals._plan.expressions.expr import ( @@ -85,10 +90,12 @@ "_ColumnSelection", "aggregation", "boolean", + "categorical", "col", "cols", "functions", "index_columns", + "lists", "max", "min", "named_ir", @@ -97,4 +104,7 @@ "over", "over_ordered", "selectors", + "strings", + "struct", + "temporal", ] diff --git a/narwhals/_plan/expressions/aggregation.py b/narwhals/_plan/expressions/aggregation.py index 129da889aa..3b9ae41d11 100644 --- a/narwhals/_plan/expressions/aggregation.py +++ b/narwhals/_plan/expressions/aggregation.py @@ -50,18 +50,18 @@ class ArgMin(OrderableAggExpr): ... class ArgMax(OrderableAggExpr): ... # fmt: on class Quantile(AggExpr): - __slots__ = (*AggExpr.__slots__, "interpolation", "quantile") + __slots__ = ("interpolation", "quantile") quantile: float interpolation: RollingInterpolationMethod class Std(AggExpr): - __slots__ = (*AggExpr.__slots__, "ddof") + __slots__ = ("ddof",) ddof: int class Var(AggExpr): - __slots__ = (*AggExpr.__slots__, "ddof") + __slots__ = ("ddof",) ddof: int diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index ec5ffa36a1..42d00bdbd7 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -407,7 +407,7 @@ class OrderedWindowExpr( child=("expr", "partition_by", "order_by"), config=ExprIROptions.renamed("over_ordered"), ): - __slots__ = ("expr", "partition_by", "order_by", "sort_options", "options") # noqa: RUF023 + __slots__ = ("order_by", "sort_options") expr: ExprIR partition_by: Seq[ExprIR] order_by: Seq[ExprIR] diff --git a/narwhals/_plan/expressions/functions.py b/narwhals/_plan/expressions/functions.py index f7ff80abe5..8e910113a1 100644 --- a/narwhals/_plan/expressions/functions.py +++ b/narwhals/_plan/expressions/functions.py @@ -9,8 +9,10 @@ from narwhals._plan.options import FunctionFlags, FunctionOptions if TYPE_CHECKING: + from collections.abc import Iterable from typing import Any + from _typeshed import ConvertibleToInt from typing_extensions import Self from narwhals._plan._expr_ir import ExprIR @@ -71,27 +73,35 @@ class Hist(Function): def __repr__(self) -> str: return "hist" - -class HistBins(Hist): - __slots__ = ("bins", *Hist.__slots__) - bins: Seq[float] - - def __init__(self, *, bins: Seq[float], include_breakpoint: bool = True) -> None: + # NOTE: These constructors provide validation + defaults, and avoid + # repeating on every `__init__` afterwards + # They're also more widely defined to what will work at runtime + @staticmethod + def from_bins( + bins: Iterable[float], /, *, include_breakpoint: bool = True + ) -> HistBins: + bins = tuple(bins) for i in range(1, len(bins)): if bins[i - 1] >= bins[i]: raise hist_bins_monotonic_error(bins) - object.__setattr__(self, "bins", bins) - object.__setattr__(self, "include_breakpoint", include_breakpoint) + return HistBins(bins=bins, include_breakpoint=include_breakpoint) + + @staticmethod + def from_bin_count( + count: ConvertibleToInt = 10, /, *, include_breakpoint: bool = True + ) -> HistBinCount: + return HistBinCount(bin_count=int(count), include_breakpoint=include_breakpoint) + + +class HistBins(Hist): + __slots__ = ("bins",) + bins: Seq[float] class HistBinCount(Hist): - __slots__ = ("bin_count", *Hist.__slots__) + __slots__ = ("bin_count",) bin_count: int - def __init__(self, *, bin_count: int = 10, include_breakpoint: bool = True) -> None: - object.__setattr__(self, "bin_count", bin_count) - object.__setattr__(self, "include_breakpoint", include_breakpoint) - class Log(Function, options=FunctionOptions.elementwise): __slots__ = ("base",) diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index ca0d101676..f8ee347cac 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -312,7 +312,7 @@ def namespaced(cls, override_name: str = "", /) -> Self: class ExprIROptions(_BaseIROptions): - __slots__ = (*_BaseIROptions.__slots__, "allow_dispatch") + __slots__ = ("allow_dispatch",) allow_dispatch: bool @classmethod @@ -325,7 +325,7 @@ def no_dispatch() -> ExprIROptions: class FunctionExprOptions(_BaseIROptions): - __slots__ = (*_BaseIROptions.__slots__, "accessor_name") + __slots__ = ("accessor_name",) accessor_name: Accessor | None """Namespace accessor name, if any.""" diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index 67433db06b..10c6665d08 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload from narwhals._plan._expr_ir import NamedIR -from narwhals._plan._immutable import _IMMUTABLE_HASH_NAME, Immutable +from narwhals._plan._immutable import Immutable from narwhals._utils import _hasattr_static from narwhals.dtypes import Unknown @@ -86,12 +86,9 @@ def with_columns_irs(self, exprs: Seq[NamedIR]) -> Seq[NamedIR]: return tuple(chain(it, named.values())) @property - def __immutable_hash__(self) -> int: - if hasattr(self, _IMMUTABLE_HASH_NAME): - return self.__immutable_hash_value__ - hash_value = hash((self.__class__, *tuple(self._mapping.items()))) - object.__setattr__(self, _IMMUTABLE_HASH_NAME, hash_value) - return self.__immutable_hash_value__ + def __immutable_values__(self) -> Iterator[Any]: + # Repurposed `self._mapping.items()` as a hash seed + yield from tuple(self.items()) @property def names(self) -> FrozenColumns: diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 5b525001a9..e556352268 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -16,6 +16,7 @@ from narwhals._plan.expressions import functions as F, operators as ops from narwhals._plan.expressions.literal import SeriesLiteral from narwhals.exceptions import ( + ComputeError, InvalidIntoExprError, InvalidOperationError, InvalidOperationError as LengthChangingExprError, @@ -468,3 +469,63 @@ def test_operators_left_right( assert isinstance(ir_2, ir.FunctionExpr) assert isinstance(ir_2.function, op) assert tuple(reversed(ir_2.input)) == ir_1.input + + +def test_hist_bins() -> None: + bins_values = (0, 1.5, 3.0, 4.5, 6.0) + a = nwp.col("a") + hist_1 = a.hist(deque(bins_values), include_breakpoint=False) + hist_2 = a.hist(list(bins_values), include_breakpoint=False) + + ir_1 = hist_1._ir + ir_2 = hist_2._ir + assert isinstance(ir_1, ir.FunctionExpr) + assert isinstance(ir_2, ir.FunctionExpr) + assert isinstance(ir_1.function, F.HistBins) + assert isinstance(ir_2.function, F.HistBins) + assert ir_1.function.include_breakpoint is False + assert_expr_ir_equal(ir_1, ir_2) + + +def test_hist_bin_count() -> None: + bin_count_default = 10 + include_breakpoint_default = True + a = nwp.col("a") + hist_1 = a.hist( + bin_count=bin_count_default, include_breakpoint=include_breakpoint_default + ) + hist_2 = a.hist() + hist_3 = a.hist(bin_count=5) + hist_4 = a.hist(include_breakpoint=False) + + ir_1 = hist_1._ir + ir_2 = hist_2._ir + ir_3 = hist_3._ir + ir_4 = hist_4._ir + assert isinstance(ir_1, ir.FunctionExpr) + assert isinstance(ir_2, ir.FunctionExpr) + assert isinstance(ir_3, ir.FunctionExpr) + assert isinstance(ir_4, ir.FunctionExpr) + assert isinstance(ir_1.function, F.HistBinCount) + assert isinstance(ir_2.function, F.HistBinCount) + assert isinstance(ir_3.function, F.HistBinCount) + assert isinstance(ir_4.function, F.HistBinCount) + assert ir_1.function.include_breakpoint is include_breakpoint_default + assert ir_2.function.bin_count == bin_count_default + assert_expr_ir_equal(ir_1, ir_2) + assert ir_3.function.include_breakpoint != ir_4.function.include_breakpoint + assert ir_4.function.bin_count != ir_3.function.bin_count + assert ir_4 != ir_2 + assert ir_3 != ir_1 + + +def test_hist_invalid() -> None: + a = nwp.col("a") + with pytest.raises(ComputeError, match=r"bin_count.+or.+bins"): + a.hist(bins=[1], bin_count=1) + with pytest.raises(ComputeError, match=r"bins.+monotonic"): + a.hist([1, 5, 4]) + with pytest.raises(ComputeError, match=r"bins.+monotonic"): + a.hist(deque((3, 2, 1))) + with pytest.raises(TypeError): + a.hist(1) # type: ignore[arg-type] diff --git a/tests/plan/immutable_test.py b/tests/plan/immutable_test.py index 6f9d0450ad..8e60759d0b 100644 --- a/tests/plan/immutable_test.py +++ b/tests/plan/immutable_test.py @@ -1,13 +1,19 @@ from __future__ import annotations +import re import string from itertools import repeat -from typing import Any +from typing import TYPE_CHECKING, Any, TypeVar import pytest +from narwhals._plan import when_then from narwhals._plan._immutable import Immutable +if TYPE_CHECKING: + from collections.abc import Iterator +T_co = TypeVar("T_co", covariant=True) + class Empty(Immutable): ... @@ -141,9 +147,68 @@ def test_immutable_hash_cache() -> None: obj = TwoSlot(a=int_long, b=str_long) with pytest.raises(AttributeError): - uncached = obj.__immutable_hash_value__ # noqa: F841 + _ = getattr(obj, "__immutable_hash_value__") # noqa: B009 hash_cache_miss = hash(obj) - cached = obj.__immutable_hash_value__ + cached = getattr(obj, "__immutable_hash_value__") # noqa: B009 hash_cache_hit = hash(obj) assert hash_cache_miss == cached == hash_cache_hit + + +def _collect_immutable_descendants() -> list[type[Immutable]]: + # NOTE: Will populate `__subclasses__` by bringing the defs into scope + from narwhals._plan import ( + _expansion, + _expr_ir, + _function, + expressions, + options, + schema, + when_then, + ) + + _ = expressions, schema, options, _expansion, _expr_ir, _function, when_then + return sorted(set(_iter_descendants(Immutable)), key=repr) + + +def _iter_descendants(*bases: type[T_co]) -> Iterator[type[T_co]]: + seen = set[T_co]() + for base in bases: + yield base + if (children := (base.__subclasses__())) and ( + unseen := set(children).difference(seen) + ): + yield from _iter_descendants(*unseen) + + +ALLOW_DICT_TO_AVOID_MULTIPLE_BASES_HAVE_INSTANCE_LAYOUT_CONFLICT_ERROR = frozenset( + (when_then.Then, when_then.ChainedThen) +) + + +@pytest.fixture(params=_collect_immutable_descendants(), ids=lambda tp: tp.__name__) +def immutable_type(request: pytest.FixtureRequest) -> type[Immutable]: + tp: type[Immutable] = request.param + request.applymarker( + pytest.mark.xfail( + tp in ALLOW_DICT_TO_AVOID_MULTIPLE_BASES_HAVE_INSTANCE_LAYOUT_CONFLICT_ERROR, + reason="Multiple inheritance + `__slots__` = bad", + ) + ) + return tp + + +def test_immutable___slots___(immutable_type: type[Immutable]) -> None: + featureless_instance = object.__new__(immutable_type) + + # NOTE: If this fails, `__setattr__` has been overridden + with pytest.raises(AttributeError, match=r"immutable"): + featureless_instance.i_dont_exist = 999 # type: ignore[assignment] + + # NOTE: If this fails, `__slots__` lose the size benefit + with pytest.raises(AttributeError, match=re.escape("has no attribute '__dict__'")): + _ = featureless_instance.__dict__ + + slots = immutable_type.__slots__ + if slots: + assert len(slots) != 0, slots From fa0899ad03aa01802f433b3efc80d256a083ee5a Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 15 Oct 2025 18:01:04 +0000 Subject: [PATCH 366/368] perf(expr-ir): Optimize `ExpansionFlags.from_ir` (#3206) --- narwhals/_plan/_expansion.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/_expansion.py b/narwhals/_plan/_expansion.py index 6cbf061a98..20a886c3a3 100644 --- a/narwhals/_plan/_expansion.py +++ b/narwhals/_plan/_expansion.py @@ -130,16 +130,17 @@ def from_ir(ir: ExprIR, /) -> ExpansionFlags: has_selector: bool = False has_exclude: bool = False for e in ir.iter_left(): - if isinstance(e, (Columns, IndexColumns)): - multiple_columns = True - elif isinstance(e, Nth): - has_nth = True - elif isinstance(e, All): - has_wildcard = True - elif isinstance(e, SelectorIR): - has_selector = True - elif isinstance(e, Exclude): - has_exclude = True + if isinstance(e, (_ColumnSelection, SelectorIR)): + if isinstance(e, (Columns, IndexColumns)): + multiple_columns = True + elif isinstance(e, Nth): + has_nth = True + elif isinstance(e, All): + has_wildcard = True + elif isinstance(e, SelectorIR): + has_selector = True + elif isinstance(e, Exclude): + has_exclude = True return ExpansionFlags( multiple_columns=multiple_columns, has_nth=has_nth, From 34a1fd00d417cc63a04e9bbd73030896d34536e8 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 18 Oct 2025 15:53:18 +0000 Subject: [PATCH 367/368] refactor(expr-ir): Improve function dispatch (#3215) --- .pre-commit-config.yaml | 2 + narwhals/_plan/_dispatch.py | 188 ++++++++++++++++++++++ narwhals/_plan/_expr_ir.py | 44 ++--- narwhals/_plan/_function.py | 28 +--- narwhals/_plan/arrow/group_by.py | 9 +- narwhals/_plan/common.py | 45 +----- narwhals/_plan/compliant/expr.py | 3 + narwhals/_plan/compliant/scalar.py | 6 + narwhals/_plan/expressions/aggregation.py | 3 +- narwhals/_plan/expressions/expr.py | 8 +- narwhals/_plan/options.py | 22 ++- tests/plan/compliant_test.py | 3 +- tests/plan/dispatch_test.py | 105 ++++++++++++ 13 files changed, 342 insertions(+), 124 deletions(-) create mode 100644 narwhals/_plan/_dispatch.py create mode 100644 tests/plan/dispatch_test.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d342428700..961e48e892 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -79,6 +79,8 @@ repos: entry: "self: Self" language: pygrep files: ^narwhals/ + # mypy needs `Self` for `ExprIR.dispatch` + exclude: ^narwhals/_plan/.*\.py - id: dtypes-import name: don't import from narwhals.dtypes (use `Version.dtypes` instead) entry: | diff --git a/narwhals/_plan/_dispatch.py b/narwhals/_plan/_dispatch.py new file mode 100644 index 0000000000..044b786e28 --- /dev/null +++ b/narwhals/_plan/_dispatch.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +import re +from collections.abc import Callable +from operator import attrgetter +from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, final, overload + +from narwhals._plan._guards import is_function_expr +from narwhals._plan.compliant.typing import FrameT_contra, R_co +from narwhals._typing_compat import TypeVar + +if TYPE_CHECKING: + from typing_extensions import Never, TypeAlias + + from narwhals._plan.compliant.typing import Ctx + from narwhals._plan.expressions import ExprIR, FunctionExpr + from narwhals._plan.typing import ExprIRT, FunctionT + +__all__ = ["Dispatcher", "get_dispatch_name"] + + +Node = TypeVar("Node", bound="ExprIR | FunctionExpr[Any]") +Node_contra = TypeVar( + "Node_contra", bound="ExprIR | FunctionExpr[Any]", contravariant=True +) +Raiser: TypeAlias = Callable[..., "Never"] + + +class Binder(Protocol[Node_contra]): + def __call__( + self, ctx: Ctx[FrameT_contra, R_co], / + ) -> BoundMethod[Node_contra, FrameT_contra, R_co]: ... + + +class BoundMethod(Protocol[Node_contra, FrameT_contra, R_co]): + def __call__(self, node: Node_contra, frame: FrameT_contra, name: str, /) -> R_co: ... + + +@final +class Dispatcher(Generic[Node]): + """Translate class definitions into error-wrapped method calls. + + Operates over `ExprIR` and `Function` nodes. + + By default, we dispatch to the compliant-level by calling a method that is the + **snake_case**-equivalent of the class name: + + class BinaryExpr(ExprIR): ... + + class CompliantExpr(Protocol): + def binary_expr(self, *args: Any): ... + """ + + __slots__ = ("_bind", "_name") + _bind: Binder[Node] + _name: str + + @property + def name(self) -> str: + return self._name + + def __repr__(self) -> str: + return f"{type(self).__name__}<{self.name}>" + + def bind( + self, ctx: Ctx[FrameT_contra, R_co], / + ) -> BoundMethod[Node, FrameT_contra, R_co]: + """Retrieve the implementation of this expression from `ctx`. + + Binds an instance method, most commonly via: + + expr: CompliantExpr + method = getattr(expr, "method_name") + """ + try: + return self._bind(ctx) + except AttributeError: + raise self._not_implemented_error(ctx, "compliant") from None + + def __call__( + self, + node: Node, + ctx: Ctx[FrameT_contra, R_co], + frame: FrameT_contra, + name: str, + /, + ) -> R_co: + """Evaluate this expression in `frame`, using implementation(s) provided by `ctx`.""" + method = self.bind(ctx) + if result := method(node, frame, name): + return result + raise self._not_implemented_error(ctx, "context") + + @staticmethod + def from_expr_ir(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]: + if not tp.__expr_ir_config__.allow_dispatch: + return Dispatcher._no_dispatch(tp) + return Dispatcher._from_type(tp) + + @staticmethod + def from_function(tp: type[FunctionT], /) -> Dispatcher[FunctionExpr[FunctionT]]: + return Dispatcher._from_type(tp) + + @staticmethod + @overload + def _from_type(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]: ... + @staticmethod + @overload + def _from_type(tp: type[FunctionT], /) -> Dispatcher[FunctionExpr[FunctionT]]: ... + @staticmethod + def _from_type(tp: type[ExprIRT | FunctionT], /) -> Dispatcher[Any]: + obj = Dispatcher.__new__(Dispatcher) + obj._name = _method_name(tp) + getter = attrgetter(obj._name) + is_namespaced = tp.__expr_ir_config__.is_namespaced + obj._bind = _via_namespace(getter) if is_namespaced else getter + return obj + + @staticmethod + def _no_dispatch(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]: + obj = Dispatcher.__new__(Dispatcher) + obj._name = tp.__name__ + obj._bind = obj._make_no_dispatch_error() + return obj + + def _make_no_dispatch_error(self) -> Callable[[Any], Raiser]: + def _no_dispatch_error(node: Node, *_: Any) -> Never: + msg = ( + f"{self.name!r} should not appear at the compliant-level.\n\n" + f"Make sure to expand all expressions first, got:\n{node!r}" + ) + raise TypeError(msg) + + def getter(_: Any, /) -> Raiser: + return _no_dispatch_error + + return getter + + def _not_implemented_error( + self, ctx: object, /, missing: Literal["compliant", "context"] + ) -> NotImplementedError: + if missing == "context": + msg = f"`{self.name}` is not yet implemented for {type(ctx).__name__!r}" + else: + msg = ( + f"`{self.name}` has not been implemented at the compliant-level.\n" + f"Hint: Try adding `CompliantExpr.{self.name}()` or `CompliantNamespace.{self.name}()`" + ) + return NotImplementedError(msg) + + +def _via_namespace(getter: Callable[[Any], Any], /) -> Callable[[Any], Any]: + def _(ctx: Any, /) -> Any: + return getter(ctx.__narwhals_namespace__()) + + return _ + + +def _pascal_to_snake_case(s: str) -> str: + """Convert a PascalCase string to snake_case. + + Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62 + """ + # Handle the sequence of uppercase letters followed by a lowercase letter + snake = _PATTERN_UPPER_LOWER.sub(_re_repl_snake, s) + # Insert an underscore between a lowercase letter and an uppercase letter + return _PATTERN_LOWER_UPPER.sub(_re_repl_snake, snake).lower() + + +_PATTERN_UPPER_LOWER = re.compile(r"([A-Z]+)([A-Z][a-z])") +_PATTERN_LOWER_UPPER = re.compile(r"([a-z])([A-Z])") + + +def _re_repl_snake(match: re.Match[str], /) -> str: + return f"{match.group(1)}_{match.group(2)}" + + +def _method_name(tp: type[ExprIRT | FunctionT]) -> str: + config = tp.__expr_ir_config__ + name = config.override_name or _pascal_to_snake_case(tp.__name__) + return f"{ns}.{name}" if (ns := getattr(config, "accessor_name", "")) else name + + +def get_dispatch_name(expr: ExprIR, /) -> str: + """Return the synthesized method name for `expr`.""" + return ( + repr(expr.function) if is_function_expr(expr) else expr.__expr_ir_dispatch__.name + ) diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py index f0c9a08548..8f66f2d9fa 100644 --- a/narwhals/_plan/_expr_ir.py +++ b/narwhals/_plan/_expr_ir.py @@ -1,19 +1,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Generic, cast +from typing import TYPE_CHECKING, Generic +from narwhals._plan._dispatch import Dispatcher from narwhals._plan._guards import is_function_expr, is_literal from narwhals._plan._immutable import Immutable -from narwhals._plan.common import dispatch_getter, replace +from narwhals._plan.common import replace from narwhals._plan.options import ExprIROptions from narwhals._plan.typing import ExprIRT from narwhals.utils import Version if TYPE_CHECKING: - from collections.abc import Callable, Iterator + from collections.abc import Iterator from typing import Any, ClassVar - from typing_extensions import Self, TypeAlias + from typing_extensions import Self from narwhals._plan.compliant.typing import Ctx, FrameT_contra, R_co from narwhals._plan.expr import Expr, Selector @@ -22,29 +23,6 @@ from narwhals._plan.typing import ExprIRT2, MapIR, Seq from narwhals.dtypes import DType - Incomplete: TypeAlias = "Any" - - -def _dispatch_generate( - tp: type[ExprIRT], / -) -> Callable[[Incomplete, ExprIRT, Incomplete, str], Incomplete]: - if not tp.__expr_ir_config__.allow_dispatch: - - def _(ctx: Any, /, node: ExprIRT, _: Any, name: str) -> Any: - msg = ( - f"{tp.__name__!r} should not appear at the compliant-level.\n\n" - f"Make sure to expand all expressions first, got:\n{ctx!r}\n{node!r}\n{name!r}" - ) - raise TypeError(msg) - - return _ - getter = dispatch_getter(tp) - - def _(ctx: Any, /, node: ExprIRT, frame: Any, name: str) -> Any: - return getter(ctx)(node, frame, name) - - return _ - class ExprIR(Immutable): """Anything that can be a node on a graph of expressions.""" @@ -53,9 +31,7 @@ class ExprIR(Immutable): """Nested node names, in iteration order.""" __expr_ir_config__: ClassVar[ExprIROptions] = ExprIROptions.default() - __expr_ir_dispatch__: ClassVar[ - staticmethod[[Incomplete, Self, Incomplete, str], Incomplete] - ] + __expr_ir_dispatch__: ClassVar[Dispatcher[Self]] def __init_subclass__( cls: type[Self], @@ -69,13 +45,13 @@ def __init_subclass__( cls._child = child if config: cls.__expr_ir_config__ = config - cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate(cls)) + cls.__expr_ir_dispatch__ = Dispatcher.from_expr_ir(cls) def dispatch( - self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, / + self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, / ) -> R_co: - """Evaluate expression in `frame`, using `ctx` for implementation(s).""" - return self.__expr_ir_dispatch__(ctx, cast("Self", self), frame, name) # type: ignore[no-any-return] + """Evaluate this expression in `frame`, using implementation(s) provided by `ctx`.""" + return self.__expr_ir_dispatch__(self, ctx, frame, name) def to_narwhals(self, version: Version = Version.MAIN) -> Expr: from narwhals._plan import expr diff --git a/narwhals/_plan/_function.py b/narwhals/_plan/_function.py index 332dbfc085..8b71a4dd8d 100644 --- a/narwhals/_plan/_function.py +++ b/narwhals/_plan/_function.py @@ -2,33 +2,21 @@ from typing import TYPE_CHECKING +from narwhals._plan._dispatch import Dispatcher from narwhals._plan._immutable import Immutable -from narwhals._plan.common import dispatch_getter, dispatch_method_name, replace +from narwhals._plan.common import replace from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: from typing import Any, Callable, ClassVar - from typing_extensions import Self, TypeAlias + from typing_extensions import Self from narwhals._plan.expressions import ExprIR, FunctionExpr - from narwhals._plan.typing import Accessor, FunctionT + from narwhals._plan.typing import Accessor __all__ = ["Function", "HorizontalFunction"] -Incomplete: TypeAlias = "Any" - - -def _dispatch_generate_function( - tp: type[FunctionT], / -) -> Callable[[Incomplete, FunctionExpr[FunctionT], Incomplete, str], Incomplete]: - getter = dispatch_getter(tp) - - def _(ctx: Any, /, node: FunctionExpr[FunctionT], frame: Any, name: str) -> Any: - return getter(ctx)(node, frame, name) - - return _ - class Function(Immutable): """Shared by expr functions and namespace functions. @@ -40,9 +28,7 @@ class Function(Immutable): FunctionOptions.default ) __expr_ir_config__: ClassVar[FEOptions] = FEOptions.default() - __expr_ir_dispatch__: ClassVar[ - staticmethod[[Incomplete, FunctionExpr[Self], Incomplete, str], Incomplete] - ] + __expr_ir_dispatch__: ClassVar[Dispatcher[FunctionExpr[Self]]] @property def function_options(self) -> FunctionOptions: @@ -72,10 +58,10 @@ def __init_subclass__( cls._function_options = staticmethod(options) if config: cls.__expr_ir_config__ = config - cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate_function(cls)) + cls.__expr_ir_dispatch__ = Dispatcher.from_function(cls) def __repr__(self) -> str: - return dispatch_method_name(type(self)) + return self.__expr_ir_dispatch__.name class HorizontalFunction( diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index b519261c53..df57f781c1 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -6,9 +6,10 @@ import pyarrow.compute as pc # ignore-banned-import from narwhals._plan import expressions as ir +from narwhals._plan._dispatch import get_dispatch_name from narwhals._plan._guards import is_agg_expr, is_function_expr from narwhals._plan.arrow import acero, functions as fn, options -from narwhals._plan.common import dispatch_method_name, temp +from narwhals._plan.common import temp from narwhals._plan.compliant.group_by import EagerDataFrameGroupBy from narwhals._plan.expressions import aggregation as agg from narwhals._utils import Implementation @@ -132,11 +133,7 @@ def group_by_error( if reason == "too complex": msg = "Non-trivial complex aggregation found, which" else: - if is_function_expr(expr): - func_name = repr(expr.function) - else: - func_name = dispatch_method_name(type(expr)) - msg = f"`{func_name}()`" + msg = f"`{get_dispatch_name(expr)}()`" msg = f"{msg} is not supported in a `group_by` context for {backend!r}:\n{column_name}={expr!r}" return InvalidOperationError(msg) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index c29e8e8071..8ac2084034 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -1,11 +1,9 @@ from __future__ import annotations import datetime as dt -import re import sys from collections.abc import Iterable from decimal import Decimal -from operator import attrgetter from secrets import token_hex from typing import TYPE_CHECKING, cast, overload @@ -18,20 +16,13 @@ if TYPE_CHECKING: import reprlib from collections.abc import Iterator - from typing import Any, Callable, ClassVar, TypeVar + from typing import Any, ClassVar, TypeVar from typing_extensions import TypeIs from narwhals._plan.compliant.series import CompliantSeries from narwhals._plan.series import Series - from narwhals._plan.typing import ( - DTypeT, - ExprIRT, - FunctionT, - NonNestedDTypeT, - OneOrIterable, - Seq, - ) + from narwhals._plan.typing import DTypeT, NonNestedDTypeT, OneOrIterable, Seq from narwhals._utils import _StoresColumns from narwhals.typing import NonNestedDType, NonNestedLiteral @@ -51,38 +42,6 @@ def replace(obj: T, /, **changes: Any) -> T: return func(obj, **changes) # type: ignore[no-any-return] -def pascal_to_snake_case(s: str) -> str: - """Convert a PascalCase, camelCase string to snake_case. - - Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62 - """ - # Handle the sequence of uppercase letters followed by a lowercase letter - snake = _PATTERN_UPPER_LOWER.sub(_re_repl_snake, s) - # Insert an underscore between a lowercase letter and an uppercase letter - return _PATTERN_LOWER_UPPER.sub(_re_repl_snake, snake).lower() - - -_PATTERN_UPPER_LOWER = re.compile(r"([A-Z]+)([A-Z][a-z])") -_PATTERN_LOWER_UPPER = re.compile(r"([a-z])([A-Z])") - - -def _re_repl_snake(match: re.Match[str], /) -> str: - return f"{match.group(1)}_{match.group(2)}" - - -def dispatch_method_name(tp: type[ExprIRT | FunctionT]) -> str: - config = tp.__expr_ir_config__ - name = config.override_name or pascal_to_snake_case(tp.__name__) - return f"{ns}.{name}" if (ns := getattr(config, "accessor_name", "")) else name - - -def dispatch_getter(tp: type[ExprIRT | FunctionT]) -> Callable[[Any], Any]: - getter = attrgetter(dispatch_method_name(tp)) - if tp.__expr_ir_config__.origin == "expr": - return getter - return lambda ctx: getter(ctx.__narwhals_namespace__()) - - def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) -> DType: dtypes = version.dtypes mapping: dict[type[NonNestedLiteral], type[NonNestedDType]] = { diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index 5defde7d61..e8d0dffbed 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -54,6 +54,9 @@ def name(self) -> str: ... def abs(self, node: FunctionExpr[F.Abs], frame: FrameT_contra, name: str) -> Self: ... def binary_expr(self, node: BinaryExpr, frame: FrameT_contra, name: str) -> Self: ... def cast(self, node: ir.Cast, frame: FrameT_contra, name: str) -> Self: ... + def ewm_mean( + self, node: FunctionExpr[F.EwmMean], frame: FrameT_contra, name: str + ) -> Self: ... def fill_null( self, node: FunctionExpr[F.FillNull], frame: FrameT_contra, name: str ) -> Self: ... diff --git a/narwhals/_plan/compliant/scalar.py b/narwhals/_plan/compliant/scalar.py index abb873aa5e..25c07d7de7 100644 --- a/narwhals/_plan/compliant/scalar.py +++ b/narwhals/_plan/compliant/scalar.py @@ -11,6 +11,7 @@ from narwhals._plan import expressions as ir from narwhals._plan.expressions import FunctionExpr, aggregation as agg from narwhals._plan.expressions.boolean import IsFirstDistinct, IsLastDistinct + from narwhals._plan.expressions.functions import EwmMean from narwhals._utils import Version from narwhals.typing import IntoDType, PythonLiteral @@ -58,6 +59,11 @@ def count(self, node: agg.Count, frame: FrameT_contra, name: str) -> Self: """Returns 0 if null, else 1.""" ... + def ewm_mean( + self, node: FunctionExpr[EwmMean], frame: FrameT_contra, name: str + ) -> Self: + return self._cast_float(node.input[0], frame, name) + def first(self, node: agg.First, frame: FrameT_contra, name: str) -> Self: return self._with_evaluated(self._evaluated, name) diff --git a/narwhals/_plan/expressions/aggregation.py b/narwhals/_plan/expressions/aggregation.py index 3b9ae41d11..92a563f586 100644 --- a/narwhals/_plan/expressions/aggregation.py +++ b/narwhals/_plan/expressions/aggregation.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any from narwhals._plan._expr_ir import ExprIR -from narwhals._plan.common import pascal_to_snake_case from narwhals._plan.exceptions import agg_scalar_error if TYPE_CHECKING: @@ -21,7 +20,7 @@ def is_scalar(self) -> bool: return True def __repr__(self) -> str: - return f"{self.expr!r}.{pascal_to_snake_case(type(self).__name__)}()" + return f"{self.expr!r}.{self.__expr_ir_dispatch__.name}()" def iter_output_name(self) -> Iterator[ExprIR]: yield from self.expr.iter_output_name() diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index 42d00bdbd7..4fa2f6cf6e 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -314,9 +314,9 @@ def __init__( super().__init__(**dict(input=input, function=function, options=options, **kwds)) def dispatch( - self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str + self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str ) -> R_co: - return self.function.__expr_ir_dispatch__(ctx, t.cast("Self", self), frame, name) # type: ignore[no-any-return] + return self.function.__expr_ir_dispatch__(self, ctx, frame, name) class RollingExpr(FunctionExpr[RollingT_co]): ... @@ -328,9 +328,9 @@ class AnonymousExpr( """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L158-L166.""" def dispatch( - self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str + self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str ) -> R_co: - return self.__expr_ir_dispatch__(ctx, t.cast("Self", self), frame, name) # type: ignore[no-any-return] + return self.__expr_ir_dispatch__(self, ctx, frame, name) class RangeExpr(FunctionExpr[RangeT_co]): diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index f8ee347cac..739654e4bf 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -2,7 +2,7 @@ import enum from itertools import repeat -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING from narwhals._plan._immutable import Immutable @@ -11,14 +11,12 @@ import pyarrow.acero import pyarrow.compute as pc - from typing_extensions import Self, TypeAlias + from typing_extensions import Self from narwhals._plan.arrow.typing import NullPlacement from narwhals._plan.typing import Accessor, OneOrIterable, Order, Seq from narwhals.typing import RankMethod -DispatchOrigin: TypeAlias = Literal["expr", "__narwhals_namespace__"] - class FunctionFlags(enum.Flag): ALLOW_GROUP_AWARE = 1 << 0 @@ -285,8 +283,8 @@ def rolling_options( class _BaseIROptions(Immutable): - __slots__ = ("origin", "override_name") - origin: DispatchOrigin + __slots__ = ("is_namespaced", "override_name") + is_namespaced: bool override_name: str def __repr__(self) -> str: @@ -294,7 +292,7 @@ def __repr__(self) -> str: @classmethod def default(cls) -> Self: - return cls(origin="expr", override_name="") + return cls(is_namespaced=False, override_name="") @classmethod def renamed(cls, name: str, /) -> Self: @@ -306,9 +304,7 @@ def renamed(cls, name: str, /) -> Self: def namespaced(cls, override_name: str = "", /) -> Self: from narwhals._plan.common import replace - return replace( - cls.default(), origin="__narwhals_namespace__", override_name=override_name - ) + return replace(cls.default(), is_namespaced=True, override_name=override_name) class ExprIROptions(_BaseIROptions): @@ -317,11 +313,11 @@ class ExprIROptions(_BaseIROptions): @classmethod def default(cls) -> Self: - return cls(origin="expr", override_name="", allow_dispatch=True) + return cls(is_namespaced=False, override_name="", allow_dispatch=True) @staticmethod def no_dispatch() -> ExprIROptions: - return ExprIROptions(origin="expr", override_name="", allow_dispatch=False) + return ExprIROptions(is_namespaced=False, override_name="", allow_dispatch=False) class FunctionExprOptions(_BaseIROptions): @@ -331,7 +327,7 @@ class FunctionExprOptions(_BaseIROptions): @classmethod def default(cls) -> Self: - return cls(origin="expr", override_name="", accessor_name=None) + return cls(is_namespaced=False, override_name="", accessor_name=None) FEOptions = FunctionExprOptions diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 2d63cb99f2..e011e5c547 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -549,7 +549,8 @@ def test_protocol_expr() -> None: from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar from narwhals._plan.arrow.series import ArrowSeries - expr = ArrowExpr() + # NOTE: Intentionally leaving `ewm_mean` without a `not_implemented()` for another test + expr = ArrowExpr() # type: ignore[abstract] scalar = ArrowScalar() df = ArrowDataFrame() ser = ArrowSeries() diff --git a/tests/plan/dispatch_test.py b/tests/plan/dispatch_test.py new file mode 100644 index 0000000000..db65d468a5 --- /dev/null +++ b/tests/plan/dispatch_test.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Any + +import pytest + +pytest.importorskip("pyarrow") +import narwhals as nw +from narwhals import _plan as nwp +from narwhals._plan import expressions as ir, selectors as ncs +from narwhals._plan._dispatch import get_dispatch_name +from tests.plan.utils import assert_equal_data, dataframe, named_ir + +if TYPE_CHECKING: + import sys + + import pyarrow as pa + from typing_extensions import TypeAlias + + from narwhals._plan.dataframe import DataFrame + + if sys.version_info >= (3, 11): + _Flags: TypeAlias = "int | re.RegexFlag" + else: + _Flags: TypeAlias = int + + +@pytest.fixture +def data() -> dict[str, Any]: + return { + "a": [12.1, None, 4.0], + "b": [42, 10, None], + "c": [4, 5, 6], + "d": ["play", "swim", "walk"], + } + + +@pytest.fixture +def df(data: dict[str, Any]) -> DataFrame[pa.Table, pa.ChunkedArray[Any]]: + return dataframe(data) + + +def re_compile( + pattern: str, flags: _Flags = re.DOTALL | re.IGNORECASE +) -> re.Pattern[str]: + return re.compile(pattern, flags) + + +def test_dispatch(df: DataFrame[pa.Table, pa.ChunkedArray[Any]]) -> None: + implemented_full = nwp.col("a").is_null() + forgot_to_expand = (named_ir("howdy", nwp.nth(3, 4).first()),) + aliased_after_expand: tuple[ir.NamedIR[Any]] = ( + ir.NamedIR.from_ir(ir.col("a").alias("b")), + ) + + assert_equal_data(df.select(implemented_full), {"a": [False, True, False]}) + + missing_backend = r"ewm_mean.+is not yet implemented for" + with pytest.raises(NotImplementedError, match=missing_backend): + df.select(nwp.col("c").ewm_mean()) + + missing_protocol = re_compile( + r"str\.contains.+has not been implemented.+compliant.+" + r"Hint.+try adding.+CompliantExpr\.str\.contains\(\)" + ) + with pytest.raises(NotImplementedError, match=missing_protocol): + df.select(nwp.col("d").str.contains("a")) + + with pytest.raises( + TypeError, + match=re_compile(r"IndexColumns.+not.+appear.+compliant.+expand.+expr.+first"), + ): + df._compliant.select(forgot_to_expand) + + bad = re.escape("col('a').alias('b')") + with pytest.raises(TypeError, match=re_compile(rf"Alias.+not.+appear.+got.+{bad}")): + df._compliant.select(aliased_after_expand) + + # Not a narwhals method, to make sure this doesn't allow arbitrary calls + with pytest.raises(AttributeError): + nwp.col("a").max().to_physical() # type: ignore[attr-defined] + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (nwp.col("a"), "col"), + (nwp.col("a").min().over("b"), "over"), + (nwp.col("a").first().over(order_by="b"), "over_ordered"), + (nwp.all_horizontal("a", "b", nwp.nth(4, 5, 6)), "all_horizontal"), + (nwp.int_range(10), "int_range"), + (nwp.col("a") + nwp.col("b") + 10, "binary_expr"), + (nwp.when(nwp.col("c")).then(5).when(nwp.col("d")).then(20), "ternary_expr"), + (nwp.col("a").cast(nw.String).str.starts_with("something"), ("str.starts_with")), + (nwp.mean("a"), "mean"), + (nwp.nth(1).first(), "first"), + (nwp.col("a").sum(), "sum"), + (nwp.col("a").drop_nulls().arg_min(), "arg_min"), + pytest.param(nwp.col("a").alias("b"), "Alias", id="no_dispatch-Alias"), + pytest.param(ncs.string(), "RootSelector", id="no_dispatch-RootSelector"), + ], +) +def test_dispatch_name(expr: nwp.Expr, expected: str) -> None: + assert get_dispatch_name(expr._ir) == expected From 46abfe66cc808a06303ba3eff767e83d4ef114cf Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 20 Oct 2025 18:23:24 +0000 Subject: [PATCH 368/368] chore: re-sync imports following (#3086) --- narwhals/_plan/arrow/typing.py | 2 +- narwhals/_plan/typing.py | 9 ++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index a515b99091..63333a49d4 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -23,7 +23,7 @@ ) from typing_extensions import TypeAlias - from narwhals.typing import NativeDataFrame, NativeSeries + from narwhals._native import NativeDataFrame, NativeSeries StringScalar: TypeAlias = "Scalar[StringType | LargeStringType]" IntegerType: TypeAlias = "Int8Type | Int16Type | Int32Type | Int64Type | Uint8Type | Uint16Type | Uint32Type | Uint64Type" diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 5214c2f92c..1449d245ac 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -10,6 +10,7 @@ from typing_extensions import TypeAlias from narwhals import dtypes + from narwhals._native import NativeDataFrame, NativeFrame, NativeSeries from narwhals._plan._expr_ir import ExprIR, NamedIR, SelectorIR from narwhals._plan._function import Function from narwhals._plan.dataframe import DataFrame @@ -19,13 +20,7 @@ from narwhals._plan.expressions.namespace import IRNamespace from narwhals._plan.expressions.ranges import RangeFunction from narwhals._plan.series import Series - from narwhals.typing import ( - NativeDataFrame, - NativeFrame, - NativeSeries, - NonNestedDType, - NonNestedLiteral, - ) + from narwhals.typing import NonNestedDType, NonNestedLiteral __all__ = [ "ColumnNameOrSelector",