Skip to content

Commit d085b3a

Browse files
committed
feat: Get 3x over() rules passing
Aaaaand added comments on rule origin
1 parent b13ebc6 commit d085b3a

File tree

3 files changed

+64
-16
lines changed

3 files changed

+64
-16
lines changed

narwhals/_plan/expr.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
RollingT,
2323
SelectorOperatorT,
2424
)
25+
from narwhals.exceptions import InvalidOperationError
2526

2627
if t.TYPE_CHECKING:
2728
from typing_extensions import Self
@@ -403,6 +404,34 @@ def iter_right(self) -> t.Iterator[ExprIR]:
403404
yield from e.iter_right()
404405
yield from self.expr.iter_right()
405406

407+
def __init__(
408+
self,
409+
*,
410+
expr: ExprIR,
411+
partition_by: Seq[ExprIR],
412+
order_by: tuple[Seq[ExprIR], SortOptions] | None,
413+
options: Window,
414+
) -> None:
415+
if isinstance(expr, WindowExpr):
416+
msg = "Cannot nest `over` statements."
417+
raise InvalidOperationError(msg)
418+
419+
if isinstance(expr, FunctionExpr):
420+
if expr.options.is_elementwise():
421+
msg = f"Cannot use `over` on expressions which are elementwise.\n{expr!r}"
422+
raise InvalidOperationError(msg)
423+
if expr.options.is_row_separable():
424+
msg = f"Cannot use `over` on expressions which change length.\n{expr!r}"
425+
raise InvalidOperationError(msg)
426+
427+
kwds = {
428+
"expr": expr,
429+
"partition_by": partition_by,
430+
"order_by": order_by,
431+
"options": options,
432+
}
433+
super().__init__(**kwds)
434+
406435

407436
class Len(ExprIR):
408437
@property

narwhals/_plan/options.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class FunctionFlags(enum.Flag):
2727
"""Automatically explode on unit length if it ran as final aggregation."""
2828

2929
ROW_SEPARABLE = 1 << 8
30-
"""Not sure lol.
30+
"""`drop_nulls` is the only one we've got that is *just* this.
3131
3232
https://github.com/pola-rs/polars/pull/22573
3333
"""
@@ -36,14 +36,17 @@ class FunctionFlags(enum.Flag):
3636
"""mutually exclusive with `RETURNS_SCALAR`"""
3737

3838
def is_elementwise(self) -> bool:
39-
return self in (FunctionFlags.ROW_SEPARABLE | FunctionFlags.LENGTH_PRESERVING)
39+
return (FunctionFlags.ROW_SEPARABLE | FunctionFlags.LENGTH_PRESERVING) in self
4040

4141
def returns_scalar(self) -> bool:
4242
return FunctionFlags.RETURNS_SCALAR in self
4343

4444
def is_length_preserving(self) -> bool:
4545
return FunctionFlags.LENGTH_PRESERVING in self
4646

47+
def is_row_separable(self) -> bool:
48+
return FunctionFlags.ROW_SEPARABLE in self
49+
4750
@staticmethod
4851
def default() -> FunctionFlags:
4952
return FunctionFlags.ALLOW_GROUP_AWARE
@@ -75,6 +78,9 @@ def returns_scalar(self) -> bool:
7578
def is_length_preserving(self) -> bool:
7679
return self.flags.is_length_preserving()
7780

81+
def is_row_separable(self) -> bool:
82+
return self.flags.is_row_separable()
83+
7884
def with_flags(self, flags: FunctionFlags, /) -> FunctionOptions:
7985
if (FunctionFlags.RETURNS_SCALAR | FunctionFlags.LENGTH_PRESERVING) in flags:
8086
msg = "A function cannot both return a scalar and preserve length, they are mutually exclusive."

tests/plan/expr_parsing_test.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import re
34
from typing import TYPE_CHECKING, Callable, Iterable
45

56
import pytest
@@ -81,17 +82,10 @@ def test_function_expr_horizontal(
8182
assert sequence_node != unrelated_node
8283

8384

84-
# TODO @dangotbanned: Get partity with the existing tests
85+
# TODO @dangotbanned: Get parity with the existing tests
8586
# https://github.com/narwhals-dev/narwhals/blob/63c8e4771a1df4e0bfeea5559c303a4a447d5cc2/tests/expression_parsing_test.py#L48-L105
8687

8788

88-
def test_misleading_order_by() -> None:
89-
with pytest.raises(InvalidOperationError):
90-
nw.col("a").mean().over(order_by="b")
91-
with pytest.raises(InvalidOperationError):
92-
nw.col("a").rank().over(order_by="b")
93-
94-
9589
# `test_double_over` is already covered in the later `test_nested_over`
9690

9791

@@ -107,6 +101,9 @@ def test_invalid_repeat_agg() -> None:
107101
nwd.col("a").all().quantile(0.5, "linear")
108102

109103

104+
# TODO @dangotbanned: Weirdly, `polars` suggestion **does** resolve it
105+
# InvalidOperationError: Series idx, length 1 doesn't match the DataFrame height of 9
106+
# If you want expression: col("idx").mean().drop_nulls() to be broadcasted, ensure it is a scalar (for instance by adding '.first()')
110107
def test_filter_aggregation() -> None:
111108
with pytest.raises(InvalidOperationError):
112109
nwd.col("a").mean().drop_nulls()
@@ -118,32 +115,48 @@ def test_head_aggregation() -> None:
118115
nwd.col("a").mean().head() # type: ignore[attr-defined]
119116

120117

118+
# TODO @dangotbanned: (Same as `test_filter_aggregation`)
121119
def test_rank_aggregation() -> None:
122120
with pytest.raises(InvalidOperationError):
123121
nwd.col("a").mean().rank()
124122

125123

124+
# TODO @dangotbanned: No error in `polars`, but results in all `null`s
126125
def test_diff_aggregation() -> None:
127126
with pytest.raises(InvalidOperationError):
128127
nwd.col("a").mean().diff()
129128

130129

131-
def test_invalid_over() -> None:
130+
# TODO @dangotbanned: Non-`polars`` rule
131+
def test_misleading_order_by() -> None:
132+
with pytest.raises(InvalidOperationError):
133+
nwd.col("a").mean().over(order_by="b")
132134
with pytest.raises(InvalidOperationError):
135+
nwd.col("a").rank().over(order_by="b")
136+
137+
138+
# NOTE: Non-`polars`` rule
139+
def test_invalid_over() -> None:
140+
pattern = re.compile(r"cannot use.+over.+elementwise", re.IGNORECASE)
141+
with pytest.raises(InvalidOperationError, match=pattern):
133142
nwd.col("a").fill_null(3).over("b")
134143

135144

136145
def test_nested_over() -> None:
137-
with pytest.raises(InvalidOperationError):
146+
pattern = re.compile(r"cannot nest.+over", re.IGNORECASE)
147+
with pytest.raises(InvalidOperationError, match=pattern):
138148
nwd.col("a").mean().over("b").over("c")
139-
with pytest.raises(InvalidOperationError):
149+
with pytest.raises(InvalidOperationError, match=pattern):
140150
nwd.col("a").mean().over("b").over("c", order_by="i")
141151

142152

153+
# NOTE: This *can* error in polars, but only if the length **actualy changes**
154+
# The rule then breaks down to needing the same length arrays in all parts of the over
143155
def test_filtration_over() -> None:
144-
with pytest.raises(InvalidOperationError):
156+
pattern = re.compile(r"cannot use.+over.+change length", re.IGNORECASE)
157+
with pytest.raises(InvalidOperationError, match=pattern):
145158
nwd.col("a").drop_nulls().over("b")
146-
with pytest.raises(InvalidOperationError):
159+
with pytest.raises(InvalidOperationError, match=pattern):
147160
nwd.col("a").drop_nulls().over("b", order_by="i")
148-
with pytest.raises(InvalidOperationError):
161+
with pytest.raises(InvalidOperationError, match=pattern):
149162
nwd.col("a").diff().drop_nulls().over("b", order_by="i")

0 commit comments

Comments
 (0)