Skip to content

Commit 787114d

Browse files
ym-pettpre-commit-ci[bot]dangotbannedMarcoGorelli
authored
chore: fixup operator type ignores for SQLExpr (#2920)
--------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: dangotbanned <[email protected]> Co-authored-by: Marco Edward Gorelli <[email protected]>
1 parent 3d0ec2f commit 787114d

File tree

5 files changed

+61
-52
lines changed

5 files changed

+61
-52
lines changed

narwhals/_compliant/expr.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,17 @@ class NativeExpr(Protocol):
6262
"""
6363

6464
def between(self, *args: Any, **kwds: Any) -> Any: ...
65-
def isin(self, *args: Any, **kwds: Any) -> Any: ...
65+
66+
# NOTE: None of these are annotated for `dx.Series`, but are added imperatively
67+
# Probably better to define a sub-protocol for `NativeSQLExpr`
68+
# - match `dx.Series` to `NativeExpr`
69+
# - match the others to `NativeSQLExpr`
70+
def __gt__(self, value: Any, /) -> Self: ...
71+
def __lt__(self, value: Any, /) -> Self: ...
72+
def __ge__(self, value: Any, /) -> Self: ...
73+
def __le__(self, value: Any, /) -> Self: ...
74+
def __eq__(self, value: Any, /) -> Self: ... # type: ignore[override]
75+
def __ne__(self, value: Any, /) -> Self: ... # type: ignore[override]
6676

6777

6878
class CompliantExpr(Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]):

narwhals/_dask/expr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@
4242

4343

4444
class DaskExpr(
45-
LazyExpr["DaskLazyFrame", "dx.Series"],
46-
DepthTrackingExpr["DaskLazyFrame", "dx.Series"],
45+
LazyExpr["DaskLazyFrame", "dx.Series"], # pyright: ignore[reportInvalidTypeArguments]
46+
DepthTrackingExpr["DaskLazyFrame", "dx.Series"], # pyright: ignore[reportInvalidTypeArguments]
4747
):
4848
_implementation: Implementation = Implementation.DASK
4949

5050
def __init__(
5151
self,
52-
call: EvalSeries[DaskLazyFrame, dx.Series],
52+
call: EvalSeries[DaskLazyFrame, dx.Series], # pyright: ignore[reportInvalidTypeForm]
5353
*,
5454
depth: int,
5555
function_name: str,

narwhals/_dask/namespace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
304304
)
305305

306306

307-
class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]):
307+
class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]): # pyright: ignore[reportInvalidTypeArguments]
308308
@property
309309
def _then(self) -> type[DaskThen]:
310310
return DaskThen
@@ -344,7 +344,7 @@ def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
344344
return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType]
345345

346346

347-
class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr, DaskWhen], DaskExpr):
347+
class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr, DaskWhen], DaskExpr): # pyright: ignore[reportInvalidTypeArguments]
348348
_depth: int = 0
349349
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
350350
_function_name: str = "whenthen"

narwhals/_dask/selectors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from narwhals._dask.dataframe import DaskLazyFrame # noqa: F401
1313

1414

15-
class DaskSelectorNamespace(LazySelectorNamespace["DaskLazyFrame", "dx.Series"]):
15+
class DaskSelectorNamespace(LazySelectorNamespace["DaskLazyFrame", "dx.Series"]): # pyright: ignore[reportInvalidTypeArguments]
1616
@property
1717
def _selector(self) -> type[DaskSelector]:
1818
return DaskSelector

narwhals/_sql/expr.py

Lines changed: 44 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol, cast
3+
# ruff: noqa: N806
4+
import operator as op
5+
from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol
46

57
from narwhals._compliant.expr import LazyExpr
68
from narwhals._compliant.typing import (
@@ -276,7 +278,7 @@ def func(
276278
}
277279
return [
278280
self._when(
279-
self._window_expression( # type: ignore[operator]
281+
self._window_expression(
280282
self._function("count", expr), **window_kwargs
281283
)
282284
>= self._lit(min_samples),
@@ -501,9 +503,7 @@ def round(self, decimals: int) -> Self:
501503
def sqrt(self) -> Self:
502504
def _sqrt(expr: NativeExprT) -> NativeExprT:
503505
return self._when(
504-
expr < self._lit(0), # type: ignore[operator]
505-
self._lit(float("nan")),
506-
self._function("sqrt", expr),
506+
expr < self._lit(0), self._lit(float("nan")), self._function("sqrt", expr)
507507
)
508508

509509
return self._with_elementwise(_sqrt)
@@ -513,13 +513,14 @@ def exp(self) -> Self:
513513

514514
def log(self, base: float) -> Self:
515515
def _log(expr: NativeExprT) -> NativeExprT:
516+
F = self._function
516517
return self._when(
517-
expr < self._lit(0), # type: ignore[operator]
518+
expr < self._lit(0),
518519
self._lit(float("nan")),
519520
self._when(
520-
cast("NativeExprT", expr == self._lit(0)),
521+
expr == self._lit(0),
521522
self._lit(float("-inf")),
522-
self._function("log", expr) / self._function("log", self._lit(base)), # type: ignore[operator]
523+
op.truediv(F("log", expr), F("log", self._lit(base))),
523524
),
524525
)
525526

@@ -577,11 +578,10 @@ def diff(self) -> Self:
577578
def func(
578579
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
579580
) -> Sequence[NativeExprT]:
581+
F = self._function
582+
window = self._window_expression
580583
return [
581-
expr # type: ignore[operator]
582-
- self._window_expression(
583-
self._function("lag", expr), inputs.partition_by, inputs.order_by
584-
)
584+
op.sub(expr, window(F("lag", expr), inputs.partition_by, inputs.order_by))
585585
for expr in self(df)
586586
]
587587

@@ -606,15 +606,12 @@ def func(
606606
) -> Sequence[NativeExprT]:
607607
# pyright checkers think the return type is `list[bool]` because of `==`
608608
return [
609-
cast(
610-
"NativeExprT",
611-
self._window_expression(
612-
self._function("row_number"),
613-
(*inputs.partition_by, expr),
614-
inputs.order_by,
615-
)
616-
== self._lit(1),
609+
self._window_expression(
610+
self._function("row_number"),
611+
(*inputs.partition_by, expr),
612+
inputs.order_by,
617613
)
614+
== self._lit(1)
618615
for expr in self(df)
619616
]
620617

@@ -625,17 +622,14 @@ def func(
625622
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
626623
) -> Sequence[NativeExprT]:
627624
return [
628-
cast(
629-
"NativeExprT",
630-
self._window_expression(
631-
self._function("row_number"),
632-
(*inputs.partition_by, expr),
633-
inputs.order_by,
634-
descending=[True] * len(inputs.order_by),
635-
nulls_last=[True] * len(inputs.order_by),
636-
)
637-
== self._lit(1),
625+
self._window_expression(
626+
self._function("row_number"),
627+
(*inputs.partition_by, expr),
628+
inputs.order_by,
629+
descending=[True] * len(inputs.order_by),
630+
nulls_last=[True] * len(inputs.order_by),
638631
)
632+
== self._lit(1)
639633
for expr in self(df)
640634
]
641635

@@ -665,20 +659,27 @@ def _rank(
665659
"nulls_last": nulls_last,
666660
}
667661
count_window_kwargs: dict[str, Any] = {"partition_by": (*partition_by, expr)}
662+
window = self._window_expression
663+
F = self._function
668664
if method == "max":
669-
rank_expr = (
670-
self._window_expression(func, **window_kwargs) # type: ignore[operator]
671-
+ self._window_expression(count_expr, **count_window_kwargs)
672-
- self._lit(1)
665+
rank_expr = op.sub(
666+
op.add(
667+
window(func, **window_kwargs),
668+
window(count_expr, **count_window_kwargs),
669+
),
670+
self._lit(1),
673671
)
674672
elif method == "average":
675-
rank_expr = self._window_expression(func, **window_kwargs) + (
676-
self._window_expression(count_expr, **count_window_kwargs) # type: ignore[operator]
677-
- self._lit(1)
678-
) / self._lit(2.0)
673+
rank_expr = op.add(
674+
window(func, **window_kwargs),
675+
op.truediv(
676+
op.sub(window(count_expr, **count_window_kwargs), self._lit(1)),
677+
self._lit(2.0),
678+
),
679+
)
679680
else:
680-
rank_expr = self._window_expression(func, **window_kwargs)
681-
return self._when(~self._function("isnull", expr), rank_expr) # type: ignore[operator]
681+
rank_expr = window(func, **window_kwargs)
682+
return self._when(~F("isnull", expr), rank_expr) # type: ignore[operator]
682683

683684
def _unpartitioned_rank(expr: NativeExprT) -> NativeExprT:
684685
return _rank(expr, descending=[descending], nulls_last=[True])
@@ -707,11 +708,9 @@ def is_unique(self) -> Self:
707708
def _is_unique(
708709
expr: NativeExprT, *partition_by: str | NativeExprT
709710
) -> NativeExprT:
710-
return cast(
711-
"NativeExprT",
712-
self._window_expression(self._count_star(), (expr, *partition_by))
713-
== self._lit(1),
714-
)
711+
return self._window_expression(
712+
self._count_star(), (expr, *partition_by)
713+
) == self._lit(1)
715714

716715
def _unpartitioned_is_unique(expr: NativeExprT) -> NativeExprT:
717716
return _is_unique(expr)

0 commit comments

Comments
 (0)