Skip to content

Commit b3bdb3b

Browse files
committed
fix(typing): EagerNamespace.all_horizontal
Was not easy getting that to please both `mypy` & `pyright`
1 parent d1dd6ce commit b3bdb3b

File tree

5 files changed

+20
-22
lines changed

5 files changed

+20
-22
lines changed

narwhals/_arrow/namespace.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
_Scalar: TypeAlias = Any
4848

4949

50-
class ArrowNamespace(EagerNamespace[ArrowDataFrame, ArrowSeries]):
50+
class ArrowNamespace(EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr]):
5151
@property
5252
def _expr(self) -> type[ArrowExpr]:
5353
return ArrowExpr
@@ -123,9 +123,7 @@ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
123123
version=self._version,
124124
)
125125

126-
# NOTE: Needs to be resolved in `EagerNamespace`
127-
# Probably, by adding an `EagerExprT` typevar
128-
def all_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr: # type: ignore[override]
126+
def all_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
129127
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
130128
series = chain.from_iterable(expr(df) for expr in exprs)
131129
return [reduce(operator.and_, align_series_full_broadcast(*series))]

narwhals/_compliant/expr.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,9 @@ def __call__(self, df: EagerDataFrameT) -> Sequence[EagerSeriesT]:
285285
def __repr__(self) -> str: # pragma: no cover
286286
return f"{type(self).__name__}(depth={self._depth}, function_name={self._function_name})"
287287

288-
def __narwhals_namespace__(self) -> EagerNamespace[EagerDataFrameT, EagerSeriesT]: ...
288+
def __narwhals_namespace__(
289+
self,
290+
) -> EagerNamespace[EagerDataFrameT, EagerSeriesT, Self]: ...
289291
def __narwhals_expr__(self) -> None: ...
290292

291293
@classmethod

narwhals/_compliant/namespace.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
from narwhals._compliant.typing import CompliantFrameT
88
from narwhals._compliant.typing import CompliantSeriesOrNativeExprT_co
99
from narwhals._compliant.typing import EagerDataFrameT
10-
from narwhals._compliant.typing import EagerSeriesT
10+
from narwhals._compliant.typing import EagerExprT
11+
from narwhals._compliant.typing import EagerSeriesT_co
1112
from narwhals.utils import deprecated
1213

1314
if TYPE_CHECKING:
1415
from narwhals._compliant.expr import CompliantExpr
15-
from narwhals._compliant.expr import EagerExpr
1616
from narwhals._compliant.selectors import CompliantSelectorNamespace
1717
from narwhals.dtypes import DType
1818

@@ -31,25 +31,23 @@ def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ...
3131

3232

3333
class EagerNamespace(
34-
CompliantNamespace[EagerDataFrameT, EagerSeriesT],
35-
Protocol[EagerDataFrameT, EagerSeriesT],
34+
CompliantNamespace[EagerDataFrameT, EagerSeriesT_co],
35+
Protocol[EagerDataFrameT, EagerSeriesT_co, EagerExprT],
3636
):
3737
# NOTE: Supporting moved ops
3838
# - `self_create_expr_from_callable` -> `self._expr._from_callable`
3939
# - `self_create_expr_from_series` -> `self._expr._from_series`
4040
@property
41-
def _expr(self) -> type[EagerExpr[EagerDataFrameT, EagerSeriesT]]: ...
41+
def _expr(self) -> type[EagerExprT]: ...
4242

4343
# NOTE: Supporting moved ops
4444
# - `self._create_series_from_scalar` -> `EagerSeries()._from_scalar`
4545
# - Was dependent on a `reference_series`, so is now an instance method
4646
# - `<class>._from_iterable` -> `self._series._from_iterable`
4747
@property
48-
def _series(self) -> type[EagerSeriesT]: ...
48+
def _series(self) -> type[EagerSeriesT_co]: ...
4949

50-
def all_horizontal(
51-
self, *exprs: EagerExpr[EagerDataFrameT, EagerSeriesT]
52-
) -> EagerExpr[EagerDataFrameT, EagerSeriesT]: ...
50+
def all_horizontal(self, *exprs: EagerExprT) -> EagerExprT: ...
5351

5452
@deprecated("ref'd in untyped code")
55-
def _create_compliant_series(self, value: Any) -> EagerSeriesT: ...
53+
def _create_compliant_series(self, value: Any) -> EagerSeriesT_co: ...

narwhals/_compliant/series.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def _from_iterable(
5656
cls: type[Self], data: Iterable[Any], name: str, *, context: _FullContext
5757
) -> Self: ...
5858

59-
def __narwhals_namespace__(self) -> EagerNamespace[Any, Self]: ...
59+
def __narwhals_namespace__(self) -> EagerNamespace[Any, Self, Any]: ...
6060

61-
def _to_expr(self) -> EagerExpr[Any, Self]:
62-
return self.__narwhals_namespace__()._expr._from_series(self)
61+
def _to_expr(self) -> EagerExpr[Any, Any]:
62+
return self.__narwhals_namespace__()._expr._from_series(self) # type: ignore[no-any-return]

narwhals/_pandas_like/namespace.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@
4040
_Scalar: TypeAlias = Any
4141

4242

43-
class PandasLikeNamespace(EagerNamespace[PandasLikeDataFrame, PandasLikeSeries]):
43+
class PandasLikeNamespace(
44+
EagerNamespace[PandasLikeDataFrame, PandasLikeSeries, PandasLikeExpr]
45+
):
4446
@property
4547
def _expr(self) -> type[PandasLikeExpr]:
4648
return PandasLikeExpr
@@ -149,9 +151,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
149151
context=self,
150152
)
151153

152-
# NOTE: Needs to be resolved in `EagerNamespace`
153-
# Probably, by adding an `EagerExprT` typevar
154-
def all_horizontal(self: Self, *exprs: PandasLikeExpr) -> PandasLikeExpr: # type: ignore[override]
154+
def all_horizontal(self: Self, *exprs: PandasLikeExpr) -> PandasLikeExpr:
155155
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
156156
series = align_series_full_broadcast(
157157
*(s for _expr in exprs for s in _expr(df))

0 commit comments

Comments
 (0)