diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 12ef9180f5..fe5bdd777b 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -62,7 +62,17 @@ class NativeExpr(Protocol): """ def between(self, *args: Any, **kwds: Any) -> Any: ... - def isin(self, *args: Any, **kwds: Any) -> Any: ... + + # NOTE: None of these are annotated for `dx.Series`, but are added imperatively + # Probably better to define a sub-protocol for `NativeSQLExpr` + # - match `dx.Series` to `NativeExpr` + # - match the others to `NativeSQLExpr` + def __gt__(self, value: Any, /) -> Self: ... + def __lt__(self, value: Any, /) -> Self: ... + def __ge__(self, value: Any, /) -> Self: ... + def __le__(self, value: Any, /) -> Self: ... + def __eq__(self, value: Any, /) -> Self: ... # type: ignore[override] + def __ne__(self, value: Any, /) -> Self: ... # type: ignore[override] class CompliantExpr(Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]): diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 5f6952ba1a..a4f5e54820 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -42,14 +42,14 @@ class DaskExpr( - LazyExpr["DaskLazyFrame", "dx.Series"], - DepthTrackingExpr["DaskLazyFrame", "dx.Series"], + LazyExpr["DaskLazyFrame", "dx.Series"], # pyright: ignore[reportInvalidTypeArguments] + DepthTrackingExpr["DaskLazyFrame", "dx.Series"], # pyright: ignore[reportInvalidTypeArguments] ): _implementation: Implementation = Implementation.DASK def __init__( self, - call: EvalSeries[DaskLazyFrame, dx.Series], + call: EvalSeries[DaskLazyFrame, dx.Series], # pyright: ignore[reportInvalidTypeForm] *, depth: int, function_name: str, diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 0e08101253..518ab36b7e 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -304,7 +304,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: ) -class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]): +class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]): # pyright: ignore[reportInvalidTypeArguments] @property def _then(self) -> type[DaskThen]: return DaskThen @@ -344,7 +344,7 @@ def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]: return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType] -class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr, DaskWhen], DaskExpr): +class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr, DaskWhen], DaskExpr): # pyright: ignore[reportInvalidTypeArguments] _depth: int = 0 _scalar_kwargs: ScalarKwargs = {} # noqa: RUF012 _function_name: str = "whenthen" diff --git a/narwhals/_dask/selectors.py b/narwhals/_dask/selectors.py index 155276e06a..9fb6eeecb8 100644 --- a/narwhals/_dask/selectors.py +++ b/narwhals/_dask/selectors.py @@ -12,7 +12,7 @@ from narwhals._dask.dataframe import DaskLazyFrame # noqa: F401 -class DaskSelectorNamespace(LazySelectorNamespace["DaskLazyFrame", "dx.Series"]): +class DaskSelectorNamespace(LazySelectorNamespace["DaskLazyFrame", "dx.Series"]): # pyright: ignore[reportInvalidTypeArguments] @property def _selector(self) -> type[DaskSelector]: return DaskSelector diff --git a/narwhals/_sql/expr.py b/narwhals/_sql/expr.py index c925944c89..9e29868a16 100644 --- a/narwhals/_sql/expr.py +++ b/narwhals/_sql/expr.py @@ -1,6 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol, cast +# ruff: noqa: N806 +import operator as op +from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol from narwhals._compliant.expr import LazyExpr from narwhals._compliant.typing import ( @@ -276,7 +278,7 @@ def func( } return [ self._when( - self._window_expression( # type: ignore[operator] + self._window_expression( self._function("count", expr), **window_kwargs ) >= self._lit(min_samples), @@ -501,9 +503,7 @@ def round(self, decimals: int) -> Self: def sqrt(self) -> Self: def _sqrt(expr: NativeExprT) -> NativeExprT: return self._when( - expr < self._lit(0), # type: ignore[operator] - self._lit(float("nan")), - self._function("sqrt", expr), + expr < self._lit(0), self._lit(float("nan")), self._function("sqrt", expr) ) return self._with_elementwise(_sqrt) @@ -513,13 +513,14 @@ def exp(self) -> Self: def log(self, base: float) -> Self: def _log(expr: NativeExprT) -> NativeExprT: + F = self._function return self._when( - expr < self._lit(0), # type: ignore[operator] + expr < self._lit(0), self._lit(float("nan")), self._when( - cast("NativeExprT", expr == self._lit(0)), + expr == self._lit(0), self._lit(float("-inf")), - self._function("log", expr) / self._function("log", self._lit(base)), # type: ignore[operator] + op.truediv(F("log", expr), F("log", self._lit(base))), ), ) @@ -577,11 +578,10 @@ def diff(self) -> Self: def func( df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT] ) -> Sequence[NativeExprT]: + F = self._function + window = self._window_expression return [ - expr # type: ignore[operator] - - self._window_expression( - self._function("lag", expr), inputs.partition_by, inputs.order_by - ) + op.sub(expr, window(F("lag", expr), inputs.partition_by, inputs.order_by)) for expr in self(df) ] @@ -606,15 +606,12 @@ def func( ) -> Sequence[NativeExprT]: # pyright checkers think the return type is `list[bool]` because of `==` return [ - cast( - "NativeExprT", - self._window_expression( - self._function("row_number"), - (*inputs.partition_by, expr), - inputs.order_by, - ) - == self._lit(1), + self._window_expression( + self._function("row_number"), + (*inputs.partition_by, expr), + inputs.order_by, ) + == self._lit(1) for expr in self(df) ] @@ -625,17 +622,14 @@ def func( df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT] ) -> Sequence[NativeExprT]: return [ - cast( - "NativeExprT", - self._window_expression( - self._function("row_number"), - (*inputs.partition_by, expr), - inputs.order_by, - descending=[True] * len(inputs.order_by), - nulls_last=[True] * len(inputs.order_by), - ) - == self._lit(1), + self._window_expression( + self._function("row_number"), + (*inputs.partition_by, expr), + inputs.order_by, + descending=[True] * len(inputs.order_by), + nulls_last=[True] * len(inputs.order_by), ) + == self._lit(1) for expr in self(df) ] @@ -665,20 +659,27 @@ def _rank( "nulls_last": nulls_last, } count_window_kwargs: dict[str, Any] = {"partition_by": (*partition_by, expr)} + window = self._window_expression + F = self._function if method == "max": - rank_expr = ( - self._window_expression(func, **window_kwargs) # type: ignore[operator] - + self._window_expression(count_expr, **count_window_kwargs) - - self._lit(1) + rank_expr = op.sub( + op.add( + window(func, **window_kwargs), + window(count_expr, **count_window_kwargs), + ), + self._lit(1), ) elif method == "average": - rank_expr = self._window_expression(func, **window_kwargs) + ( - self._window_expression(count_expr, **count_window_kwargs) # type: ignore[operator] - - self._lit(1) - ) / self._lit(2.0) + rank_expr = op.add( + window(func, **window_kwargs), + op.truediv( + op.sub(window(count_expr, **count_window_kwargs), self._lit(1)), + self._lit(2.0), + ), + ) else: - rank_expr = self._window_expression(func, **window_kwargs) - return self._when(~self._function("isnull", expr), rank_expr) # type: ignore[operator] + rank_expr = window(func, **window_kwargs) + return self._when(~F("isnull", expr), rank_expr) # type: ignore[operator] def _unpartitioned_rank(expr: NativeExprT) -> NativeExprT: return _rank(expr, descending=[descending], nulls_last=[True]) @@ -707,11 +708,9 @@ def is_unique(self) -> Self: def _is_unique( expr: NativeExprT, *partition_by: str | NativeExprT ) -> NativeExprT: - return cast( - "NativeExprT", - self._window_expression(self._count_star(), (expr, *partition_by)) - == self._lit(1), - ) + return self._window_expression( + self._count_star(), (expr, *partition_by) + ) == self._lit(1) def _unpartitioned_is_unique(expr: NativeExprT) -> NativeExprT: return _is_unique(expr)