Skip to content

Commit 23eb022

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 7ad112a commit 23eb022

File tree

2 files changed

+32
-30
lines changed

2 files changed

+32
-30
lines changed

narwhals/_sql/expr.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,34 @@
33
from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol, cast
44

55
from narwhals._compliant.expr import LazyExpr
6-
from narwhals._compliant.typing import (
7-
AliasNames,
8-
EvalNames,
9-
EvalSeries,
10-
WindowFunction,
11-
)
6+
from narwhals._compliant.typing import AliasNames, WindowFunction
127
from narwhals._compliant.window import WindowInputs
138
from narwhals._expression_parsing import (
149
combine_alias_output_names,
1510
combine_evaluate_output_names,
1611
)
17-
from narwhals._sql.typing import SQLLazyFrameT, NativeSQLExprT
12+
from narwhals._sql.typing import NativeSQLExprT, SQLLazyFrameT
1813
from narwhals._utils import Implementation, Version, not_implemented
1914

2015
if TYPE_CHECKING:
2116
from collections.abc import Iterable, Sequence
2217

2318
from typing_extensions import Self, TypeIs
2419

25-
from narwhals._compliant.typing import AliasNames, WindowFunction
20+
from narwhals._compliant.typing import (
21+
AliasNames,
22+
EvalNames,
23+
EvalSeries,
24+
WindowFunction,
25+
)
2626
from narwhals._expression_parsing import ExprMetadata
2727
from narwhals._sql.namespace import SQLNamespace
2828
from narwhals.typing import NumericLiteral, PythonLiteral, RankMethod, TemporalLiteral
2929

30-
class SQLExpr(LazyExpr[SQLLazyFrameT, NativeSQLExprT], Protocol[SQLLazyFrameT, NativeSQLExprT]):
30+
31+
class SQLExpr(
32+
LazyExpr[SQLLazyFrameT, NativeSQLExprT], Protocol[SQLLazyFrameT, NativeSQLExprT]
33+
):
3134
_call: EvalSeries[SQLLazyFrameT, NativeSQLExprT]
3235
_evaluate_output_names: EvalNames[SQLLazyFrameT]
3336
_alias_output_names: AliasNames | None
@@ -173,7 +176,9 @@ def default_window_func(
173176

174177
return self._window_function or default_window_func
175178

176-
def _function(self, name: str, *args: NativeSQLExprT | PythonLiteral) -> NativeSQLExprT:
179+
def _function(
180+
self, name: str, *args: NativeSQLExprT | PythonLiteral
181+
) -> NativeSQLExprT:
177182
return self.__narwhals_namespace__()._function(name, *args)
178183

179184
def _lit(self, value: Any) -> NativeSQLExprT:
@@ -273,7 +278,7 @@ def func(
273278
}
274279
return [
275280
self._when(
276-
self._window_expression(
281+
self._window_expression(
277282
self._function("count", expr), **window_kwargs
278283
)
279284
>= self._lit(min_samples),
@@ -494,13 +499,12 @@ def round(self, decimals: int) -> Self:
494499
return self._with_elementwise(
495500
lambda expr: self._function("round", expr, self._lit(decimals))
496501
)
497-
# WIP: trying new NativeSQLExprT
502+
503+
# WIP: trying new NativeSQLExprT
498504
def sqrt(self) -> Self:
499505
def _sqrt(expr: NativeSQLExprT) -> NativeSQLExprT:
500506
return self._when(
501-
expr < self._lit(0),
502-
self._lit(float("nan")),
503-
self._function("sqrt", expr),
507+
expr < self._lit(0), self._lit(float("nan")), self._function("sqrt", expr)
504508
)
505509

506510
return self._with_elementwise(_sqrt)
@@ -511,12 +515,12 @@ def exp(self) -> Self:
511515
def log(self, base: float) -> Self:
512516
def _log(expr: NativeSQLExprT) -> NativeSQLExprT:
513517
return self._when(
514-
expr < self._lit(0),
518+
expr < self._lit(0),
515519
self._lit(float("nan")),
516520
self._when(
517521
cast("NativeSQLExprT", expr == self._lit(0)),
518522
self._lit(float("-inf")),
519-
self._function("log", expr) / self._function("log", self._lit(base)),
523+
self._function("log", expr) / self._function("log", self._lit(base)),
520524
),
521525
)
522526

@@ -664,19 +668,19 @@ def _rank(
664668
count_window_kwargs: dict[str, Any] = {"partition_by": (*partition_by, expr)}
665669
if method == "max":
666670
rank_expr = (
667-
self._window_expression(func, **window_kwargs)
671+
self._window_expression(func, **window_kwargs)
668672
+ self._window_expression(count_expr, **count_window_kwargs)
669673
- self._lit(1)
670674
)
671675
elif method == "average":
672676
rank_expr = self._window_expression(func, **window_kwargs) + (
673-
self._window_expression(count_expr, **count_window_kwargs)
677+
self._window_expression(count_expr, **count_window_kwargs)
674678
- self._lit(1)
675679
) / self._lit(2.0)
676680
else:
677681
rank_expr = self._window_expression(func, **window_kwargs)
678-
# TODO: @mp, thought I added this to NativeSQLExprT but not working?
679-
return self._when(~self._function("isnull", expr), rank_expr) # type: ignore[operator]
682+
# TODO: @mp, thought I added this to NativeSQLExprT but not working?
683+
return self._when(~self._function("isnull", expr), rank_expr) # type: ignore[operator]
680684

681685
def _unpartitioned_rank(expr: NativeSQLExprT) -> NativeSQLExprT:
682686
return _rank(expr, descending=[descending], nulls_last=[True])

narwhals/_sql/typing.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,28 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, TypeVar, Protocol
3+
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
44

55
from narwhals._compliant.expr import NativeExpr
66

77
if TYPE_CHECKING:
8+
from typing_extensions import Self
9+
810
from narwhals._sql.dataframe import SQLLazyFrame
911
from narwhals._sql.expr import SQLExpr
10-
from narwhals.dtypes import Boolean
11-
from typing_extensions import Self
1212

13-
# TODO: check we
13+
# TODO: check we
1414
SQLExprAny = SQLExpr[Any, Any]
1515
SQLLazyFrameAny = SQLLazyFrame[Any, Any, Any]
1616

1717
SQLExprT = TypeVar("SQLExprT", bound="SQLExprAny")
1818
SQLExprT_contra = TypeVar("SQLExprT_contra", bound="SQLExprAny", contravariant=True)
1919
SQLLazyFrameT = TypeVar("SQLLazyFrameT", bound="SQLLazyFrameAny")
2020
# TODO: @mp, should this be contravariant as to do with function arguments? think through!
21-
NativeSQLExprT = TypeVar("NativeSQLExprT", bound="NativeSQLExpr")
21+
NativeSQLExprT = TypeVar("NativeSQLExprT", bound="NativeSQLExpr")
22+
2223

2324
class NativeSQLExpr(NativeExpr, Protocol):
24-
# both Self because we're comparing an expression with an expression?
25+
# both Self because we're comparing an expression with an expression?
2526
def __gt__(self, value: Any, /) -> Self: ...
2627

2728
def __lt__(self, value: Any, /) -> Self: ...
@@ -45,6 +46,3 @@ def __truediv__(self, value: Any, /) -> Self: ...
4546
# def __mul__(self, value: Self) -> Self: ...
4647

4748
# def __invert__(self, value: Self) -> Self: ...
48-
49-
50-

0 commit comments

Comments
 (0)