|
3 | 3 | import builtins |
4 | 4 | import typing as t |
5 | 5 |
|
| 6 | +from narwhals._plan import aggregation as agg |
6 | 7 | from narwhals._plan import boolean |
7 | 8 | from narwhals._plan import functions as F # noqa: N812 |
8 | 9 | from narwhals._plan.dummy import DummySeries |
|
17 | 18 | from narwhals._plan.strings import ConcatHorizontal |
18 | 19 | from narwhals.dtypes import DType |
19 | 20 | from narwhals.dtypes import Unknown |
| 21 | +from narwhals.exceptions import OrderDependentExprError |
20 | 22 | from narwhals.utils import flatten |
21 | 23 |
|
22 | 24 | if t.TYPE_CHECKING: |
| 25 | + from typing_extensions import TypeIs |
| 26 | + |
23 | 27 | from narwhals._plan.dummy import DummyExpr |
| 28 | + from narwhals._plan.expr import SortBy |
| 29 | + from narwhals._plan.expr import WindowExpr |
24 | 30 | from narwhals.typing import NonNestedLiteral |
25 | 31 |
|
26 | 32 |
|
@@ -124,3 +130,43 @@ def concat_str( |
124 | 130 | .to_function_expr(*it) |
125 | 131 | .to_narwhals() |
126 | 132 | ) |
| 133 | + |
| 134 | + |
| 135 | +def _is_order_enforcing_previous(obj: t.Any) -> TypeIs[SortBy]: |
| 136 | + """In theory, we could add other nodes to this check.""" |
| 137 | + from narwhals._plan.expr import SortBy |
| 138 | + |
| 139 | + allowed = (SortBy,) |
| 140 | + return isinstance(obj, allowed) |
| 141 | + |
| 142 | + |
| 143 | +def _is_order_enforcing_next(obj: t.Any) -> TypeIs[WindowExpr]: |
| 144 | + """Not sure how this one would work.""" |
| 145 | + from narwhals._plan.expr import WindowExpr |
| 146 | + |
| 147 | + return isinstance(obj, WindowExpr) and obj.order_by is not None |
| 148 | + |
| 149 | + |
| 150 | +def _order_dependent_error(node: agg.OrderableAgg) -> OrderDependentExprError: |
| 151 | + previous = node.expr |
| 152 | + method = repr(node).removeprefix(f"{previous!r}.") |
| 153 | + msg = ( |
| 154 | + f"{method} is order-dependent and requires an ordering operation for lazy backends.\n" |
| 155 | + f"Hint:\nInstead of:\n" |
| 156 | + f" {node!r}\n\n" |
| 157 | + "If you want to aggregate to a single value, try:\n" |
| 158 | + f" {previous!r}.sort_by(...).{method}\n\n" |
| 159 | + "Otherwise, try:\n" |
| 160 | + f" {node!r}.over(order_by=...)" |
| 161 | + ) |
| 162 | + return OrderDependentExprError(msg) |
| 163 | + |
| 164 | + |
| 165 | +def ensure_orderable_rules(*exprs: DummyExpr) -> tuple[DummyExpr, ...]: |
| 166 | + for expr in exprs: |
| 167 | + node = expr._ir |
| 168 | + if isinstance(node, agg.OrderableAgg): |
| 169 | + previous = node.expr |
| 170 | + if not _is_order_enforcing_previous(previous): |
| 171 | + raise _order_dependent_error(node) |
| 172 | + return exprs |
0 commit comments