Skip to content

Commit aa7a7a7

Browse files
chore(typing): enable reportIncompatibleMethodOverride in pyright (#3096)
* chore: enable reportIncompatibleMethodOverride in pyright * chore(suggestion): Try turning ibis issue into a positive? Part 1 of (#3096 (comment)) Isolates the problematic typing and then reuses it in other places it causes issues * chore(typing): Ignore with context Seems to appear at random and vanish Error message is nonsense * refactor: Move common `not_implemented` up * refactor(typing): Do this ignore instead actually * be reasonable #3096 (comment) --------- Co-authored-by: dangotbanned <[email protected]>
1 parent cdbd1bc commit aa7a7a7

File tree

16 files changed

+105
-94
lines changed

16 files changed

+105
-94
lines changed

narwhals/_arrow/series.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,12 +373,12 @@ def shift(self, n: int) -> Self:
373373
return self._with_native(self.native)
374374
return self._with_native(pa.concat_arrays(arrays))
375375

376-
def std(self, ddof: int, *, _return_py_scalar: bool = True) -> float:
376+
def std(self, *, ddof: int, _return_py_scalar: bool = True) -> float:
377377
return maybe_extract_py_scalar(
378378
pc.stddev(self.native, ddof=ddof), _return_py_scalar
379379
)
380380

381-
def var(self, ddof: int, *, _return_py_scalar: bool = True) -> float:
381+
def var(self, *, ddof: int, _return_py_scalar: bool = True) -> float:
382382
return maybe_extract_py_scalar(
383383
pc.variance(self.native, ddof=ddof), _return_py_scalar
384384
)

narwhals/_compliant/expr.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828
LazyExprT,
2929
NativeExprT,
3030
)
31-
from narwhals._utils import _StoresCompliant, qualified_type_name, zip_strict
31+
from narwhals._utils import (
32+
_StoresCompliant,
33+
not_implemented,
34+
qualified_type_name,
35+
zip_strict,
36+
)
3237
from narwhals.dependencies import is_numpy_array, is_numpy_scalar
3338

3439
if TYPE_CHECKING:
@@ -168,8 +173,14 @@ class DepthTrackingExpr(
168173
_depth: int
169174
_function_name: str
170175

176+
# NOTE: pyright bug?
177+
# Method "from_column_names" overrides class "CompliantExpr" in an incompatible manner
178+
# Parameter 2 type mismatch: base parameter is type "EvalNames[CompliantFrameT@DepthTrackingExpr]", override parameter is type "EvalNames[CompliantFrameT@DepthTrackingExpr]"
179+
# Type "EvalNames[CompliantFrameT@DepthTrackingExpr]" is not assignable to type "EvalNames[CompliantFrameT@DepthTrackingExpr]"
180+
# Parameter 1: type "CompliantFrameT@DepthTrackingExpr" is incompatible with type "CompliantFrameT@DepthTrackingExpr"
181+
# Type "CompliantFrameT@DepthTrackingExpr" is not assignable to type "CompliantFrameT@DepthTrackingExpr"
171182
@classmethod
172-
def from_column_names(
183+
def from_column_names( # pyright: ignore[reportIncompatibleMethodOverride]
173184
cls: type[Self],
174185
evaluate_column_names: EvalNames[CompliantFrameT],
175186
/,
@@ -899,6 +910,12 @@ def fn(names: Sequence[str]) -> Sequence[str]:
899910
def name(self) -> LazyExprNameNamespace[Self]:
900911
return LazyExprNameNamespace(self)
901912

913+
ewm_mean = not_implemented() # type: ignore[misc]
914+
map_batches = not_implemented() # type: ignore[misc]
915+
replace_strict = not_implemented() # type: ignore[misc]
916+
917+
cat: not_implemented = not_implemented() # type: ignore[assignment]
918+
902919

903920
class _ExprNamespace( # type: ignore[misc]
904921
_StoresCompliant[CompliantExprT_co], Protocol[CompliantExprT_co]

narwhals/_dask/expr.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -298,14 +298,14 @@ def min(self) -> Self:
298298
def max(self) -> Self:
299299
return self._with_callable(lambda expr: expr.max().to_series(), "max")
300300

301-
def std(self, ddof: int) -> Self:
301+
def std(self, *, ddof: int) -> Self:
302302
return self._with_callable(
303303
lambda expr: expr.std(ddof=ddof).to_series(),
304304
"std",
305305
scalar_kwargs={"ddof": ddof},
306306
)
307307

308-
def var(self, ddof: int) -> Self:
308+
def var(self, *, ddof: int) -> Self:
309309
return self._with_callable(
310310
lambda expr: expr.var(ddof=ddof).to_series(),
311311
"var",
@@ -682,20 +682,8 @@ def str(self) -> DaskExprStringNamespace:
682682
def dt(self) -> DaskExprDateTimeNamespace:
683683
return DaskExprDateTimeNamespace(self)
684684

685-
arg_max: not_implemented = not_implemented()
686-
arg_min: not_implemented = not_implemented()
687-
arg_true: not_implemented = not_implemented()
688-
ewm_mean: not_implemented = not_implemented()
689-
gather_every: not_implemented = not_implemented()
690-
head: not_implemented = not_implemented()
691-
map_batches: not_implemented = not_implemented()
692-
sample: not_implemented = not_implemented()
693-
rank: not_implemented = not_implemented()
694-
replace_strict: not_implemented = not_implemented()
695-
sort: not_implemented = not_implemented()
696-
tail: not_implemented = not_implemented()
685+
rank = not_implemented()
697686

698687
# namespaces
699688
list: not_implemented = not_implemented() # type: ignore[assignment]
700-
cat: not_implemented = not_implemented() # type: ignore[assignment]
701689
struct: not_implemented = not_implemented() # type: ignore[assignment]

narwhals/_duckdb/expr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def func(expr: Expression) -> Expression:
188188
def len(self) -> Self:
189189
return self._with_callable(lambda _expr: F("count"))
190190

191-
def std(self, ddof: int) -> Self:
191+
def std(self, *, ddof: int) -> Self:
192192
if ddof == 0:
193193
return self._with_callable(lambda expr: F("stddev_pop", expr))
194194
if ddof == 1:
@@ -204,7 +204,7 @@ def _std(expr: Expression) -> Expression:
204204

205205
return self._with_callable(_std)
206206

207-
def var(self, ddof: int) -> Self:
207+
def var(self, *, ddof: int) -> Self:
208208
if ddof == 0:
209209
return self._with_callable(lambda expr: F("var_pop", expr))
210210
if ddof == 1:

narwhals/_ibis/dataframe.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -336,29 +336,17 @@ def unique(
336336
return self._with_native(self.native.distinct(on=subset))
337337

338338
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
339-
if isinstance(descending, bool):
340-
descending = [descending for _ in range(len(by))]
339+
from narwhals._ibis.expr import IbisExpr
341340

342-
sort_cols: list[Any] = []
343-
344-
for i in range(len(by)):
345-
direction_fn = ibis.desc if descending[i] else ibis.asc
346-
col = direction_fn(by[i], nulls_first=not nulls_last)
347-
sort_cols.append(col)
348-
349-
return self._with_native(self.native.order_by(*sort_cols))
341+
cols = IbisExpr._sort(*by, descending=descending, nulls_last=nulls_last)
342+
return self._with_native(self.native.order_by(*cols))
350343

351344
def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
352-
if isinstance(reverse, bool):
353-
reverse = [reverse] * len(list(by))
354-
sort_cols = []
355-
356-
for is_reverse, by_col in zip_strict(reverse, by):
357-
direction_fn = ibis.asc if is_reverse else ibis.desc
358-
col = direction_fn(by_col, nulls_first=False)
359-
sort_cols.append(cast("ir.Column", col))
345+
from narwhals._ibis.expr import IbisExpr
360346

361-
return self._with_native(self.native.order_by(*sort_cols).head(k))
347+
desc = not reverse if isinstance(reverse, bool) else [not el for el in reverse]
348+
cols = IbisExpr._sort(*by, descending=desc, nulls_last=True)
349+
return self._with_native(self.native.order_by(*cols).head(k))
362350

363351
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
364352
subset_ = subset if subset is not None else self.columns

narwhals/_ibis/expr.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import operator
4-
from functools import partial
54
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, cast
65

76
import ibis
@@ -10,7 +9,17 @@
109
from narwhals._ibis.expr_list import IbisExprListNamespace
1110
from narwhals._ibis.expr_str import IbisExprStringNamespace
1211
from narwhals._ibis.expr_struct import IbisExprStructNamespace
13-
from narwhals._ibis.utils import is_floating, lit, narwhals_to_native_dtype
12+
from narwhals._ibis.utils import (
13+
IntoColumn,
14+
asc_nulls_first,
15+
asc_nulls_last,
16+
desc_nulls_first,
17+
desc_nulls_last,
18+
extend_bool,
19+
is_floating,
20+
lit,
21+
narwhals_to_native_dtype,
22+
)
1423
from narwhals._sql.expr import SQLExpr
1524
from narwhals._utils import Implementation, Version, not_implemented, zip_strict
1625

@@ -79,7 +88,7 @@ def _window_expression(
7988
self,
8089
expr: ir.Value,
8190
partition_by: Sequence[str | ir.Value] = (),
82-
order_by: Sequence[str | ir.Column] = (),
91+
order_by: Sequence[IntoColumn] = (),
8392
rows_start: int | None = None,
8493
rows_end: int | None = None,
8594
*,
@@ -94,9 +103,11 @@ def _window_expression(
94103
rows_between = {"preceding": -rows_start}
95104
else:
96105
rows_between = {}
106+
desc = descending or False
107+
last = nulls_last or False
97108
window = ibis.window(
98109
group_by=partition_by,
99-
order_by=self._sort(*order_by, descending=descending, nulls_last=nulls_last),
110+
order_by=self._sort(*order_by, descending=desc, nulls_last=last),
100111
**rows_between,
101112
)
102113
return expr.over(window)
@@ -110,24 +121,23 @@ def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Se
110121
# Ibis does its own broadcasting.
111122
return self
112123

124+
@staticmethod
113125
def _sort(
114-
self,
115-
*cols: ir.Column | str,
116-
descending: Sequence[bool] | None = None,
117-
nulls_last: Sequence[bool] | None = None,
126+
*cols: IntoColumn,
127+
descending: Sequence[bool] | bool = False,
128+
nulls_last: Sequence[bool] | bool = False,
118129
) -> Iterator[ir.Column]:
119-
descending = descending or [False] * len(cols)
120-
nulls_last = nulls_last or [False] * len(cols)
130+
n = len(cols)
131+
descending = extend_bool(descending, n)
132+
nulls_last = extend_bool(nulls_last, n)
121133
mapping = {
122-
(False, False): partial(ibis.asc, nulls_first=True),
123-
(False, True): partial(ibis.asc, nulls_first=False),
124-
(True, False): partial(ibis.desc, nulls_first=True),
125-
(True, True): partial(ibis.desc, nulls_first=False),
134+
(False, False): asc_nulls_first,
135+
(False, True): asc_nulls_last,
136+
(True, False): desc_nulls_first,
137+
(True, True): desc_nulls_last,
126138
}
127-
yield from (
128-
cast("ir.Column", mapping[(_desc, _nulls_last)](col))
129-
for col, _desc, _nulls_last in zip_strict(cols, descending, nulls_last)
130-
)
139+
for col, _desc, _nulls_last in zip_strict(cols, descending, nulls_last):
140+
yield mapping[(_desc, _nulls_last)](col)
131141

132142
@classmethod
133143
def from_column_names(
@@ -217,7 +227,7 @@ def func(df: IbisLazyFrame) -> Sequence[ir.IntegerScalar]:
217227
version=self._version,
218228
)
219229

220-
def std(self, ddof: int) -> Self:
230+
def std(self, *, ddof: int) -> Self:
221231
def _std(expr: ir.NumericColumn, ddof: int) -> ir.Value:
222232
if ddof == 0:
223233
return expr.std(how="pop")
@@ -230,7 +240,7 @@ def _std(expr: ir.NumericColumn, ddof: int) -> ir.Value:
230240

231241
return self._with_callable(lambda expr: _std(expr, ddof))
232242

233-
def var(self, ddof: int) -> Self:
243+
def var(self, *, ddof: int) -> Self:
234244
def _var(expr: ir.NumericColumn, ddof: int) -> ir.Value:
235245
if ddof == 0:
236246
return expr.var(how="pop")
@@ -289,7 +299,7 @@ def is_unique(self) -> Self:
289299

290300
def rank(self, method: RankMethod, *, descending: bool) -> Self:
291301
def _rank(expr: ir.Column) -> ir.Value:
292-
order_by = next(self._sort(expr, descending=[descending], nulls_last=[True]))
302+
order_by = next(self._sort(expr, descending=descending, nulls_last=True))
293303
window = ibis.window(order_by=order_by)
294304

295305
if method == "dense":

narwhals/_ibis/expr_dt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def truncate(self, every: str) -> IbisExpr:
5858
fn = self._truncate(UNITS_DICT_TRUNCATE[unit])
5959
return self.compliant._with_callable(fn)
6060

61-
def offset_by(self, every: str) -> IbisExpr:
62-
interval = Interval.parse_no_constraints(every)
61+
def offset_by(self, by: str) -> IbisExpr:
62+
interval = Interval.parse_no_constraints(by)
6363
unit = interval.unit
6464
if unit in {"y", "q", "mo", "d", "ns"}:
6565
msg = f"Offsetting by {unit} is not yet supported for ibis."

narwhals/_ibis/utils.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from functools import lru_cache
3+
from functools import lru_cache, partial
44
from typing import TYPE_CHECKING, Any, Literal, cast, overload
55

66
import ibis
@@ -9,7 +9,7 @@
99
from narwhals._utils import Version, isinstance_or_issubclass
1010

1111
if TYPE_CHECKING:
12-
from collections.abc import Mapping
12+
from collections.abc import Callable, Iterable, Mapping, Sequence
1313
from datetime import timedelta
1414

1515
import ibis.expr.types as ir
@@ -23,6 +23,8 @@
2323
from narwhals.dtypes import DType
2424
from narwhals.typing import IntoDType, PythonLiteral
2525

26+
IntoColumn: TypeAlias = "str | ir.Value | ir.Column"
27+
SortFn: TypeAlias = "Callable[[IntoColumn], ir.Column]"
2628
Incomplete: TypeAlias = Any
2729
"""Marker for upstream issues."""
2830

@@ -45,6 +47,23 @@ def lit(value: Any, dtype: Any | None = None) -> Incomplete:
4547
return literal(value, dtype)
4648

4749

50+
asc_nulls_first = cast("SortFn", partial(ibis.asc, nulls_first=True))
51+
asc_nulls_last = cast("SortFn", partial(ibis.asc, nulls_first=False))
52+
desc_nulls_first = cast("SortFn", partial(ibis.desc, nulls_first=True))
53+
desc_nulls_last = cast("SortFn", partial(ibis.desc, nulls_first=False))
54+
55+
56+
def extend_bool(
57+
value: bool | Iterable[bool], # noqa: FBT001
58+
n_match: int,
59+
) -> Sequence[bool]:
60+
"""Ensure the given bool or sequence of bools is the correct length.
61+
62+
Stolen from https://github.com/pola-rs/polars/blob/b8bfb07a4a37a8d449d6d1841e345817431142df/py-polars/polars/_utils/various.py#L580-L594
63+
"""
64+
return [value] * n_match if isinstance(value, bool) else list(value)
65+
66+
4867
BucketUnit: TypeAlias = Literal[
4968
"years",
5069
"quarters",

narwhals/_polars/dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def schema(self) -> dict[str, DType]:
181181

182182
def join(
183183
self,
184-
other: Self,
184+
other: PolarsBaseFrame[NativePolarsFrame],
185185
*,
186186
how: JoinStrategy,
187187
left_on: Sequence[str] | None,
@@ -568,7 +568,7 @@ def to_polars(self) -> pl.DataFrame:
568568

569569
def join(
570570
self,
571-
other: Self,
571+
other: PolarsBaseFrame[pl.DataFrame],
572572
*,
573573
how: JoinStrategy,
574574
left_on: Sequence[str] | None,

narwhals/_spark_like/dataframe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ def unique(
385385
def join(
386386
self,
387387
other: Self,
388+
*,
388389
how: JoinStrategy,
389390
left_on: Sequence[str] | None,
390391
right_on: Sequence[str] | None,

0 commit comments

Comments
 (0)