Skip to content

Commit c9b57a2

Browse files
committed
feat(typing): CompliantDataFrame knows about CompliantExpr
- Resolves #2223 (comment) - Will help with `EagerDataFrame` as well
1 parent 4dc0514 commit c9b57a2

File tree

11 files changed

+41
-34
lines changed

11 files changed

+41
-34
lines changed

narwhals/_arrow/dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
from narwhals.typing import CompliantLazyFrame
7474

7575

76-
class ArrowDataFrame(EagerDataFrame["ArrowSeries"], CompliantLazyFrame):
76+
class ArrowDataFrame(EagerDataFrame["ArrowSeries", "ArrowExpr"], CompliantLazyFrame):
7777
# --- not in the spec ---
7878
def __init__(
7979
self: Self,
@@ -616,7 +616,7 @@ def collect(
616616
self: Self,
617617
backend: Implementation | None,
618618
**kwargs: Any,
619-
) -> CompliantDataFrame[Any]:
619+
) -> CompliantDataFrame[Any, Any]:
620620
if backend is Implementation.PYARROW or backend is None:
621621
from narwhals._arrow.dataframe import ArrowDataFrame
622622

narwhals/_compliant/dataframe.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from typing import TypeVar
1313
from typing import overload
1414

15+
from narwhals._compliant.typing import CompliantExprT_contra
1516
from narwhals._compliant.typing import CompliantSeriesT
17+
from narwhals._compliant.typing import EagerExprT_contra
1618
from narwhals._compliant.typing import EagerSeriesT
1719
from narwhals._expression_parsing import evaluate_output_names_and_aliases
1820

@@ -37,7 +39,7 @@
3739
T = TypeVar("T")
3840

3941

40-
class CompliantDataFrame(Sized, Protocol[CompliantSeriesT]):
42+
class CompliantDataFrame(Sized, Protocol[CompliantSeriesT, CompliantExprT_contra]):
4143
def __narwhals_dataframe__(self) -> Self: ...
4244
def __narwhals_namespace__(self) -> Any: ...
4345
def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ...
@@ -46,7 +48,7 @@ def simple_select(self, *column_names: str) -> Self:
4648
"""`select` where all args are column names."""
4749
...
4850

49-
def aggregate(self, *exprs: Any) -> Self: # pragma: no cover
51+
def aggregate(self, *exprs: CompliantExprT_contra) -> Self: # pragma: no cover
5052
"""`select` where all args are aggregations or literals.
5153
5254
(so, no broadcasting is necessary).
@@ -62,12 +64,12 @@ def shape(self) -> tuple[int, int]: ...
6264
def clone(self) -> Self: ...
6365
def collect(
6466
self, backend: Implementation | None, **kwargs: Any
65-
) -> CompliantDataFrame[Any]: ...
67+
) -> CompliantDataFrame[Any, Any]: ...
6668
def collect_schema(self) -> Mapping[str, DType]: ...
6769
def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ...
6870
def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
6971
def estimated_size(self, unit: SizeUnit) -> int | float: ...
70-
def filter(self, predicate: Any) -> Self: ...
72+
def filter(self, predicate: CompliantExprT_contra | Any) -> Self: ...
7173
def gather_every(self, n: int, offset: int) -> Self: ...
7274
def get_column(self, name: str) -> CompliantSeriesT: ...
7375
def group_by(self, *keys: str, drop_null_keys: bool) -> Any: ...
@@ -112,7 +114,7 @@ def sample(
112114
with_replacement: bool,
113115
seed: int | None,
114116
) -> Self: ...
115-
def select(self, *exprs: Any) -> Self: ...
117+
def select(self, *exprs: CompliantExprT_contra) -> Self: ...
116118
def sort(
117119
self, *by: str, descending: bool | Sequence[bool], nulls_last: bool
118120
) -> Self: ...
@@ -142,7 +144,7 @@ def unpivot(
142144
variable_name: str,
143145
value_name: str,
144146
) -> Self: ...
145-
def with_columns(self, *exprs: Any) -> Self: ...
147+
def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
146148
def with_row_index(self, name: str) -> Self: ...
147149
@overload
148150
def write_csv(self, file: None) -> str: ...
@@ -169,10 +171,11 @@ def schema(self) -> Mapping[str, DType]: ...
169171
def _iter_columns(self) -> Iterator[Any]: ...
170172

171173

172-
class EagerDataFrame(CompliantDataFrame[EagerSeriesT], Protocol[EagerSeriesT]):
173-
def _maybe_evaluate_expr(
174-
self, expr: EagerExpr[Self, EagerSeriesT] | T, /
175-
) -> EagerSeriesT | T:
174+
class EagerDataFrame(
175+
CompliantDataFrame[EagerSeriesT, EagerExprT_contra],
176+
Protocol[EagerSeriesT, EagerExprT_contra],
177+
):
178+
def _maybe_evaluate_expr(self, expr: EagerExprT_contra | T, /) -> EagerSeriesT | T:
176179
if is_eager_expr(expr):
177180
result: Sequence[EagerSeriesT] = expr(self)
178181
if len(result) > 1:
@@ -184,14 +187,10 @@ def _maybe_evaluate_expr(
184187
return result[0]
185188
return expr
186189

187-
def _evaluate_into_exprs(
188-
self, *exprs: EagerExpr[Self, EagerSeriesT]
189-
) -> Sequence[EagerSeriesT]:
190+
def _evaluate_into_exprs(self, *exprs: EagerExprT_contra) -> Sequence[EagerSeriesT]:
190191
return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs))
191192

192-
def _evaluate_into_expr(
193-
self, expr: EagerExpr[Self, EagerSeriesT], /
194-
) -> Sequence[EagerSeriesT]:
193+
def _evaluate_into_expr(self, expr: EagerExprT_contra, /) -> Sequence[EagerSeriesT]:
195194
"""Return list of raw columns.
196195
197196
For eager backends we alias operations at each step.

narwhals/_compliant/selectors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@
6565
SeriesOrExprT = TypeVar("SeriesOrExprT", bound="CompliantSeries | NativeExpr")
6666
SeriesT = TypeVar("SeriesT", bound="CompliantSeries")
6767
ExprT = TypeVar("ExprT", bound="NativeExpr")
68-
FrameT = TypeVar("FrameT", bound="CompliantDataFrame[Any] | CompliantLazyFrame")
69-
DataFrameT = TypeVar("DataFrameT", bound="CompliantDataFrame[Any]")
68+
FrameT = TypeVar("FrameT", bound="CompliantDataFrame[Any, Any] | CompliantLazyFrame")
69+
DataFrameT = TypeVar("DataFrameT", bound="CompliantDataFrame[Any, Any]")
7070
LazyFrameT = TypeVar("LazyFrameT", bound="CompliantLazyFrame")
7171
SelectorOrExpr: TypeAlias = (
7272
"CompliantSelector[FrameT, SeriesOrExprT] | CompliantExpr[FrameT, SeriesOrExprT]"
@@ -309,7 +309,7 @@ def __repr__(self: Self) -> str: # pragma: no cover
309309

310310

311311
def _eval_lhs_rhs(
312-
df: CompliantDataFrame[Any] | CompliantLazyFrame,
312+
df: CompliantDataFrame[Any, Any] | CompliantLazyFrame,
313313
lhs: CompliantExpr[Any, Any],
314314
rhs: CompliantExpr[Any, Any],
315315
) -> tuple[Sequence[str], Sequence[str]]:

narwhals/_compliant/typing.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,25 @@
3636
covariant=True,
3737
)
3838
CompliantFrameT = TypeVar(
39-
"CompliantFrameT", bound="CompliantDataFrame[Any] | CompliantLazyFrame"
39+
"CompliantFrameT", bound="CompliantDataFrame[Any, Any] | CompliantLazyFrame"
4040
)
41-
CompliantDataFrameT = TypeVar("CompliantDataFrameT", bound="CompliantDataFrame[Any]")
41+
CompliantDataFrameT = TypeVar("CompliantDataFrameT", bound="CompliantDataFrame[Any, Any]")
4242
CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound="CompliantLazyFrame")
4343
IntoCompliantExpr: TypeAlias = "CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | CompliantSeriesOrNativeExprT_co"
4444
CompliantExprT = TypeVar("CompliantExprT", bound="CompliantExpr[Any, Any]")
45+
CompliantExprT_contra = TypeVar(
46+
"CompliantExprT_contra", bound="CompliantExpr[Any, Any]", contravariant=True
47+
)
4548

46-
EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any]")
49+
EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any, Any]")
4750
EagerSeriesT = TypeVar("EagerSeriesT", bound="EagerSeries[Any]")
4851
EagerSeriesT_co = TypeVar("EagerSeriesT_co", bound="EagerSeries[Any]", covariant=True)
4952
EagerExprT = TypeVar("EagerExprT", bound="EagerExpr[Any, Any]")
53+
EagerExprT_contra = TypeVar(
54+
"EagerExprT_contra", bound="EagerExpr[Any, Any]", contravariant=True
55+
)
5056
EagerNamespaceAny: TypeAlias = (
51-
"EagerNamespace[EagerDataFrame[Any], EagerSeries[Any], EagerExpr[Any, Any]]"
57+
"EagerNamespace[EagerDataFrame[Any, Any], EagerSeries[Any], EagerExpr[Any, Any]]"
5258
)
5359

5460
AliasNames: TypeAlias = Callable[[Sequence[str]], Sequence[str]]

narwhals/_dask/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def collect(
9393
self: Self,
9494
backend: Implementation | None,
9595
**kwargs: Any,
96-
) -> CompliantDataFrame[Any]:
96+
) -> CompliantDataFrame[Any, Any]:
9797
result = self._native_frame.compute(**kwargs)
9898

9999
if backend is None or backend is Implementation.PANDAS:

narwhals/_duckdb/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def collect(
9292
self: Self,
9393
backend: ModuleType | Implementation | str | None,
9494
**kwargs: Any,
95-
) -> CompliantDataFrame[Any]:
95+
) -> CompliantDataFrame[Any, Any]:
9696
if backend is None or backend is Implementation.PYARROW:
9797
import pyarrow as pa # ignore-banned-import
9898

narwhals/_expression_parsing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def extract_compliant(
109109

110110
def evaluate_output_names_and_aliases(
111111
expr: CompliantExpr[Any, Any],
112-
df: CompliantDataFrame[Any] | CompliantLazyFrame,
112+
df: CompliantDataFrame[Any, Any] | CompliantLazyFrame,
113113
exclude: Sequence[str],
114114
) -> tuple[Sequence[str], Sequence[str]]:
115115
output_names = expr._evaluate_output_names(df)

narwhals/_pandas_like/dataframe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@
8383
)
8484

8585

86-
class PandasLikeDataFrame(EagerDataFrame["PandasLikeSeries"], CompliantLazyFrame):
86+
class PandasLikeDataFrame(
87+
EagerDataFrame["PandasLikeSeries", "PandasLikeExpr"], CompliantLazyFrame
88+
):
8789
# --- not in the spec ---
8890
def __init__(
8991
self: Self,
@@ -529,7 +531,7 @@ def collect(
529531
self: Self,
530532
backend: Implementation | None,
531533
**kwargs: Any,
532-
) -> CompliantDataFrame[Any]:
534+
) -> CompliantDataFrame[Any, Any]:
533535
if backend is None:
534536
return PandasLikeDataFrame(
535537
self._native_frame,

narwhals/_polars/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def collect(
502502
self: Self,
503503
backend: Implementation | None,
504504
**kwargs: Any,
505-
) -> CompliantDataFrame[Any]:
505+
) -> CompliantDataFrame[Any, Any]:
506506
try:
507507
result = self._native_frame.collect(**kwargs)
508508
except Exception as e: # noqa: BLE001

narwhals/_spark_like/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def collect(
172172
self: Self,
173173
backend: ModuleType | Implementation | str | None,
174174
**kwargs: Any,
175-
) -> CompliantDataFrame[Any]:
175+
) -> CompliantDataFrame[Any, Any]:
176176
if backend is Implementation.PANDAS:
177177
import pandas as pd # ignore-banned-import
178178

0 commit comments

Comments
 (0)