Skip to content

Commit ec08949

Browse files
committed
wip operator
1 parent 86ed14a commit ec08949

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

narwhals/_sql/expr.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
combine_alias_output_names,
1616
combine_evaluate_output_names,
1717
)
18-
from narwhals._sql.typing import SQLLazyFrameT
18+
from narwhals._sql.typing import SQLLazyFrameT, NativeSQLExprT
1919
from narwhals._utils import Implementation, Version, not_implemented
2020

2121
if TYPE_CHECKING:
@@ -249,7 +249,7 @@ def _rolling_window_func(
249249
end = 0
250250

251251
def func(
252-
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
252+
df: SQLLazyFrameT, inputs: WindowInputs[NativeSQLExprT]
253253
) -> Sequence[NativeExprT]:
254254
if func_name in {"sum", "mean"}:
255255
func_: str = func_name
@@ -496,11 +496,11 @@ def round(self, decimals: int) -> Self:
496496
return self._with_elementwise(
497497
lambda expr: self._function("round", expr, self._lit(decimals))
498498
)
499-
499+
# WIP: trying new NativeSQLExprT
500500
def sqrt(self) -> Self:
501-
def _sqrt(expr: NativeExprT) -> NativeExprT:
501+
def _sqrt(expr: NativeSQLExprT) -> NativeSQLExprT:
502502
return self._when(
503-
expr < self._lit(0), # type: ignore[operator]
503+
expr < self._lit(0),
504504
self._lit(float("nan")),
505505
self._function("sqrt", expr),
506506
)

narwhals/_sql/typing.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, TypeVar
3+
from typing import TYPE_CHECKING, Any, TypeVar, Self
44

55
from narwhals._compliant.expr import NativeExpr
66

@@ -9,32 +9,33 @@
99
from narwhals._sql.expr import SQLExpr
1010
from narwhals.dtypes import Boolean
1111

12+
# TODO: @mp, understand why these are here & if we need one for NativeSQLExprT
1213
SQLExprAny = SQLExpr[Any, Any]
1314
SQLLazyFrameAny = SQLLazyFrame[Any, Any, Any]
1415

1516
SQLExprT = TypeVar("SQLExprT", bound="SQLExprAny")
1617
SQLExprT_contra = TypeVar("SQLExprT_contra", bound="SQLExprAny", contravariant=True)
1718
SQLLazyFrameT = TypeVar("SQLLazyFrameT", bound="SQLLazyFrameAny")
18-
19+
NativeSQLExprT = TypeVar("NativeSQLExprT", bound="NativeSQLExpr")
1920

2021
class NativeSQLExpr(NativeExpr):
21-
# TODO @mp: fix input type for all these!
22-
def __gt__(self, value: float) -> Boolean: ...
22+
# both Self because we're comparing an expression with an expression?
23+
def __gt__(self, value: Self) -> Self: ...
2324

24-
def __lt__(self, value: float) -> Boolean: ...
25+
def __lt__(self, value: Self) -> Self: ...
2526

26-
def __ge__(self, value: float) -> Boolean: ...
27+
def __ge__(self, value: Self) -> Self: ...
2728

28-
def __le__(self, value: float) -> Boolean: ...
29+
def __le__(self, value: Self) -> Self: ...
2930

30-
def __eq__(self, value: float) -> Boolean: ...
31+
def __eq__(self, value: Self) -> Self: ...
3132

32-
def __ne__(self, value: float) -> Boolean: ...
33+
def __ne__(self, value: Self) -> Self: ...
3334
# do we want any more of the arithmetic methods? I wasn't sure between lefthan & righthand methods..
34-
def __sub__(self, value: float) -> Any: ...
35+
def __sub__(self, value: Self) -> Self: ...
3536

36-
def __add__(self, value: float) -> Any: ...
37+
def __add__(self, value: Self) -> Self: ...
3738

38-
def __truediv__(self, value: float) -> Any: ...
39+
def __truediv__(self, value: Self) -> Self: ...
3940

40-
def __mul__(self, value: float) -> Any: ...
41+
def __mul__(self, value: Self) -> Self: ...

0 commit comments

Comments
 (0)