Skip to content

Commit 4aae069

Browse files
committed
wip operator: adding NativeSQLExprT to get rid of operator issue
1 parent 266cab9 commit 4aae069

File tree

2 files changed

+27
-24
lines changed

2 files changed

+27
-24
lines changed

narwhals/_sql/expr.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def window_f(
102102
return window_f
103103

104104
def _with_window_function(
105-
self, window_function: WindowFunction[SQLLazyFrameT, NativeExprT]
105+
self, window_function: WindowFunction[SQLLazyFrameT, NativeSQLExprT]
106106
) -> Self:
107107
return self.__class__(
108108
self._call,
@@ -165,10 +165,10 @@ def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
165165
)
166166

167167
@property
168-
def window_function(self) -> WindowFunction[SQLLazyFrameT, NativeExprT]:
168+
def window_function(self) -> WindowFunction[SQLLazyFrameT, NativeSQLExprT]:
169169
def default_window_func(
170-
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
171-
) -> Sequence[NativeExprT]:
170+
df: SQLLazyFrameT, inputs: WindowInputs[NativeSQLExprT]
171+
) -> Sequence[NativeSQLExprT]:
172172
assert not inputs.order_by # noqa: S101
173173
return [
174174
self._window_expression(expr, inputs.partition_by, inputs.order_by)
@@ -277,7 +277,7 @@ def func(
277277
}
278278
return [
279279
self._when(
280-
self._window_expression( # type: ignore[operator]
280+
self._window_expression(
281281
self._function("count", expr), **window_kwargs
282282
)
283283
>= self._lit(min_samples),
@@ -308,8 +308,8 @@ def call(df: SQLLazyFrameT) -> Sequence[NativeExprT]:
308308
return [func(cols)]
309309

310310
def window_function(
311-
df: SQLLazyFrameT, window_inputs: WindowInputs[NativeExprT]
312-
) -> Sequence[NativeExprT]:
311+
df: SQLLazyFrameT, window_inputs: WindowInputs[NativeSQLExprT]
312+
) -> Sequence[NativeSQLExprT]:
313313
cols = (
314314
col for _expr in exprs for col in _expr.window_function(df, window_inputs)
315315
)
@@ -513,14 +513,14 @@ def exp(self) -> Self:
513513
return self._with_elementwise(lambda expr: self._function("exp", expr))
514514

515515
def log(self, base: float) -> Self:
516-
def _log(expr: NativeExprT) -> NativeExprT:
516+
def _log(expr: NativeSQLExprT) -> NativeSQLExprT:
517517
return self._when(
518-
expr < self._lit(0), # type: ignore[operator]
518+
expr < self._lit(0),
519519
self._lit(float("nan")),
520520
self._when(
521-
cast("NativeExprT", expr == self._lit(0)),
521+
cast("NativeSQLExprT", expr == self._lit(0)),
522522
self._lit(float("-inf")),
523-
self._function("log", expr) / self._function("log", self._lit(base)), # type: ignore[operator]
523+
self._function("log", expr) / self._function("log", self._lit(base)),
524524
),
525525
)
526526

@@ -576,10 +576,10 @@ def rolling_std(
576576
# Other window functions
577577
def diff(self) -> Self:
578578
def func(
579-
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
580-
) -> Sequence[NativeExprT]:
579+
df: SQLLazyFrameT, inputs: WindowInputs[NativeSQLExprT]
580+
) -> Sequence[NativeSQLExprT]:
581581
return [
582-
expr # type: ignore[operator]
582+
expr
583583
- self._window_expression(
584584
self._function("lag", expr), inputs.partition_by, inputs.order_by
585585
)
@@ -651,13 +651,13 @@ def rank(self, method: RankMethod, *, descending: bool) -> Self:
651651
func = self._function("row_number")
652652

653653
def _rank(
654-
expr: NativeExprT,
655-
partition_by: Sequence[str | NativeExprT] = (),
656-
order_by: Sequence[str | NativeExprT] = (),
654+
expr: NativeSQLExprT,
655+
partition_by: Sequence[str | NativeSQLExprT] = (),
656+
order_by: Sequence[str | NativeSQLExprT] = (),
657657
*,
658658
descending: Sequence[bool],
659659
nulls_last: Sequence[bool],
660-
) -> NativeExprT:
660+
) -> NativeSQLExprT:
661661
count_expr = self._count_star()
662662
window_kwargs: dict[str, Any] = {
663663
"partition_by": partition_by,
@@ -668,25 +668,26 @@ def _rank(
668668
count_window_kwargs: dict[str, Any] = {"partition_by": (*partition_by, expr)}
669669
if method == "max":
670670
rank_expr = (
671-
self._window_expression(func, **window_kwargs) # type: ignore[operator]
671+
self._window_expression(func, **window_kwargs)
672672
+ self._window_expression(count_expr, **count_window_kwargs)
673673
- self._lit(1)
674674
)
675675
elif method == "average":
676676
rank_expr = self._window_expression(func, **window_kwargs) + (
677-
self._window_expression(count_expr, **count_window_kwargs) # type: ignore[operator]
677+
self._window_expression(count_expr, **count_window_kwargs)
678678
- self._lit(1)
679679
) / self._lit(2.0)
680680
else:
681681
rank_expr = self._window_expression(func, **window_kwargs)
682-
return self._when(~self._function("isnull", expr), rank_expr) # type: ignore[operator]
682+
# TODO: @mp, thought I added this to NativeSQLExprT but not working?
683+
return self._when(~self._function("isnull", expr), rank_expr) # type: ignore[operator]
683684

684-
def _unpartitioned_rank(expr: NativeExprT) -> NativeExprT:
685+
def _unpartitioned_rank(expr: NativeSQLExprT) -> NativeSQLExprT:
685686
return _rank(expr, descending=[descending], nulls_last=[True])
686687

687688
def _partitioned_rank(
688-
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
689-
) -> Sequence[NativeExprT]:
689+
df: SQLLazyFrameT, inputs: WindowInputs[NativeSQLExprT]
690+
) -> Sequence[NativeSQLExprT]:
690691
# node: when `descending` / `nulls_last` are supported in `.over`, they should be respected here
691692
# https://github.com/narwhals-dev/narwhals/issues/2790
692693
return [

narwhals/_sql/typing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def __truediv__(self, value: Self) -> Self: ...
2626

2727
def __mul__(self, value: Self) -> Self: ...
2828

29+
def __invert__(self, value: Self) -> Self: ...
30+
2931
if TYPE_CHECKING:
3032
from narwhals._sql.dataframe import SQLLazyFrame
3133
from narwhals._sql.expr import SQLExpr

0 commit comments

Comments
 (0)