99from typing import TYPE_CHECKING
1010from typing import Any
1111from typing import Callable
12+ from typing import Literal
1213from typing import Sequence
1314from typing import TypeVar
1415from typing import overload
@@ -331,8 +332,8 @@ class ExprKind(Enum):
331332 Commutative composition rules are:
332333 - LITERAL vs LITERAL -> LITERAL
333334 - CHANGES_LENGTH vs (LITERAL | AGGREGATION) -> CHANGES_LENGTH
334- - CHANGES_LENGTH vs (CHANGES_LENGTH | TRANSFORM) -> raise
335- - TRANSFORM vs (LITERAL | AGGREGATION) -> TRANSFORM
335+ - CHANGES_LENGTH vs (CHANGES_LENGTH | TRANSFORM | WINDOW ) -> raise
336+ - ( TRANSFORM | WINDOW) vs (LITERAL | AGGREGATION) -> TRANSFORM
336337 - AGGREGATION vs (LITERAL | AGGREGATION) -> AGGREGATION
337338 """
338339
@@ -343,18 +344,49 @@ class ExprKind(Enum):
343344 """e.g. `nw.col('a').mean()`"""
344345
345346 TRANSFORM = auto ()
346- """length-preserving, e.g. `nw.col('a').round()`"""
347+ """preserves length, e.g. `nw.col('a').round()`"""
348+
349+ WINDOW = auto ()
350+ """transform in which last node is order-dependent
351+
352+ examples:
353+ - `nw.col('a').cum_sum()`
354+ - `(nw.col('a')+1).cum_sum()`
355+
356+ non-examples:
357+ - `nw.col('a').cum_sum()+1`
358+ - `nw.col('a').cum_sum().mean()`
359+ """
347360
348361 CHANGES_LENGTH = auto ()
349362 """e.g. `nw.col('a').drop_nulls()`"""
350363
364+ def preserves_length (self ) -> bool :
365+ return self in {ExprKind .TRANSFORM , ExprKind .WINDOW }
366+
367+ def is_window (self ) -> bool :
368+ return self is ExprKind .WINDOW
369+
370+ def is_changes_length (self ) -> bool :
371+ return self is ExprKind .CHANGES_LENGTH
372+
373+ def is_scalar_like (self ) -> bool :
374+ return is_scalar_like (self )
375+
376+
377+ def is_scalar_like (
378+ kind : ExprKind ,
379+ ) -> TypeIs [Literal [ExprKind .AGGREGATION , ExprKind .LITERAL ]]:
380+ # Like ExprKind.is_scalar_like, but uses TypeIs for better type checking.
381+ return kind in {ExprKind .AGGREGATION , ExprKind .LITERAL }
382+
351383
352384class ExprMetadata :
353- __slots__ = ("_kind" , "_order_dependent " )
385+ __slots__ = ("_kind" , "_n_open_windows " )
354386
355- def __init__ (self , kind : ExprKind , / , * , order_dependent : bool ) -> None :
387+ def __init__ (self , kind : ExprKind , / , * , n_open_windows : int ) -> None :
356388 self ._kind : ExprKind = kind
357- self ._order_dependent : bool = order_dependent
389+ self ._n_open_windows = n_open_windows
358390
359391 def __init_subclass__ (cls , / , * args : Any , ** kwds : Any ) -> Never : # pragma: no cover
360392 msg = f"Cannot subclass { cls .__name__ !r} "
@@ -364,105 +396,98 @@ def __init_subclass__(cls, /, *args: Any, **kwds: Any) -> Never: # pragma: no c
364396 def kind (self ) -> ExprKind :
365397 return self ._kind
366398
367- def is_order_dependent (self ) -> bool :
368- return self ._order_dependent
369-
370- def is_transform (self ) -> bool :
371- return self .kind is ExprKind .TRANSFORM
372-
373- def is_aggregation_or_literal (self ) -> bool :
374- return self .kind in {ExprKind .AGGREGATION , ExprKind .LITERAL }
375-
376- def is_changes_length (self ) -> bool :
377- return self .kind is ExprKind .CHANGES_LENGTH
399+ @property
400+ def n_open_windows (self ) -> int :
401+ return self ._n_open_windows
378402
379403 def with_kind (self , kind : ExprKind , / ) -> ExprMetadata :
380404 """Change metadata kind, leaving all other attributes the same."""
381- return ExprMetadata (kind , order_dependent = self .is_order_dependent () )
405+ return ExprMetadata (kind , n_open_windows = self ._n_open_windows )
382406
383- def with_order_dependence (self ) -> ExprMetadata :
384- """Set `order_dependent` to True, leaving all other attributes the same."""
385- return ExprMetadata (self .kind , order_dependent = True )
407+ def with_extra_open_window (self ) -> ExprMetadata :
408+ """Increment `n_open_windows` leaving other attributes the same."""
409+ return ExprMetadata (self .kind , n_open_windows = self . _n_open_windows + 1 )
386410
387- def with_kind_and_order_dependence (self , kind : ExprKind , / ) -> ExprMetadata :
388- """Change kind and set `order_dependent` to True ."""
389- return ExprMetadata (kind , order_dependent = True )
411+ def with_kind_and_extra_open_window (self , kind : ExprKind , / ) -> ExprMetadata :
412+ """Change metadata kind and increment `n_open_windows` ."""
413+ return ExprMetadata (kind , n_open_windows = self . _n_open_windows + 1 )
390414
391415 @staticmethod
392416 def selector () -> ExprMetadata :
393- return ExprMetadata (ExprKind .TRANSFORM , order_dependent = False )
417+ return ExprMetadata (ExprKind .TRANSFORM , n_open_windows = 0 )
394418
395419
396420def combine_metadata (* args : IntoExpr | object | None , str_as_lit : bool ) -> ExprMetadata :
397421 # Combine metadata from `args`.
398422
399423 n_changes_length = 0
400- has_transforms = False
424+ has_transforms_or_windows = False
401425 has_aggregations = False
402426 has_literals = False
403- result_is_order_dependent = False
427+ result_n_open_windows = 0
404428
405429 for arg in args :
406430 if isinstance (arg , str ) and not str_as_lit :
407- has_transforms = True
431+ has_transforms_or_windows = True
408432 elif is_expr (arg ):
409- if arg ._metadata .is_order_dependent () :
410- result_is_order_dependent = True
433+ if arg ._metadata .n_open_windows :
434+ result_n_open_windows += 1
411435 kind = arg ._metadata .kind
412436 if kind is ExprKind .AGGREGATION :
413437 has_aggregations = True
414438 elif kind is ExprKind .LITERAL :
415439 has_literals = True
416440 elif kind is ExprKind .CHANGES_LENGTH :
417441 n_changes_length += 1
418- elif kind is ExprKind . TRANSFORM :
419- has_transforms = True
442+ elif kind . preserves_length () :
443+ has_transforms_or_windows = True
420444 else : # pragma: no cover
421445 msg = "unreachable code"
422446 raise AssertionError (msg )
423447 if (
424448 has_literals
425449 and not has_aggregations
426- and not has_transforms
450+ and not has_transforms_or_windows
427451 and not n_changes_length
428452 ):
429453 result_kind = ExprKind .LITERAL
430454 elif n_changes_length > 1 :
431455 msg = "Length-changing expressions can only be used in isolation, or followed by an aggregation"
432456 raise LengthChangingExprError (msg )
433- elif n_changes_length and has_transforms :
457+ elif n_changes_length and has_transforms_or_windows :
434458 msg = "Cannot combine length-changing expressions with length-preserving ones or aggregations"
435459 raise ShapeError (msg )
436460 elif n_changes_length :
437461 result_kind = ExprKind .CHANGES_LENGTH
438- elif has_transforms :
462+ elif has_transforms_or_windows :
439463 result_kind = ExprKind .TRANSFORM
440464 else :
441465 result_kind = ExprKind .AGGREGATION
442466
443- return ExprMetadata (result_kind , order_dependent = result_is_order_dependent )
467+ return ExprMetadata (result_kind , n_open_windows = result_n_open_windows )
444468
445469
446- def check_expressions_transform (* args : IntoExpr , function_name : str ) -> None :
470+ def check_expressions_preserve_length (* args : IntoExpr , function_name : str ) -> None :
447471 # Raise if any argument in `args` isn't length-preserving.
448472 # For Series input, we don't raise (yet), we let such checks happen later,
449473 # as this function works lazily and so can't evaluate lengths.
450474 from narwhals .series import Series
451475
452476 if not all (
453- (is_expr (x ) and x ._metadata .is_transform ()) or isinstance (x , (str , Series ))
477+ (is_expr (x ) and x ._metadata .kind .preserves_length ())
478+ or isinstance (x , (str , Series ))
454479 for x in args
455480 ):
456481 msg = f"Expressions which aggregate or change length cannot be passed to '{ function_name } '."
457482 raise ShapeError (msg )
458483
459484
460- def all_exprs_are_aggs_or_literals (* args : IntoExpr , ** kwargs : IntoExpr ) -> bool :
485+ def all_exprs_are_scalar_like (* args : IntoExpr , ** kwargs : IntoExpr ) -> bool :
461486 # Raise if any argument in `args` isn't an aggregation or literal.
462487 # For Series input, we don't raise (yet), we let such checks happen later,
463488 # as this function works lazily and so can't evaluate lengths.
464489 exprs = chain (args , kwargs .values ())
465- return all (is_expr (x ) and x ._metadata .is_aggregation_or_literal () for x in exprs )
490+ return all (is_expr (x ) and x ._metadata .kind . is_scalar_like () for x in exprs )
466491
467492
468493def infer_kind (obj : IntoExpr | _1DArray | object , * , str_as_lit : bool ) -> ExprKind :
@@ -489,7 +514,7 @@ def apply_n_ary_operation(
489514 )
490515 kinds = [infer_kind (comparand , str_as_lit = str_as_lit ) for comparand in comparands ]
491516
492- broadcast = any (kind is ExprKind . TRANSFORM for kind in kinds )
517+ broadcast = any (kind . preserves_length () for kind in kinds )
493518 compliant_exprs = (
494519 compliant_expr .broadcast (kind )
495520 if broadcast
0 commit comments