Skip to content

Commit ee5daf9

Browse files
committed
refactor: Separate *Expr, *Scalar
Gonna be soooooooo much easier to work on them side-by-side now 🥳
1 parent 5a25d07 commit ee5daf9

File tree

5 files changed

+284
-278
lines changed

5 files changed

+284
-278
lines changed

narwhals/_plan/arrow/expr.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
from narwhals._plan.arrow.series import ArrowSeries as Series
1111
from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co
1212
from narwhals._plan.compliant.column import ExprDispatch
13+
from narwhals._plan.compliant.expr import EagerExpr
14+
from narwhals._plan.compliant.scalar import EagerScalar
1315
from narwhals._plan.compliant.typing import namespace
1416
from narwhals._plan.expressions import NamedIR
15-
from narwhals._plan.protocols import EagerExpr, EagerScalar
1617
from narwhals._utils import (
1718
Implementation,
1819
Version,

narwhals/_plan/compliant/expr.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any, Protocol
4+
5+
from narwhals._plan.compliant.column import EagerBroadcast, SupportsBroadcast
6+
from narwhals._plan.compliant.typing import (
7+
FrameT_contra,
8+
LengthT,
9+
SeriesT,
10+
SeriesT_co,
11+
StoresVersion,
12+
)
13+
from narwhals._utils import Version
14+
15+
if TYPE_CHECKING:
16+
from typing_extensions import Self
17+
18+
from narwhals._plan import expressions as ir
19+
from narwhals._plan.compliant.scalar import CompliantScalar
20+
from narwhals._plan.expressions import (
21+
BinaryExpr,
22+
FunctionExpr,
23+
aggregation as agg,
24+
boolean,
25+
functions as F,
26+
)
27+
from narwhals._plan.expressions.boolean import IsBetween, IsFinite, IsNan, IsNull, Not
28+
29+
30+
class CompliantExpr(StoresVersion, Protocol[FrameT_contra, SeriesT_co]):
31+
"""Everything common to `Expr`/`Series` and `Scalar` literal values."""
32+
33+
_evaluated: Any
34+
"""Compliant or native value."""
35+
36+
@property
37+
def name(self) -> str: ...
38+
@classmethod
39+
def from_native(
40+
cls, native: Any, name: str = "", /, version: Version = Version.MAIN
41+
) -> Self: ...
42+
def _with_native(self, native: Any, name: str, /) -> Self:
43+
return self.from_native(native, name or self.name, self.version)
44+
45+
# series & scalar
46+
def abs(self, node: FunctionExpr[F.Abs], frame: FrameT_contra, name: str) -> Self: ...
47+
def cast(self, node: ir.Cast, frame: FrameT_contra, name: str) -> Self: ...
48+
def pow(self, node: FunctionExpr[F.Pow], frame: FrameT_contra, name: str) -> Self: ...
49+
def not_(self, node: FunctionExpr[Not], frame: FrameT_contra, name: str) -> Self: ...
50+
def fill_null(
51+
self, node: FunctionExpr[F.FillNull], frame: FrameT_contra, name: str
52+
) -> Self: ...
53+
def is_between(
54+
self, node: FunctionExpr[IsBetween], frame: FrameT_contra, name: str
55+
) -> Self: ...
56+
def is_finite(
57+
self, node: FunctionExpr[IsFinite], frame: FrameT_contra, name: str
58+
) -> Self: ...
59+
def is_nan(
60+
self, node: FunctionExpr[IsNan], frame: FrameT_contra, name: str
61+
) -> Self: ...
62+
def is_null(
63+
self, node: FunctionExpr[IsNull], frame: FrameT_contra, name: str
64+
) -> Self: ...
65+
def binary_expr(self, node: BinaryExpr, frame: FrameT_contra, name: str) -> Self: ...
66+
def ternary_expr(
67+
self, node: ir.TernaryExpr, frame: FrameT_contra, name: str
68+
) -> Self: ...
69+
def over(self, node: ir.WindowExpr, frame: FrameT_contra, name: str) -> Self: ...
70+
# NOTE: `Scalar` is returned **only** for un-partitioned `OrderableAggExpr`
71+
# e.g. `nw.col("a").first().over(order_by="b")`
72+
def over_ordered(
73+
self, node: ir.OrderedWindowExpr, frame: FrameT_contra, name: str
74+
) -> Self | CompliantScalar[FrameT_contra, SeriesT_co]: ...
75+
def map_batches(
76+
self, node: ir.AnonymousExpr, frame: FrameT_contra, name: str
77+
) -> Self: ...
78+
def rolling_expr(
79+
self, node: ir.RollingExpr, frame: FrameT_contra, name: str
80+
) -> Self: ...
81+
# series only (section 3)
82+
def sort(self, node: ir.Sort, frame: FrameT_contra, name: str) -> Self: ...
83+
def sort_by(self, node: ir.SortBy, frame: FrameT_contra, name: str) -> Self: ...
84+
def filter(self, node: ir.Filter, frame: FrameT_contra, name: str) -> Self: ...
85+
# series -> scalar
86+
def first(
87+
self, node: agg.First, frame: FrameT_contra, name: str
88+
) -> CompliantScalar[FrameT_contra, SeriesT_co]: ...
89+
def last(
90+
self, node: agg.Last, frame: FrameT_contra, name: str
91+
) -> CompliantScalar[FrameT_contra, SeriesT_co]: ...
92+
def arg_min(
93+
self, node: agg.ArgMin, frame: FrameT_contra, name: str
94+
) -> CompliantScalar[FrameT_contra, SeriesT_co]: ...
95+
def arg_max(
96+
self, node: agg.ArgMax, frame: FrameT_contra, name: str
97+
) -> CompliantScalar[FrameT_contra, SeriesT_co]: ...
98+
def sum(
99+
self, node: agg.Sum, frame: FrameT_contra, name: str
100+
) -> CompliantScalar[FrameT_contra, SeriesT_co]: ...
101+
def n_unique(
102+
self, node: agg.NUnique, frame: FrameT_contra, name: str
103+
) -> CompliantScalar[FrameT_contra, SeriesT_co]: ...
104+
def std(
105+
self, node: agg.Std, frame: FrameT_contra, name: str
106+
) -> CompliantScalar[FrameT_contra, SeriesT_co]: ...
107+
def var(
108+
self, node: agg.Var, frame: FrameT_contra, name: str
109+
) -> CompliantScalar[FrameT_contra, SeriesT_co]: ...
110+
def quantile(
111+
self, node: agg.Quantile, frame: FrameT_contra, name: str
112+
) -> CompliantScalar[FrameT_contra, SeriesT_co]: ...
113+
def count(
114+
self, node: agg.Count, frame: FrameT_contra, name: str
115+
) -> CompliantScalar[FrameT_contra, SeriesT_co]: ...
116+
def len(
117+
self, node: agg.Len, frame: FrameT_contra, name: str
118+
) -> CompliantScalar[FrameT_contra, SeriesT_co]: ...
119+
def max(
120+
self, node: agg.Max, frame: FrameT_contra, name: str
121+
) -> CompliantScalar[FrameT_contra, SeriesT_co]: ...
122+
def mean(
123+
self, node: agg.Mean, frame: FrameT_contra, name: str
124+
) -> CompliantScalar[FrameT_contra, SeriesT_co]: ...
125+
def median(
126+
self, node: agg.Median, frame: FrameT_contra, name: str
127+
) -> CompliantScalar[FrameT_contra, SeriesT_co]: ...
128+
def min(
129+
self, node: agg.Min, frame: FrameT_contra, name: str
130+
) -> CompliantScalar[FrameT_contra, SeriesT_co]: ...
131+
def all(
132+
self, node: FunctionExpr[boolean.All], frame: FrameT_contra, name: str
133+
) -> CompliantScalar[FrameT_contra, SeriesT_co]: ...
134+
def any(
135+
self, node: FunctionExpr[boolean.Any], frame: FrameT_contra, name: str
136+
) -> CompliantScalar[FrameT_contra, SeriesT_co]: ...
137+
138+
139+
class EagerExpr(
140+
EagerBroadcast[SeriesT],
141+
CompliantExpr[FrameT_contra, SeriesT],
142+
Protocol[FrameT_contra, SeriesT],
143+
): ...
144+
145+
146+
class LazyExpr(
147+
SupportsBroadcast[SeriesT, LengthT],
148+
CompliantExpr[FrameT_contra, SeriesT],
149+
Protocol[FrameT_contra, SeriesT, LengthT],
150+
): ...

narwhals/_plan/compliant/scalar.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,131 @@
11
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any, Protocol
4+
5+
from narwhals._plan.compliant.expr import CompliantExpr, EagerExpr, LazyExpr
6+
from narwhals._plan.compliant.typing import FrameT_contra, LengthT, SeriesT, SeriesT_co
7+
8+
if TYPE_CHECKING:
9+
from typing_extensions import Self
10+
11+
from narwhals._plan import expressions as ir
12+
from narwhals._plan.expressions import aggregation as agg
13+
from narwhals._utils import Version
14+
from narwhals.typing import IntoDType, PythonLiteral
15+
16+
17+
class CompliantScalar(
18+
CompliantExpr[FrameT_contra, SeriesT_co], Protocol[FrameT_contra, SeriesT_co]
19+
):
20+
_name: str
21+
22+
@property
23+
def name(self) -> str:
24+
return self._name
25+
26+
@classmethod
27+
def from_python(
28+
cls,
29+
value: PythonLiteral,
30+
name: str = "literal",
31+
/,
32+
*,
33+
dtype: IntoDType | None,
34+
version: Version,
35+
) -> Self: ...
36+
def _with_evaluated(self, evaluated: Any, name: str) -> Self:
37+
"""Expr is based on a series having these via accessors, but a scalar needs to keep passing through."""
38+
cls = type(self)
39+
obj = cls.__new__(cls)
40+
obj._evaluated = evaluated
41+
obj._name = name or self.name
42+
obj._version = self.version
43+
return obj
44+
45+
def max(self, node: agg.Max, frame: FrameT_contra, name: str) -> Self:
46+
"""Returns self."""
47+
return self._with_evaluated(self._evaluated, name)
48+
49+
def min(self, node: agg.Min, frame: FrameT_contra, name: str) -> Self:
50+
"""Returns self."""
51+
return self._with_evaluated(self._evaluated, name)
52+
53+
def sum(self, node: agg.Sum, frame: FrameT_contra, name: str) -> Self:
54+
"""Returns self."""
55+
return self._with_evaluated(self._evaluated, name)
56+
57+
def first(self, node: agg.First, frame: FrameT_contra, name: str) -> Self:
58+
"""Returns self."""
59+
return self._with_evaluated(self._evaluated, name)
60+
61+
def last(self, node: agg.Last, frame: FrameT_contra, name: str) -> Self:
62+
"""Returns self."""
63+
return self._with_evaluated(self._evaluated, name)
64+
65+
def _cast_float(self, node: ir.ExprIR, frame: FrameT_contra, name: str) -> Self:
66+
"""`polars` interpolates a single scalar as a float."""
67+
dtype = self.version.dtypes.Float64()
68+
return self.cast(node.cast(dtype), frame, name)
69+
70+
def mean(self, node: agg.Mean, frame: FrameT_contra, name: str) -> Self:
71+
return self._cast_float(node.expr, frame, name)
72+
73+
def median(self, node: agg.Median, frame: FrameT_contra, name: str) -> Self:
74+
return self._cast_float(node.expr, frame, name)
75+
76+
def quantile(self, node: agg.Quantile, frame: FrameT_contra, name: str) -> Self:
77+
return self._cast_float(node.expr, frame, name)
78+
79+
def n_unique(self, node: agg.NUnique, frame: FrameT_contra, name: str) -> Self:
80+
"""Returns 1."""
81+
...
82+
83+
def std(self, node: agg.Std, frame: FrameT_contra, name: str) -> Self:
84+
"""Returns null."""
85+
...
86+
87+
def var(self, node: agg.Var, frame: FrameT_contra, name: str) -> Self:
88+
"""Returns null."""
89+
...
90+
91+
def arg_min(self, node: agg.ArgMin, frame: FrameT_contra, name: str) -> Self:
92+
"""Returns 0."""
93+
...
94+
95+
def arg_max(self, node: agg.ArgMax, frame: FrameT_contra, name: str) -> Self:
96+
"""Returns 0."""
97+
...
98+
99+
def count(self, node: agg.Count, frame: FrameT_contra, name: str) -> Self:
100+
"""Returns 0 if null, else 1."""
101+
...
102+
103+
def len(self, node: agg.Len, frame: FrameT_contra, name: str) -> Self:
104+
"""Returns 1."""
105+
...
106+
107+
def sort(self, node: ir.Sort, frame: FrameT_contra, name: str) -> Self:
108+
return self._with_evaluated(self._evaluated, name)
109+
110+
def sort_by(self, node: ir.SortBy, frame: FrameT_contra, name: str) -> Self:
111+
return self._with_evaluated(self._evaluated, name)
112+
113+
# NOTE: `Filter` behaves the same, (maybe) no need to override
114+
115+
116+
class EagerScalar(
117+
CompliantScalar[FrameT_contra, SeriesT],
118+
EagerExpr[FrameT_contra, SeriesT],
119+
Protocol[FrameT_contra, SeriesT],
120+
):
121+
def __len__(self) -> int:
122+
return 1
123+
124+
def to_python(self) -> PythonLiteral: ...
125+
126+
127+
class LazyScalar(
128+
CompliantScalar[FrameT_contra, SeriesT],
129+
LazyExpr[FrameT_contra, SeriesT, LengthT],
130+
Protocol[FrameT_contra, SeriesT, LengthT],
131+
): ...

narwhals/_plan/compliant/typing.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,11 @@
1313
CompliantDataFrame,
1414
EagerDataFrame,
1515
)
16+
from narwhals._plan.compliant.expr import CompliantExpr, EagerExpr, LazyExpr
1617
from narwhals._plan.compliant.group_by import GroupByResolver
1718
from narwhals._plan.compliant.namespace import CompliantNamespace
19+
from narwhals._plan.compliant.scalar import CompliantScalar, EagerScalar, LazyScalar
1820
from narwhals._plan.compliant.series import CompliantSeries
19-
from narwhals._plan.protocols import (
20-
CompliantExpr,
21-
CompliantScalar,
22-
EagerExpr,
23-
EagerScalar,
24-
LazyExpr,
25-
LazyScalar,
26-
)
2721
from narwhals._utils import Version
2822

2923
T = TypeVar("T")

0 commit comments

Comments
 (0)