Skip to content

Commit 45038b0

Browse files
authored
chore(typing): Resolve _dask errors (#2087)
1 parent a43aca0 commit 45038b0

File tree

7 files changed

+38
-52
lines changed

7 files changed

+38
-52
lines changed

narwhals/_dask/dataframe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343
version: Version,
4444
validate_column_names: bool,
4545
) -> None:
46-
self._native_frame = native_dataframe
46+
self._native_frame: dd.DataFrame = native_dataframe
4747
self._backend_version = backend_version
4848
self._implementation = Implementation.DASK
4949
self._version = version
@@ -138,7 +138,7 @@ def collect(
138138

139139
@property
140140
def columns(self: Self) -> list[str]:
141-
return self._native_frame.columns.tolist() # type: ignore[no-any-return]
141+
return self._native_frame.columns.tolist()
142142

143143
def filter(self: Self, predicate: DaskExpr) -> Self:
144144
# `[0]` is safe as the predicate's expression only returns a single column
@@ -426,7 +426,7 @@ def gather_every(self: Self, n: int, offset: int) -> Self:
426426
return (
427427
self.with_row_index(row_index_token)
428428
.filter(
429-
(plx.col(row_index_token) >= offset) # type: ignore[operator]
429+
(plx.col(row_index_token) >= offset)
430430
& ((plx.col(row_index_token) - offset) % n == 0)
431431
)
432432
.drop([row_index_token], strict=False)

narwhals/_dask/expr.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from narwhals.utils import Version
3636

3737

38-
class DaskExpr(CompliantExpr["DaskLazyFrame", "dx.Series"]):
38+
class DaskExpr(CompliantExpr["DaskLazyFrame", "dx.Series"]): # pyright: ignore[reportInvalidTypeArguments] (#2044)
3939
_implementation: Implementation = Implementation.DASK
4040

4141
def __init__(
@@ -255,7 +255,7 @@ def __ne__(self: Self, other: DaskExpr) -> Self: # type: ignore[override]
255255
lambda _input, other: _input.__ne__(other), "__ne__", other=other
256256
)
257257

258-
def __ge__(self: Self, other: DaskExpr) -> Self:
258+
def __ge__(self: Self, other: DaskExpr | Any) -> Self:
259259
return self._from_call(
260260
lambda _input, other: _input.__ge__(other), "__ge__", other=other
261261
)
@@ -275,7 +275,7 @@ def __lt__(self: Self, other: DaskExpr) -> Self:
275275
lambda _input, other: _input.__lt__(other), "__lt__", other=other
276276
)
277277

278-
def __and__(self: Self, other: DaskExpr) -> Self:
278+
def __and__(self: Self, other: DaskExpr | Any) -> Self:
279279
return self._from_call(
280280
lambda _input, other: _input.__and__(other), "__and__", other=other
281281
)
@@ -454,7 +454,7 @@ def func(_input: dx.Series) -> dx.Series:
454454
_input.dtype, self._version, self._implementation
455455
)
456456
if dtype.is_numeric():
457-
return _input != _input # noqa: PLR0124
457+
return _input != _input # pyright: ignore[reportReturnType] # noqa: PLR0124
458458
msg = f"`.is_nan` only supported for numeric dtypes and not {dtype}, did you mean `.is_null`?"
459459
raise InvalidOperationError(msg)
460460

@@ -487,31 +487,23 @@ def is_first_distinct(self: Self) -> Self:
487487
def func(_input: dx.Series) -> dx.Series:
488488
_name = _input.name
489489
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
490-
_input = add_row_index(
491-
_input.to_frame(),
492-
col_token,
493-
backend_version=self._backend_version,
494-
implementation=self._implementation,
490+
frame = add_row_index(
491+
_input.to_frame(), col_token, self._backend_version, self._implementation
495492
)
496-
first_distinct_index = _input.groupby(_name).agg({col_token: "min"})[
497-
col_token
498-
]
499-
return _input[col_token].isin(first_distinct_index)
493+
first_distinct_index = frame.groupby(_name).agg({col_token: "min"})[col_token]
494+
return frame[col_token].isin(first_distinct_index)
500495

501496
return self._from_call(func, "is_first_distinct")
502497

503498
def is_last_distinct(self: Self) -> Self:
504499
def func(_input: dx.Series) -> dx.Series:
505500
_name = _input.name
506501
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
507-
_input = add_row_index(
508-
_input.to_frame(),
509-
col_token,
510-
backend_version=self._backend_version,
511-
implementation=self._implementation,
502+
frame = add_row_index(
503+
_input.to_frame(), col_token, self._backend_version, self._implementation
512504
)
513-
last_distinct_index = _input.groupby(_name).agg({col_token: "max"})[col_token]
514-
return _input[col_token].isin(last_distinct_index)
505+
last_distinct_index = frame.groupby(_name).agg({col_token: "max"})[col_token]
506+
return frame[col_token].isin(last_distinct_index)
515507

516508
return self._from_call(func, "is_last_distinct")
517509

narwhals/_dask/expr_dt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def func(s: dx.Series, time_zone: str) -> dx.Series:
9696
s.dtype, self._compliant_expr._version, Implementation.DASK
9797
)
9898
if dtype.time_zone is None: # type: ignore[attr-defined]
99-
return s.dt.tz_localize("UTC").dt.tz_convert(time_zone)
99+
return s.dt.tz_localize("UTC").dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue]
100100
else:
101-
return s.dt.tz_convert(time_zone)
101+
return s.dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue]
102102

103103
return self._compliant_expr._from_call(func, "tz_convert", time_zone=time_zone)
104104

@@ -125,7 +125,7 @@ def func(s: dx.Series, time_unit: TimeUnit) -> dx.Series:
125125
else:
126126
msg = "Input should be either of Date or Datetime type"
127127
raise TypeError(msg)
128-
return result.where(~mask_na)
128+
return result.where(~mask_na) # pyright: ignore[reportReturnType]
129129

130130
return self._compliant_expr._from_call(func, "datetime", time_unit=time_unit)
131131

narwhals/_dask/group_by.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
from narwhals._dask.dataframe import DaskLazyFrame
2828
from narwhals._dask.expr import DaskExpr
29-
from narwhals.typing import CompliantExpr
3029

3130
PandasSeriesGroupBy: TypeAlias = _PandasSeriesGroupBy[Any, Any]
3231
_AggFn: TypeAlias = Callable[..., Any]
@@ -73,7 +72,7 @@ class DaskLazyGroupBy:
7372
def __init__(
7473
self: Self, df: DaskLazyFrame, keys: list[str], *, drop_null_keys: bool
7574
) -> None:
76-
self._df = df
75+
self._df: DaskLazyFrame = df
7776
self._keys = keys
7877
self._grouped = self._df._native_frame.groupby(
7978
list(self._keys),
@@ -93,11 +92,11 @@ def agg(
9392
self._from_native_frame,
9493
)
9594

96-
def _from_native_frame(self: Self, df: DaskLazyFrame) -> DaskLazyFrame:
95+
def _from_native_frame(self: Self, df: dd.DataFrame) -> DaskLazyFrame:
9796
from narwhals._dask.dataframe import DaskLazyFrame
9897

9998
return DaskLazyFrame(
100-
df, # pyright: ignore[reportArgumentType]
99+
df,
101100
backend_version=self._df._backend_version,
102101
version=self._df._version,
103102
validate_column_names=True,
@@ -107,7 +106,7 @@ def _from_native_frame(self: Self, df: DaskLazyFrame) -> DaskLazyFrame:
107106
def agg_dask(
108107
df: DaskLazyFrame,
109108
grouped: Any,
110-
exprs: Sequence[CompliantExpr[DaskLazyFrame, dx.Series]],
109+
exprs: Sequence[DaskExpr],
111110
keys: list[str],
112111
from_dataframe: Callable[[Any], DaskLazyFrame],
113112
) -> DaskLazyFrame:
@@ -148,7 +147,7 @@ def agg_dask(
148147
agg_function = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name)
149148
# deal with n_unique case in a "lazy" mode to not depend on dask globally
150149
agg_function = (
151-
agg_function(**expr._call_kwargs) # type: ignore[attr-defined]
150+
agg_function(**expr._call_kwargs)
152151
if callable(agg_function)
153152
else agg_function
154153
)

narwhals/_dask/namespace.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing import Iterable
99
from typing import Literal
1010
from typing import Sequence
11-
from typing import cast
1211

1312
import dask.dataframe as dd
1413
import pandas as pd
@@ -24,7 +23,6 @@
2423
from narwhals._expression_parsing import combine_alias_output_names
2524
from narwhals._expression_parsing import combine_evaluate_output_names
2625
from narwhals.typing import CompliantNamespace
27-
from narwhals.utils import is_compliant_expr
2826

2927
if TYPE_CHECKING:
3028
from typing_extensions import Self
@@ -38,7 +36,7 @@
3836
import dask_expr as dx
3937

4038

41-
class DaskNamespace(CompliantNamespace[DaskLazyFrame, "dx.Series"]):
39+
class DaskNamespace(CompliantNamespace[DaskLazyFrame, "dx.Series"]): # pyright: ignore[reportInvalidTypeArguments] (#2044)
4240
@property
4341
def selectors(self: Self) -> DaskSelectorNamespace:
4442
return DaskSelectorNamespace(self)
@@ -347,17 +345,16 @@ def __init__(
347345
version: Version,
348346
) -> None:
349347
self._backend_version = backend_version
350-
self._condition = condition
351-
self._then_value = then_value
352-
self._otherwise_value = otherwise_value
348+
self._condition: DaskExpr = condition
349+
self._then_value: DaskExpr | Any = then_value
350+
self._otherwise_value: DaskExpr | Any = otherwise_value
353351
self._version = version
354352

355353
def __call__(self: Self, df: DaskLazyFrame) -> Sequence[dx.Series]:
356354
condition = self._condition(df)[0]
357-
condition = cast("dx.Series", condition)
358355

359-
if is_compliant_expr(self._then_value):
360-
then_value: dx.Series | object = self._then_value(df)[0]
356+
if isinstance(self._then_value, DaskExpr):
357+
then_value = self._then_value(df)[0]
361358
else:
362359
then_value = self._then_value
363360
(then_series,) = align_series_full_broadcast(df, then_value)
@@ -366,13 +363,13 @@ def __call__(self: Self, df: DaskLazyFrame) -> Sequence[dx.Series]:
366363
if self._otherwise_value is None:
367364
return [then_series.where(condition)]
368365

369-
if is_compliant_expr(self._otherwise_value):
370-
otherwise_value: dx.Series | object = self._otherwise_value(df)[0]
366+
if isinstance(self._otherwise_value, DaskExpr):
367+
otherwise_value = self._otherwise_value(df)[0]
371368
else:
372369
otherwise_value = self._otherwise_value
373370
(otherwise_series,) = align_series_full_broadcast(df, otherwise_value)
374371
validate_comparand(condition, otherwise_series)
375-
return [then_series.where(condition, otherwise_series)]
372+
return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType]
376373

377374
def then(self: Self, value: DaskExpr | Any) -> DaskThen:
378375
self._then_value = value
@@ -405,17 +402,14 @@ def __init__(
405402
) -> None:
406403
self._backend_version = backend_version
407404
self._version = version
408-
self._call = call
405+
self._call: DaskWhen = call
409406
self._depth = depth
410407
self._function_name = function_name
411408
self._evaluate_output_names = evaluate_output_names
412409
self._alias_output_names = alias_output_names
413410
self._call_kwargs = call_kwargs or {}
414411

415412
def otherwise(self: Self, value: DaskExpr | Any) -> DaskExpr:
416-
# type ignore because we are setting the `_call` attribute to a
417-
# callable object of type `DaskWhen`, base class has the attribute as
418-
# only a `Callable`
419-
self._call._otherwise_value = value # type: ignore[attr-defined]
413+
self._call._otherwise_value = value
420414
self._function_name = "whenotherwise"
421415
return self

narwhals/_dask/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def align_series_full_broadcast(
6262
return [
6363
s if isinstance(s, dx.Series) else df._native_frame.assign(_tmp=s)["_tmp"]
6464
for s in series
65-
]
65+
] # pyright: ignore[reportReturnType]
6666

6767

6868
def add_row_index(
@@ -155,8 +155,8 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> An
155155

156156

157157
def name_preserving_sum(s1: dx.Series, s2: dx.Series) -> dx.Series:
158-
return (s1 + s2).rename(s1.name)
158+
return (s1 + s2).rename(s1.name) # pyright: ignore[reportOperatorIssue]
159159

160160

161161
def name_preserving_div(s1: dx.Series, s2: dx.Series) -> dx.Series:
162-
return (s1 / s2).rename(s1.name)
162+
return (s1 / s2).rename(s1.name) # pyright: ignore[reportOperatorIssue]

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ module = [
269269
"*._pandas_like.*",
270270
"*._ibis.*",
271271
"*._arrow.*",
272+
"*._dask.*",
272273
]
273274
warn_return_any = false
274275

0 commit comments

Comments
 (0)