88from typing import Iterable
99from typing import Literal
1010from typing import Sequence
11- from typing import cast
1211
1312import dask .dataframe as dd
1413import pandas as pd
2423from narwhals ._expression_parsing import combine_alias_output_names
2524from narwhals ._expression_parsing import combine_evaluate_output_names
2625from narwhals .typing import CompliantNamespace
27- from narwhals .utils import is_compliant_expr
2826
2927if TYPE_CHECKING :
3028 from typing_extensions import Self
3836 import dask_expr as dx
3937
4038
41- class DaskNamespace (CompliantNamespace [DaskLazyFrame , "dx.Series" ]):
39+ class DaskNamespace (CompliantNamespace [DaskLazyFrame , "dx.Series" ]): # pyright: ignore[reportInvalidTypeArguments] (#2044)
4240 @property
4341 def selectors (self : Self ) -> DaskSelectorNamespace :
4442 return DaskSelectorNamespace (self )
@@ -347,17 +345,16 @@ def __init__(
347345 version : Version ,
348346 ) -> None :
349347 self ._backend_version = backend_version
350- self ._condition = condition
351- self ._then_value = then_value
352- self ._otherwise_value = otherwise_value
348+ self ._condition : DaskExpr = condition
349+ self ._then_value : DaskExpr | Any = then_value
350+ self ._otherwise_value : DaskExpr | Any = otherwise_value
353351 self ._version = version
354352
355353 def __call__ (self : Self , df : DaskLazyFrame ) -> Sequence [dx .Series ]:
356354 condition = self ._condition (df )[0 ]
357- condition = cast ("dx.Series" , condition )
358355
359- if is_compliant_expr (self ._then_value ):
360- then_value : dx . Series | object = self ._then_value (df )[0 ]
356+ if isinstance (self ._then_value , DaskExpr ):
357+ then_value = self ._then_value (df )[0 ]
361358 else :
362359 then_value = self ._then_value
363360 (then_series ,) = align_series_full_broadcast (df , then_value )
@@ -366,13 +363,13 @@ def __call__(self: Self, df: DaskLazyFrame) -> Sequence[dx.Series]:
366363 if self ._otherwise_value is None :
367364 return [then_series .where (condition )]
368365
369- if is_compliant_expr (self ._otherwise_value ):
370- otherwise_value : dx . Series | object = self ._otherwise_value (df )[0 ]
366+ if isinstance (self ._otherwise_value , DaskExpr ):
367+ otherwise_value = self ._otherwise_value (df )[0 ]
371368 else :
372369 otherwise_value = self ._otherwise_value
373370 (otherwise_series ,) = align_series_full_broadcast (df , otherwise_value )
374371 validate_comparand (condition , otherwise_series )
375- return [then_series .where (condition , otherwise_series )]
372+ return [then_series .where (condition , otherwise_series )] # pyright: ignore[reportArgumentType]
376373
377374 def then (self : Self , value : DaskExpr | Any ) -> DaskThen :
378375 self ._then_value = value
@@ -405,17 +402,14 @@ def __init__(
405402 ) -> None :
406403 self ._backend_version = backend_version
407404 self ._version = version
408- self ._call = call
405+ self ._call : DaskWhen = call
409406 self ._depth = depth
410407 self ._function_name = function_name
411408 self ._evaluate_output_names = evaluate_output_names
412409 self ._alias_output_names = alias_output_names
413410 self ._call_kwargs = call_kwargs or {}
414411
415412 def otherwise (self : Self , value : DaskExpr | Any ) -> DaskExpr :
416- # type ignore because we are setting the `_call` attribute to a
417- # callable object of type `DaskWhen`, base class has the attribute as
418- # only a `Callable`
419- self ._call ._otherwise_value = value # type: ignore[attr-defined]
413+ self ._call ._otherwise_value = value
420414 self ._function_name = "whenotherwise"
421415 return self
0 commit comments