33from typing import TYPE_CHECKING , Any , Callable , Literal , Protocol , cast
44
55from narwhals ._compliant .expr import LazyExpr
6- from narwhals ._compliant .typing import (
7- AliasNames ,
8- EvalNames ,
9- EvalSeries ,
10- WindowFunction ,
11- )
6+ from narwhals ._compliant .typing import AliasNames , WindowFunction
127from narwhals ._compliant .window import WindowInputs
138from narwhals ._expression_parsing import (
149 combine_alias_output_names ,
1510 combine_evaluate_output_names ,
1611)
17- from narwhals ._sql .typing import SQLLazyFrameT , NativeSQLExprT
12+ from narwhals ._sql .typing import NativeSQLExprT , SQLLazyFrameT
1813from narwhals ._utils import Implementation , Version , not_implemented
1914
2015if TYPE_CHECKING :
2116 from collections .abc import Iterable , Sequence
2217
2318 from typing_extensions import Self , TypeIs
2419
25- from narwhals ._compliant .typing import AliasNames , WindowFunction
20+ from narwhals ._compliant .typing import (
21+ AliasNames ,
22+ EvalNames ,
23+ EvalSeries ,
24+ WindowFunction ,
25+ )
2626 from narwhals ._expression_parsing import ExprMetadata
2727 from narwhals ._sql .namespace import SQLNamespace
2828 from narwhals .typing import NumericLiteral , PythonLiteral , RankMethod , TemporalLiteral
2929
30- class SQLExpr (LazyExpr [SQLLazyFrameT , NativeSQLExprT ], Protocol [SQLLazyFrameT , NativeSQLExprT ]):
30+
31+ class SQLExpr (
32+ LazyExpr [SQLLazyFrameT , NativeSQLExprT ], Protocol [SQLLazyFrameT , NativeSQLExprT ]
33+ ):
3134 _call : EvalSeries [SQLLazyFrameT , NativeSQLExprT ]
3235 _evaluate_output_names : EvalNames [SQLLazyFrameT ]
3336 _alias_output_names : AliasNames | None
@@ -173,7 +176,9 @@ def default_window_func(
173176
174177 return self ._window_function or default_window_func
175178
176- def _function (self , name : str , * args : NativeSQLExprT | PythonLiteral ) -> NativeSQLExprT :
179+ def _function (
180+ self , name : str , * args : NativeSQLExprT | PythonLiteral
181+ ) -> NativeSQLExprT :
177182 return self .__narwhals_namespace__ ()._function (name , * args )
178183
179184 def _lit (self , value : Any ) -> NativeSQLExprT :
@@ -273,7 +278,7 @@ def func(
273278 }
274279 return [
275280 self ._when (
276- self ._window_expression (
281+ self ._window_expression (
277282 self ._function ("count" , expr ), ** window_kwargs
278283 )
279284 >= self ._lit (min_samples ),
@@ -494,13 +499,12 @@ def round(self, decimals: int) -> Self:
494499 return self ._with_elementwise (
495500 lambda expr : self ._function ("round" , expr , self ._lit (decimals ))
496501 )
497- # WIP: trying new NativeSQLExprT
502+
503+ # WIP: trying new NativeSQLExprT
498504 def sqrt (self ) -> Self :
499505 def _sqrt (expr : NativeSQLExprT ) -> NativeSQLExprT :
500506 return self ._when (
501- expr < self ._lit (0 ),
502- self ._lit (float ("nan" )),
503- self ._function ("sqrt" , expr ),
507+ expr < self ._lit (0 ), self ._lit (float ("nan" )), self ._function ("sqrt" , expr )
504508 )
505509
506510 return self ._with_elementwise (_sqrt )
@@ -511,12 +515,12 @@ def exp(self) -> Self:
511515 def log (self , base : float ) -> Self :
512516 def _log (expr : NativeSQLExprT ) -> NativeSQLExprT :
513517 return self ._when (
514- expr < self ._lit (0 ),
518+ expr < self ._lit (0 ),
515519 self ._lit (float ("nan" )),
516520 self ._when (
517521 cast ("NativeSQLExprT" , expr == self ._lit (0 )),
518522 self ._lit (float ("-inf" )),
519- self ._function ("log" , expr ) / self ._function ("log" , self ._lit (base )),
523+ self ._function ("log" , expr ) / self ._function ("log" , self ._lit (base )),
520524 ),
521525 )
522526
@@ -664,19 +668,19 @@ def _rank(
664668 count_window_kwargs : dict [str , Any ] = {"partition_by" : (* partition_by , expr )}
665669 if method == "max" :
666670 rank_expr = (
667- self ._window_expression (func , ** window_kwargs )
671+ self ._window_expression (func , ** window_kwargs )
668672 + self ._window_expression (count_expr , ** count_window_kwargs )
669673 - self ._lit (1 )
670674 )
671675 elif method == "average" :
672676 rank_expr = self ._window_expression (func , ** window_kwargs ) + (
673- self ._window_expression (count_expr , ** count_window_kwargs )
677+ self ._window_expression (count_expr , ** count_window_kwargs )
674678 - self ._lit (1 )
675679 ) / self ._lit (2.0 )
676680 else :
677681 rank_expr = self ._window_expression (func , ** window_kwargs )
678- # TODO: @mp, thought I added this to NativeSQLExprT but not working?
679- return self ._when (~ self ._function ("isnull" , expr ), rank_expr ) # type: ignore[operator]
682+ # TODO: @mp, thought I added this to NativeSQLExprT but not working?
683+ return self ._when (~ self ._function ("isnull" , expr ), rank_expr ) # type: ignore[operator]
680684
681685 def _unpartitioned_rank (expr : NativeSQLExprT ) -> NativeSQLExprT :
682686 return _rank (expr , descending = [descending ], nulls_last = [True ])
0 commit comments