Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
40e1737
temp fix to type problem
ym-pett Aug 1, 2025
a8af821
notes on what should happen
ym-pett Aug 1, 2025
f8e853c
corrected base class
ym-pett Aug 1, 2025
fda9a7d
wip experimenting with methods
ym-pett Aug 1, 2025
86ed14a
wip improved types
ym-pett Aug 1, 2025
ec08949
wip operator
ym-pett Aug 3, 2025
a4de0bb
wip operator: added thoughts
ym-pett Aug 3, 2025
d372db6
wip operator: modified Self import
ym-pett Aug 4, 2025
3a97dcb
wip operator: added thoughts & only using in sqrt
ym-pett Aug 4, 2025
9a544e8
wip operator: more comments, going back to reading
ym-pett Aug 4, 2025
266cab9
wip operator: fixed NativeSQLExprT bound
ym-pett Aug 4, 2025
4aae069
wip operator: adding NativeSQLExprT to get rid of operator issue
ym-pett Aug 4, 2025
7ad112a
wip operator: added all changes from pair session, 1 problem remaining
ym-pett Aug 4, 2025
937e2a1
checkpoint
ym-pett Aug 4, 2025
854e2f3
joy: only 1 type error to go
ym-pett Aug 4, 2025
0fa44af
Merge branch 'main' into fix_sql_operator_problem
ym-pett Aug 4, 2025
a2deceb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 4, 2025
a9cabad
silencing ibisexpr type error
ym-pett Aug 4, 2025
a2c4f31
correcting error silencing
ym-pett Aug 4, 2025
dea1fe2
adding extra dunder methods
ym-pett Aug 5, 2025
27e748f
removed leftover comment
ym-pett Aug 5, 2025
495fb11
wip: trying to move operators into NativeExpr
ym-pett Aug 5, 2025
9383492
move to NativeExpr works
ym-pett Aug 5, 2025
df82744
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2025
c65df29
Merge branch 'main' into fix_sql_operator_problem
ym-pett Aug 5, 2025
dd39727
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2025
09731da
removing commented out
ym-pett Aug 5, 2025
b4b48d9
trying to silence ibis type errors
ym-pett Aug 5, 2025
a131d51
annoying stray comment
ym-pett Aug 5, 2025
3521336
Merge branch 'main' into fix_sql_operator_problem
ym-pett Aug 5, 2025
dcf1b80
fix(suggestion): Remove problematic protocol members
dangotbanned Aug 7, 2025
e6f7640
took out wip comment
ym-pett Aug 8, 2025
5347b9f
Merge branch 'main' into fix_sql_operator_problem
ym-pett Aug 8, 2025
b17d241
Merge branch 'main' into fix_sql_operator_problem
ym-pett Aug 8, 2025
462cf8a
remove `isin`
MarcoGorelli Aug 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,17 @@ class NativeExpr(Protocol):
"""

def between(self, *args: Any, **kwds: Any) -> Any: ...
def isin(self, *args: Any, **kwds: Any) -> Any: ...

# NOTE: None of these are annotated for `dx.Series`, but are added imperatively
# Probably better to define a sub-protocol for `NativeSQLExpr`
# - match `dx.Series` to `NativeExpr`
# - match the others to `NativeSQLExpr`
def __gt__(self, value: Any, /) -> Self: ...
def __lt__(self, value: Any, /) -> Self: ...
def __ge__(self, value: Any, /) -> Self: ...
def __le__(self, value: Any, /) -> Self: ...
def __eq__(self, value: Any, /) -> Self: ... # type: ignore[override]
def __ne__(self, value: Any, /) -> Self: ... # type: ignore[override]
Comment on lines +66 to +75
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh my 🀦

Comment on lines +66 to +75
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, thanks - if anything, I like it better if we match on these than if we match on between and isin (which, for example, Daft doesn't have)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the important bit

class NativeExpr(Protocol):
"""An `Expr`-like object from a package with [Lazy-only support](https://narwhals-dev.github.io/narwhals/extending/#levels-of-support).
Protocol members are chosen *purely* for matching statically - as they
are common to all currently supported packages.
"""

which, for example, Daft doesn't have

I had a similar issue for ibis.Table in (#2944)

IntoLazyFrame: TypeAlias = Union["NativeLazyFrame", "_NativeIbis"]

It's okay to have multiple protocols/sub-protocols/aliases if we need that now

I started with 1 because we had a common denominator, but we grow πŸ˜„

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

had a common denominator

right, so if we add these comparison operators, we can remove isin and between?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep between if possible

Kind of like how we have filter for NativeSeries - it's something to avoid false-positives

Waaaaay more than just the types we're interested in will have comparison dunders, but fewer will also have between as well



class CompliantExpr(Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]):
Expand Down
6 changes: 3 additions & 3 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@


class DaskExpr(
LazyExpr["DaskLazyFrame", "dx.Series"],
DepthTrackingExpr["DaskLazyFrame", "dx.Series"],
LazyExpr["DaskLazyFrame", "dx.Series"], # pyright: ignore[reportInvalidTypeArguments]
DepthTrackingExpr["DaskLazyFrame", "dx.Series"], # pyright: ignore[reportInvalidTypeArguments]
):
_implementation: Implementation = Implementation.DASK

def __init__(
self,
call: EvalSeries[DaskLazyFrame, dx.Series],
call: EvalSeries[DaskLazyFrame, dx.Series], # pyright: ignore[reportInvalidTypeForm]
*,
depth: int,
function_name: str,
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
)


class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]):
class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]): # pyright: ignore[reportInvalidTypeArguments]
@property
def _then(self) -> type[DaskThen]:
return DaskThen
Expand Down Expand Up @@ -344,7 +344,7 @@ def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType]


class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr, DaskWhen], DaskExpr):
class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr, DaskWhen], DaskExpr): # pyright: ignore[reportInvalidTypeArguments]
_depth: int = 0
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
_function_name: str = "whenthen"
2 changes: 1 addition & 1 deletion narwhals/_dask/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from narwhals._dask.dataframe import DaskLazyFrame # noqa: F401


class DaskSelectorNamespace(LazySelectorNamespace["DaskLazyFrame", "dx.Series"]):
class DaskSelectorNamespace(LazySelectorNamespace["DaskLazyFrame", "dx.Series"]): # pyright: ignore[reportInvalidTypeArguments]
@property
def _selector(self) -> type[DaskSelector]:
return DaskSelector
Expand Down
89 changes: 44 additions & 45 deletions narwhals/_sql/expr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol, cast
# ruff: noqa: N806
import operator as op
from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol

from narwhals._compliant.expr import LazyExpr
from narwhals._compliant.typing import (
Expand Down Expand Up @@ -276,7 +278,7 @@ def func(
}
return [
self._when(
self._window_expression( # type: ignore[operator]
self._window_expression(
self._function("count", expr), **window_kwargs
)
>= self._lit(min_samples),
Expand Down Expand Up @@ -501,9 +503,7 @@ def round(self, decimals: int) -> Self:
def sqrt(self) -> Self:
def _sqrt(expr: NativeExprT) -> NativeExprT:
return self._when(
expr < self._lit(0), # type: ignore[operator]
self._lit(float("nan")),
self._function("sqrt", expr),
expr < self._lit(0), self._lit(float("nan")), self._function("sqrt", expr)
)

return self._with_elementwise(_sqrt)
Expand All @@ -513,13 +513,14 @@ def exp(self) -> Self:

def log(self, base: float) -> Self:
def _log(expr: NativeExprT) -> NativeExprT:
F = self._function
return self._when(
expr < self._lit(0), # type: ignore[operator]
expr < self._lit(0),
self._lit(float("nan")),
self._when(
cast("NativeExprT", expr == self._lit(0)),
expr == self._lit(0),
self._lit(float("-inf")),
self._function("log", expr) / self._function("log", self._lit(base)), # type: ignore[operator]
op.truediv(F("log", expr), F("log", self._lit(base))),
),
)

Expand Down Expand Up @@ -577,11 +578,10 @@ def diff(self) -> Self:
def func(
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
) -> Sequence[NativeExprT]:
F = self._function
window = self._window_expression
return [
expr # type: ignore[operator]
- self._window_expression(
self._function("lag", expr), inputs.partition_by, inputs.order_by
)
op.sub(expr, window(F("lag", expr), inputs.partition_by, inputs.order_by))
for expr in self(df)
]

Expand All @@ -606,15 +606,12 @@ def func(
) -> Sequence[NativeExprT]:
# pyright checkers think the return type is `list[bool]` because of `==`
return [
cast(
"NativeExprT",
self._window_expression(
self._function("row_number"),
(*inputs.partition_by, expr),
inputs.order_by,
)
== self._lit(1),
self._window_expression(
self._function("row_number"),
(*inputs.partition_by, expr),
inputs.order_by,
)
== self._lit(1)
for expr in self(df)
]

Expand All @@ -625,17 +622,14 @@ def func(
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
) -> Sequence[NativeExprT]:
return [
cast(
"NativeExprT",
self._window_expression(
self._function("row_number"),
(*inputs.partition_by, expr),
inputs.order_by,
descending=[True] * len(inputs.order_by),
nulls_last=[True] * len(inputs.order_by),
)
== self._lit(1),
self._window_expression(
self._function("row_number"),
(*inputs.partition_by, expr),
inputs.order_by,
descending=[True] * len(inputs.order_by),
nulls_last=[True] * len(inputs.order_by),
)
== self._lit(1)
for expr in self(df)
]

Expand Down Expand Up @@ -665,20 +659,27 @@ def _rank(
"nulls_last": nulls_last,
}
count_window_kwargs: dict[str, Any] = {"partition_by": (*partition_by, expr)}
window = self._window_expression
F = self._function
if method == "max":
rank_expr = (
self._window_expression(func, **window_kwargs) # type: ignore[operator]
+ self._window_expression(count_expr, **count_window_kwargs)
- self._lit(1)
rank_expr = op.sub(
op.add(
window(func, **window_kwargs),
window(count_expr, **count_window_kwargs),
),
self._lit(1),
)
elif method == "average":
rank_expr = self._window_expression(func, **window_kwargs) + (
self._window_expression(count_expr, **count_window_kwargs) # type: ignore[operator]
- self._lit(1)
) / self._lit(2.0)
rank_expr = op.add(
window(func, **window_kwargs),
op.truediv(
op.sub(window(count_expr, **count_window_kwargs), self._lit(1)),
self._lit(2.0),
),
)
else:
rank_expr = self._window_expression(func, **window_kwargs)
return self._when(~self._function("isnull", expr), rank_expr) # type: ignore[operator]
rank_expr = window(func, **window_kwargs)
return self._when(~F("isnull", expr), rank_expr) # type: ignore[operator]

def _unpartitioned_rank(expr: NativeExprT) -> NativeExprT:
return _rank(expr, descending=[descending], nulls_last=[True])
Expand Down Expand Up @@ -707,11 +708,9 @@ def is_unique(self) -> Self:
def _is_unique(
expr: NativeExprT, *partition_by: str | NativeExprT
) -> NativeExprT:
return cast(
"NativeExprT",
self._window_expression(self._count_star(), (expr, *partition_by))
== self._lit(1),
)
return self._window_expression(
self._count_star(), (expr, *partition_by)
) == self._lit(1)

def _unpartitioned_is_unique(expr: NativeExprT) -> NativeExprT:
return _is_unique(expr)
Expand Down
Loading