11from __future__ import annotations
22
3- from typing import TYPE_CHECKING , Any , Callable , Literal , Protocol , cast
3+ # ruff: noqa: N806
4+ import operator as op
5+ from typing import TYPE_CHECKING , Any , Callable , Literal , Protocol
46
57from narwhals ._compliant .expr import LazyExpr
68from narwhals ._compliant .typing import (
@@ -276,7 +278,7 @@ def func(
276278 }
277279 return [
278280 self ._when (
279- self ._window_expression ( # type: ignore[operator]
281+ self ._window_expression (
280282 self ._function ("count" , expr ), ** window_kwargs
281283 )
282284 >= self ._lit (min_samples ),
@@ -501,9 +503,7 @@ def round(self, decimals: int) -> Self:
501503 def sqrt (self ) -> Self :
502504 def _sqrt (expr : NativeExprT ) -> NativeExprT :
503505 return self ._when (
504- expr < self ._lit (0 ), # type: ignore[operator]
505- self ._lit (float ("nan" )),
506- self ._function ("sqrt" , expr ),
506+ expr < self ._lit (0 ), self ._lit (float ("nan" )), self ._function ("sqrt" , expr )
507507 )
508508
509509 return self ._with_elementwise (_sqrt )
@@ -513,13 +513,14 @@ def exp(self) -> Self:
513513
514514 def log (self , base : float ) -> Self :
515515 def _log (expr : NativeExprT ) -> NativeExprT :
516+ F = self ._function
516517 return self ._when (
517- expr < self ._lit (0 ), # type: ignore[operator]
518+ expr < self ._lit (0 ),
518519 self ._lit (float ("nan" )),
519520 self ._when (
520- cast ( "NativeExprT" , expr == self ._lit (0 ) ),
521+ expr == self ._lit (0 ),
521522 self ._lit (float ("-inf" )),
522- self . _function ( "log" , expr ) / self . _function ("log" , self ._lit (base )), # type: ignore[operator]
523+ op . truediv ( F ( "log" , expr ), F ("log" , self ._lit (base ))),
523524 ),
524525 )
525526
@@ -577,11 +578,10 @@ def diff(self) -> Self:
577578 def func (
578579 df : SQLLazyFrameT , inputs : WindowInputs [NativeExprT ]
579580 ) -> Sequence [NativeExprT ]:
581+ F = self ._function
582+ window = self ._window_expression
580583 return [
581- expr # type: ignore[operator]
582- - self ._window_expression (
583- self ._function ("lag" , expr ), inputs .partition_by , inputs .order_by
584- )
584+ op .sub (expr , window (F ("lag" , expr ), inputs .partition_by , inputs .order_by ))
585585 for expr in self (df )
586586 ]
587587
@@ -606,15 +606,12 @@ def func(
606606 ) -> Sequence [NativeExprT ]:
607607 # pyright checkers think the return type is `list[bool]` because of `==`
608608 return [
609- cast (
610- "NativeExprT" ,
611- self ._window_expression (
612- self ._function ("row_number" ),
613- (* inputs .partition_by , expr ),
614- inputs .order_by ,
615- )
616- == self ._lit (1 ),
609+ self ._window_expression (
610+ self ._function ("row_number" ),
611+ (* inputs .partition_by , expr ),
612+ inputs .order_by ,
617613 )
614+ == self ._lit (1 )
618615 for expr in self (df )
619616 ]
620617
@@ -625,17 +622,14 @@ def func(
625622 df : SQLLazyFrameT , inputs : WindowInputs [NativeExprT ]
626623 ) -> Sequence [NativeExprT ]:
627624 return [
628- cast (
629- "NativeExprT" ,
630- self ._window_expression (
631- self ._function ("row_number" ),
632- (* inputs .partition_by , expr ),
633- inputs .order_by ,
634- descending = [True ] * len (inputs .order_by ),
635- nulls_last = [True ] * len (inputs .order_by ),
636- )
637- == self ._lit (1 ),
625+ self ._window_expression (
626+ self ._function ("row_number" ),
627+ (* inputs .partition_by , expr ),
628+ inputs .order_by ,
629+ descending = [True ] * len (inputs .order_by ),
630+ nulls_last = [True ] * len (inputs .order_by ),
638631 )
632+ == self ._lit (1 )
639633 for expr in self (df )
640634 ]
641635
@@ -665,20 +659,27 @@ def _rank(
665659 "nulls_last" : nulls_last ,
666660 }
667661 count_window_kwargs : dict [str , Any ] = {"partition_by" : (* partition_by , expr )}
662+ window = self ._window_expression
663+ F = self ._function
668664 if method == "max" :
669- rank_expr = (
670- self ._window_expression (func , ** window_kwargs ) # type: ignore[operator]
671- + self ._window_expression (count_expr , ** count_window_kwargs )
672- - self ._lit (1 )
665+ rank_expr = op .sub (
666+ op .add (
667+ window (func , ** window_kwargs ),
668+ window (count_expr , ** count_window_kwargs ),
669+ ),
670+ self ._lit (1 ),
673671 )
674672 elif method == "average" :
675- rank_expr = self ._window_expression (func , ** window_kwargs ) + (
676- self ._window_expression (count_expr , ** count_window_kwargs ) # type: ignore[operator]
677- - self ._lit (1 )
678- ) / self ._lit (2.0 )
673+ rank_expr = op .add (
674+ window (func , ** window_kwargs ),
675+ op .truediv (
676+ op .sub (window (count_expr , ** count_window_kwargs ), self ._lit (1 )),
677+ self ._lit (2.0 ),
678+ ),
679+ )
679680 else :
680- rank_expr = self . _window_expression (func , ** window_kwargs )
681- return self ._when (~ self . _function ("isnull" , expr ), rank_expr ) # type: ignore[operator]
681+ rank_expr = window (func , ** window_kwargs )
682+ return self ._when (~ F ("isnull" , expr ), rank_expr ) # type: ignore[operator]
682683
683684 def _unpartitioned_rank (expr : NativeExprT ) -> NativeExprT :
684685 return _rank (expr , descending = [descending ], nulls_last = [True ])
@@ -707,11 +708,9 @@ def is_unique(self) -> Self:
707708 def _is_unique (
708709 expr : NativeExprT , * partition_by : str | NativeExprT
709710 ) -> NativeExprT :
710- return cast (
711- "NativeExprT" ,
712- self ._window_expression (self ._count_star (), (expr , * partition_by ))
713- == self ._lit (1 ),
714- )
711+ return self ._window_expression (
712+ self ._count_star (), (expr , * partition_by )
713+ ) == self ._lit (1 )
715714
716715 def _unpartitioned_is_unique (expr : NativeExprT ) -> NativeExprT :
717716 return _is_unique (expr )
0 commit comments