Skip to content

Commit 48cd720

Browse files
authored
chore: add has_open_windows to ExprMetadata instead of is_order_dependent (#2078)
1 parent 45038b0 commit 48cd720

File tree

11 files changed

+164
-103
lines changed

11 files changed

+164
-103
lines changed

narwhals/_arrow/expr.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from narwhals._arrow.series import ArrowSeries
1515
from narwhals._expression_parsing import ExprKind
1616
from narwhals._expression_parsing import evaluate_output_names_and_aliases
17+
from narwhals._expression_parsing import is_scalar_like
1718
from narwhals._expression_parsing import reuse_series_implementation
1819
from narwhals.dependencies import get_numpy
1920
from narwhals.dependencies import is_numpy_array
@@ -414,10 +415,8 @@ def clip(self: Self, lower_bound: Any | None, upper_bound: Any | None) -> Self:
414415
)
415416

416417
def over(self: Self, keys: list[str], kind: ExprKind) -> Self:
417-
if kind is ExprKind.TRANSFORM:
418-
msg = (
419-
"Elementwise operations in `over` context are not supported for PyArrow."
420-
)
418+
if not is_scalar_like(kind):
419+
msg = "Only aggregation or literal operations are supported in `over` context for PyArrow."
421420
raise NotImplementedError(msg)
422421

423422
def func(df: ArrowDataFrame) -> list[ArrowSeries]:

narwhals/_expression_parsing.py

Lines changed: 66 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import TYPE_CHECKING
1010
from typing import Any
1111
from typing import Callable
12+
from typing import Literal
1213
from typing import Sequence
1314
from typing import TypeVar
1415
from typing import overload
@@ -331,8 +332,8 @@ class ExprKind(Enum):
331332
Commutative composition rules are:
332333
- LITERAL vs LITERAL -> LITERAL
333334
- CHANGES_LENGTH vs (LITERAL | AGGREGATION) -> CHANGES_LENGTH
334-
- CHANGES_LENGTH vs (CHANGES_LENGTH | TRANSFORM) -> raise
335-
- TRANSFORM vs (LITERAL | AGGREGATION) -> TRANSFORM
335+
- CHANGES_LENGTH vs (CHANGES_LENGTH | TRANSFORM | WINDOW) -> raise
336+
- (TRANSFORM | WINDOW) vs (LITERAL | AGGREGATION) -> TRANSFORM
336337
- AGGREGATION vs (LITERAL | AGGREGATION) -> AGGREGATION
337338
"""
338339

@@ -343,18 +344,49 @@ class ExprKind(Enum):
343344
"""e.g. `nw.col('a').mean()`"""
344345

345346
TRANSFORM = auto()
346-
"""length-preserving, e.g. `nw.col('a').round()`"""
347+
"""preserves length, e.g. `nw.col('a').round()`"""
348+
349+
WINDOW = auto()
350+
"""transform in which last node is order-dependent
351+
352+
examples:
353+
- `nw.col('a').cum_sum()`
354+
- `(nw.col('a')+1).cum_sum()`
355+
356+
non-examples:
357+
- `nw.col('a').cum_sum()+1`
358+
- `nw.col('a').cum_sum().mean()`
359+
"""
347360

348361
CHANGES_LENGTH = auto()
349362
"""e.g. `nw.col('a').drop_nulls()`"""
350363

364+
def preserves_length(self) -> bool:
365+
return self in {ExprKind.TRANSFORM, ExprKind.WINDOW}
366+
367+
def is_window(self) -> bool:
368+
return self is ExprKind.WINDOW
369+
370+
def is_changes_length(self) -> bool:
371+
return self is ExprKind.CHANGES_LENGTH
372+
373+
def is_scalar_like(self) -> bool:
374+
return is_scalar_like(self)
375+
376+
377+
def is_scalar_like(
378+
kind: ExprKind,
379+
) -> TypeIs[Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]]:
380+
# Like ExprKind.is_scalar_like, but uses TypeIs for better type checking.
381+
return kind in {ExprKind.AGGREGATION, ExprKind.LITERAL}
382+
351383

352384
class ExprMetadata:
353-
__slots__ = ("_kind", "_order_dependent")
385+
__slots__ = ("_kind", "_n_open_windows")
354386

355-
def __init__(self, kind: ExprKind, /, *, order_dependent: bool) -> None:
387+
def __init__(self, kind: ExprKind, /, *, n_open_windows: int) -> None:
356388
self._kind: ExprKind = kind
357-
self._order_dependent: bool = order_dependent
389+
self._n_open_windows = n_open_windows
358390

359391
def __init_subclass__(cls, /, *args: Any, **kwds: Any) -> Never: # pragma: no cover
360392
msg = f"Cannot subclass {cls.__name__!r}"
@@ -364,105 +396,98 @@ def __init_subclass__(cls, /, *args: Any, **kwds: Any) -> Never: # pragma: no c
364396
def kind(self) -> ExprKind:
365397
return self._kind
366398

367-
def is_order_dependent(self) -> bool:
368-
return self._order_dependent
369-
370-
def is_transform(self) -> bool:
371-
return self.kind is ExprKind.TRANSFORM
372-
373-
def is_aggregation_or_literal(self) -> bool:
374-
return self.kind in {ExprKind.AGGREGATION, ExprKind.LITERAL}
375-
376-
def is_changes_length(self) -> bool:
377-
return self.kind is ExprKind.CHANGES_LENGTH
399+
@property
400+
def n_open_windows(self) -> int:
401+
return self._n_open_windows
378402

379403
def with_kind(self, kind: ExprKind, /) -> ExprMetadata:
380404
"""Change metadata kind, leaving all other attributes the same."""
381-
return ExprMetadata(kind, order_dependent=self.is_order_dependent())
405+
return ExprMetadata(kind, n_open_windows=self._n_open_windows)
382406

383-
def with_order_dependence(self) -> ExprMetadata:
384-
"""Set `order_dependent` to True, leaving all other attributes the same."""
385-
return ExprMetadata(self.kind, order_dependent=True)
407+
def with_extra_open_window(self) -> ExprMetadata:
408+
"""Increment `n_open_windows` leaving other attributes the same."""
409+
return ExprMetadata(self.kind, n_open_windows=self._n_open_windows + 1)
386410

387-
def with_kind_and_order_dependence(self, kind: ExprKind, /) -> ExprMetadata:
388-
"""Change kind and set `order_dependent` to True."""
389-
return ExprMetadata(kind, order_dependent=True)
411+
def with_kind_and_extra_open_window(self, kind: ExprKind, /) -> ExprMetadata:
412+
"""Change metadata kind and increment `n_open_windows`."""
413+
return ExprMetadata(kind, n_open_windows=self._n_open_windows + 1)
390414

391415
@staticmethod
392416
def selector() -> ExprMetadata:
393-
return ExprMetadata(ExprKind.TRANSFORM, order_dependent=False)
417+
return ExprMetadata(ExprKind.TRANSFORM, n_open_windows=0)
394418

395419

396420
def combine_metadata(*args: IntoExpr | object | None, str_as_lit: bool) -> ExprMetadata:
397421
# Combine metadata from `args`.
398422

399423
n_changes_length = 0
400-
has_transforms = False
424+
has_transforms_or_windows = False
401425
has_aggregations = False
402426
has_literals = False
403-
result_is_order_dependent = False
427+
result_n_open_windows = 0
404428

405429
for arg in args:
406430
if isinstance(arg, str) and not str_as_lit:
407-
has_transforms = True
431+
has_transforms_or_windows = True
408432
elif is_expr(arg):
409-
if arg._metadata.is_order_dependent():
410-
result_is_order_dependent = True
433+
if arg._metadata.n_open_windows:
434+
result_n_open_windows += 1
411435
kind = arg._metadata.kind
412436
if kind is ExprKind.AGGREGATION:
413437
has_aggregations = True
414438
elif kind is ExprKind.LITERAL:
415439
has_literals = True
416440
elif kind is ExprKind.CHANGES_LENGTH:
417441
n_changes_length += 1
418-
elif kind is ExprKind.TRANSFORM:
419-
has_transforms = True
442+
elif kind.preserves_length():
443+
has_transforms_or_windows = True
420444
else: # pragma: no cover
421445
msg = "unreachable code"
422446
raise AssertionError(msg)
423447
if (
424448
has_literals
425449
and not has_aggregations
426-
and not has_transforms
450+
and not has_transforms_or_windows
427451
and not n_changes_length
428452
):
429453
result_kind = ExprKind.LITERAL
430454
elif n_changes_length > 1:
431455
msg = "Length-changing expressions can only be used in isolation, or followed by an aggregation"
432456
raise LengthChangingExprError(msg)
433-
elif n_changes_length and has_transforms:
457+
elif n_changes_length and has_transforms_or_windows:
434458
msg = "Cannot combine length-changing expressions with length-preserving ones or aggregations"
435459
raise ShapeError(msg)
436460
elif n_changes_length:
437461
result_kind = ExprKind.CHANGES_LENGTH
438-
elif has_transforms:
462+
elif has_transforms_or_windows:
439463
result_kind = ExprKind.TRANSFORM
440464
else:
441465
result_kind = ExprKind.AGGREGATION
442466

443-
return ExprMetadata(result_kind, order_dependent=result_is_order_dependent)
467+
return ExprMetadata(result_kind, n_open_windows=result_n_open_windows)
444468

445469

446-
def check_expressions_transform(*args: IntoExpr, function_name: str) -> None:
470+
def check_expressions_preserve_length(*args: IntoExpr, function_name: str) -> None:
447471
# Raise if any argument in `args` isn't length-preserving.
448472
# For Series input, we don't raise (yet), we let such checks happen later,
449473
# as this function works lazily and so can't evaluate lengths.
450474
from narwhals.series import Series
451475

452476
if not all(
453-
(is_expr(x) and x._metadata.is_transform()) or isinstance(x, (str, Series))
477+
(is_expr(x) and x._metadata.kind.preserves_length())
478+
or isinstance(x, (str, Series))
454479
for x in args
455480
):
456481
msg = f"Expressions which aggregate or change length cannot be passed to '{function_name}'."
457482
raise ShapeError(msg)
458483

459484

460-
def all_exprs_are_aggs_or_literals(*args: IntoExpr, **kwargs: IntoExpr) -> bool:
485+
def all_exprs_are_scalar_like(*args: IntoExpr, **kwargs: IntoExpr) -> bool:
461486
# Raise if any argument in `args` isn't an aggregation or literal.
462487
# For Series input, we don't raise (yet), we let such checks happen later,
463488
# as this function works lazily and so can't evaluate lengths.
464489
exprs = chain(args, kwargs.values())
465-
return all(is_expr(x) and x._metadata.is_aggregation_or_literal() for x in exprs)
490+
return all(is_expr(x) and x._metadata.kind.is_scalar_like() for x in exprs)
466491

467492

468493
def infer_kind(obj: IntoExpr | _1DArray | object, *, str_as_lit: bool) -> ExprKind:
@@ -489,7 +514,7 @@ def apply_n_ary_operation(
489514
)
490515
kinds = [infer_kind(comparand, str_as_lit=str_as_lit) for comparand in comparands]
491516

492-
broadcast = any(kind is ExprKind.TRANSFORM for kind in kinds)
517+
broadcast = any(kind.preserves_length() for kind in kinds)
493518
compliant_exprs = (
494519
compliant_expr.broadcast(kind)
495520
if broadcast

narwhals/_pandas_like/expr.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from narwhals._expression_parsing import ExprKind
1111
from narwhals._expression_parsing import evaluate_output_names_and_aliases
12+
from narwhals._expression_parsing import is_scalar_like
1213
from narwhals._expression_parsing import is_simple_aggregation
1314
from narwhals._expression_parsing import reuse_series_implementation
1415
from narwhals._pandas_like.expr_cat import PandasLikeExprCatNamespace
@@ -475,9 +476,9 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
475476
)
476477
)
477478
return [result_frame[name] for name in aliases]
478-
elif kind is ExprKind.TRANSFORM:
479+
elif not is_scalar_like(kind):
479480
msg = (
480-
"Elementwise operations are only supported in `over` context "
481+
"Length-preserving operations are only supported in `over` context "
481482
"for pandas if they are elementary "
482483
"(e.g. `nw.col('a').cum_sum().over('b'))`)."
483484
)

narwhals/dataframe.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
from warnings import warn
1717

1818
from narwhals._expression_parsing import ExprKind
19-
from narwhals._expression_parsing import all_exprs_are_aggs_or_literals
20-
from narwhals._expression_parsing import check_expressions_transform
19+
from narwhals._expression_parsing import all_exprs_are_scalar_like
20+
from narwhals._expression_parsing import check_expressions_preserve_length
2121
from narwhals._expression_parsing import infer_kind
22+
from narwhals._expression_parsing import is_scalar_like
2223
from narwhals.dependencies import get_polars
2324
from narwhals.dependencies import is_numpy_array
2425
from narwhals.dependencies import is_numpy_array_1d
@@ -137,9 +138,7 @@ def with_columns(
137138
) -> Self:
138139
compliant_exprs, kinds = self._flatten_and_extract(*exprs, **named_exprs)
139140
compliant_exprs = [
140-
compliant_expr.broadcast(kind)
141-
if (kind is ExprKind.LITERAL or kind is ExprKind.AGGREGATION)
142-
else compliant_expr
141+
compliant_expr.broadcast(kind) if is_scalar_like(kind) else compliant_expr
143142
for compliant_expr, kind in zip(compliant_exprs, kinds)
144143
]
145144
return self._from_compliant_dataframe(
@@ -166,14 +165,12 @@ def select(
166165
missing_columns, available_columns
167166
) from e
168167
compliant_exprs, kinds = self._flatten_and_extract(*flat_exprs, **named_exprs)
169-
if compliant_exprs and all_exprs_are_aggs_or_literals(*flat_exprs, **named_exprs):
168+
if compliant_exprs and all_exprs_are_scalar_like(*flat_exprs, **named_exprs):
170169
return self._from_compliant_dataframe(
171170
self._compliant_frame.aggregate(*compliant_exprs),
172171
)
173172
compliant_exprs = [
174-
compliant_expr.broadcast(kind)
175-
if (kind is ExprKind.LITERAL or kind is ExprKind.AGGREGATION)
176-
else compliant_expr
173+
compliant_expr.broadcast(kind) if is_scalar_like(kind) else compliant_expr
177174
for compliant_expr, kind in zip(compliant_exprs, kinds)
178175
]
179176
return self._from_compliant_dataframe(
@@ -205,7 +202,7 @@ def filter(
205202
and all(isinstance(x, bool) for x in predicates[0])
206203
):
207204
flat_predicates = flatten(predicates)
208-
check_expressions_transform(*flat_predicates, function_name="filter")
205+
check_expressions_preserve_length(*flat_predicates, function_name="filter")
209206
compliant_predicates, _kinds = self._flatten_and_extract(*flat_predicates)
210207
plx = self.__narwhals_namespace__()
211208
predicate = plx.all_horizontal(
@@ -2204,18 +2201,18 @@ def _extract_compliant(self: Self, arg: Any) -> Any:
22042201
plx = self.__narwhals_namespace__()
22052202
return plx.col(arg)
22062203
if isinstance(arg, Expr):
2207-
if arg._metadata.is_order_dependent():
2204+
if arg._metadata.n_open_windows > 0:
22082205
msg = (
22092206
"Order-dependent expressions are not supported for use in LazyFrame.\n\n"
22102207
"Hints:\n"
22112208
"- Instead of `lf.select(nw.col('a').sort())`, use `lf.select('a').sort()\n"
22122209
"- Instead of `lf.select(nw.col('a').head())`, use `lf.select('a').head()\n"
22132210
"- `Expr.cum_sum`, and other such expressions, are not currently supported.\n"
2214-
" In a future version of Narwhals, a `order_by` argument will be added and \n"
2215-
" they will be supported."
2211+
" In a future version of Narwhals, a `order_by` argument will be added to\n"
2212+
" `over` and they will be supported."
22162213
)
22172214
raise OrderDependentExprError(msg)
2218-
if arg._metadata.is_changes_length():
2215+
if arg._metadata.kind.is_changes_length():
22192216
msg = (
22202217
"Length-changing expressions are not supported for use in LazyFrame, unless\n"
22212218
"followed by an aggregation.\n\n"

0 commit comments

Comments
 (0)