Skip to content

Commit 8b9b8d3

Browse files
committed
fix: Enforce no repeat aggs, fix flags
Forgot that `rust` has the `contains` check the other way
1 parent 6e7f9bc commit 8b9b8d3

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

narwhals/_plan/aggregation.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Any
44

55
from narwhals._plan.common import ExprIR
6+
from narwhals.exceptions import InvalidOperationError
67

78
if TYPE_CHECKING:
89
from typing import Iterator
@@ -35,6 +36,12 @@ def iter_right(self) -> Iterator[ExprIR]:
3536
yield self
3637
yield from self.expr.iter_right()
3738

39+
def __init__(self, *, expr: ExprIR, **kwds: Any) -> None:
40+
if expr.is_scalar:
41+
msg = "Can't apply aggregations to scalar-like expressions."
42+
raise InvalidOperationError(msg)
43+
super().__init__(expr=expr, **kwds) # pyright: ignore[reportCallIssue]
44+
3845

3946
class Count(Agg): ...
4047

narwhals/_plan/options.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ def is_elementwise(self) -> bool:
3939
return self in (FunctionFlags.ROW_SEPARABLE | FunctionFlags.LENGTH_PRESERVING)
4040

4141
def returns_scalar(self) -> bool:
42-
return self in FunctionFlags.RETURNS_SCALAR
42+
return FunctionFlags.RETURNS_SCALAR in self
4343

4444
def is_length_preserving(self) -> bool:
45-
return self in FunctionFlags.LENGTH_PRESERVING
45+
return FunctionFlags.LENGTH_PRESERVING in self
4646

4747
@staticmethod
4848
def default() -> FunctionFlags:

tests/plan/expr_parsing_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from narwhals._plan.common import ExprIR, Function
1414
from narwhals._plan.dummy import DummyExpr
1515
from narwhals._plan.expr import FunctionExpr
16+
from narwhals.exceptions import InvalidOperationError
1617

1718
if TYPE_CHECKING:
1819
from narwhals._plan.common import IntoExpr, Seq
@@ -78,3 +79,14 @@ def test_function_expr_horizontal(
7879
assert isinstance(variadic_node.function, ir_node)
7980
assert variadic_node == sequence_node
8081
assert sequence_node != unrelated_node
82+
83+
84+
def test_invalid_repeat_agg() -> None:
85+
with pytest.raises(InvalidOperationError):
86+
nwd.col("a").mean().mean()
87+
with pytest.raises(InvalidOperationError):
88+
nwd.col("a").first().max()
89+
with pytest.raises(InvalidOperationError):
90+
nwd.col("a").any().std()
91+
with pytest.raises(InvalidOperationError):
92+
nwd.col("a").all().quantile(0.5, "linear")

0 commit comments

Comments
 (0)