Skip to content

Commit b987e1a

Browse files
authored
chore(typing): Generic CompliantDataFrame (#2115)
* feat(typing): Make `CompliantDataFrame` generic - Originally in #2064 - Will eventually support #2104 (comment) * chore(typing): Update eager backends * fix: Make `PolarsSeries` compliant - Now for `PolarsDataFrame` to be compliant, `PolarsSeries.alias` needs to be present - Since `PolarsDataFrame` can be returned in all of these places, they all required the update to `_polars` * fix(typing): Resolve new `mypy` errors Originally c11dc95
1 parent dc9fcaa commit b987e1a

File tree

10 files changed

+77
-62
lines changed

10 files changed

+77
-62
lines changed

narwhals/_arrow/dataframe.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
from narwhals.typing import CompliantLazyFrame
7272

7373

74-
class ArrowDataFrame(CompliantDataFrame, CompliantLazyFrame):
74+
class ArrowDataFrame(CompliantDataFrame["ArrowSeries"], CompliantLazyFrame):
7575
# --- not in the spec ---
7676
def __init__(
7777
self: Self,
@@ -354,24 +354,24 @@ def simple_select(self, *column_names: str) -> Self:
354354
self._native_frame.select(list(column_names)), validate_column_names=False
355355
)
356356

357-
def aggregate(self: Self, *exprs: ArrowExpr) -> Self:
357+
def aggregate(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame:
358358
return self.select(*exprs)
359359

360-
def select(self: Self, *exprs: ArrowExpr) -> Self:
361-
new_series: Sequence[ArrowSeries] = evaluate_into_exprs(self, *exprs)
360+
def select(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame:
361+
new_series = evaluate_into_exprs(self, *exprs)
362362
if not new_series:
363363
# return empty dataframe, like Polars does
364364
return self._from_native_frame(
365365
self._native_frame.__class__.from_arrays([]), validate_column_names=False
366366
)
367367
names = [s.name for s in new_series]
368-
new_series = align_series_full_broadcast(*new_series)
369-
df = pa.Table.from_arrays([s._native_series for s in new_series], names=names)
368+
reshaped = align_series_full_broadcast(*new_series)
369+
df = pa.Table.from_arrays([s._native_series for s in reshaped], names=names)
370370
return self._from_native_frame(df, validate_column_names=True)
371371

372-
def with_columns(self: Self, *exprs: ArrowExpr) -> Self:
372+
def with_columns(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame:
373373
native_frame = self._native_frame
374-
new_columns: list[ArrowSeries] = evaluate_into_exprs(self, *exprs)
374+
new_columns = evaluate_into_exprs(self, *exprs)
375375

376376
length = len(self)
377377
columns = self.columns
@@ -469,7 +469,7 @@ def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
469469
self._native_frame.drop(to_drop), validate_column_names=False
470470
)
471471

472-
def drop_nulls(self: Self, subset: list[str] | None) -> Self:
472+
def drop_nulls(self: ArrowDataFrame, subset: list[str] | None) -> ArrowDataFrame:
473473
if subset is None:
474474
return self._from_native_frame(
475475
self._native_frame.drop_null(), validate_column_names=False
@@ -551,7 +551,9 @@ def with_row_index(self: Self, name: str) -> Self:
551551
df.append_column(name, row_indices).select([name, *cols])
552552
)
553553

554-
def filter(self: Self, predicate: ArrowExpr | list[bool | None]) -> Self:
554+
def filter(
555+
self: ArrowDataFrame, predicate: ArrowExpr | list[bool | None]
556+
) -> ArrowDataFrame:
555557
if isinstance(predicate, list):
556558
mask_native: Mask | ArrowChunkedArray = predicate
557559
else:
@@ -627,7 +629,7 @@ def collect(
627629
self: Self,
628630
backend: Implementation | None,
629631
**kwargs: Any,
630-
) -> CompliantDataFrame:
632+
) -> CompliantDataFrame[Any]:
631633
if backend is Implementation.PYARROW or backend is None:
632634
from narwhals._arrow.dataframe import ArrowDataFrame
633635

@@ -743,12 +745,12 @@ def is_unique(self: Self) -> ArrowSeries:
743745
)
744746

745747
def unique(
746-
self: Self,
748+
self: ArrowDataFrame,
747749
subset: list[str] | None,
748750
*,
749751
keep: Literal["any", "first", "last", "none"],
750752
maintain_order: bool | None = None,
751-
) -> Self:
753+
) -> ArrowDataFrame:
752754
# The param `maintain_order` is only here for compatibility with the Polars API
753755
# and has no effect on the output.
754756
import numpy as np # ignore-banned-import

narwhals/_dask/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def collect(
8989
self: Self,
9090
backend: Implementation | None,
9191
**kwargs: Any,
92-
) -> CompliantDataFrame:
92+
) -> CompliantDataFrame[Any]:
9393
import pandas as pd
9494

9595
result = self._native_frame.compute(**kwargs)

narwhals/_duckdb/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def collect(
8989
self: Self,
9090
backend: ModuleType | Implementation | str | None,
9191
**kwargs: Any,
92-
) -> CompliantDataFrame:
92+
) -> CompliantDataFrame[Any]:
9393
if backend is None or backend is Implementation.PYARROW:
9494
import pyarrow as pa # ignore-banned-import
9595

narwhals/_expression_parsing.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from narwhals.expr import Expr
3131
from narwhals.typing import CompliantDataFrame
3232
from narwhals.typing import CompliantExpr
33-
from narwhals.typing import CompliantFrameT_contra
33+
from narwhals.typing import CompliantFrameT
3434
from narwhals.typing import CompliantLazyFrame
3535
from narwhals.typing import CompliantNamespace
3636
from narwhals.typing import CompliantSeries
@@ -52,8 +52,8 @@ def is_expr(obj: Any) -> TypeIs[Expr]:
5252

5353

5454
def evaluate_into_expr(
55-
df: CompliantFrameT_contra,
56-
expr: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co],
55+
df: CompliantFrameT,
56+
expr: CompliantExpr[CompliantFrameT, CompliantSeriesT_co],
5757
) -> Sequence[CompliantSeriesT_co]:
5858
"""Return list of raw columns.
5959
@@ -73,9 +73,9 @@ def evaluate_into_expr(
7373

7474

7575
def evaluate_into_exprs(
76-
df: CompliantFrameT_contra,
76+
df: CompliantFrameT,
7777
/,
78-
*exprs: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co],
78+
*exprs: CompliantExpr[CompliantFrameT, CompliantSeriesT_co],
7979
) -> list[CompliantSeriesT_co]:
8080
"""Evaluate each expr into Series."""
8181
return [
@@ -87,13 +87,13 @@ def evaluate_into_exprs(
8787

8888
@overload
8989
def maybe_evaluate_expr(
90-
df: CompliantFrameT_contra,
91-
expr: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co],
90+
df: CompliantFrameT,
91+
expr: CompliantExpr[CompliantFrameT, CompliantSeriesT_co],
9292
) -> CompliantSeriesT_co: ...
9393

9494

9595
@overload
96-
def maybe_evaluate_expr(df: CompliantDataFrame, expr: T) -> T: ...
96+
def maybe_evaluate_expr(df: CompliantDataFrame[Any], expr: T) -> T: ...
9797

9898

9999
def maybe_evaluate_expr(
@@ -155,7 +155,7 @@ def reuse_series_implementation(
155155
"""
156156
plx = expr.__narwhals_namespace__()
157157

158-
def func(df: CompliantDataFrame) -> Sequence[CompliantSeries]:
158+
def func(df: CompliantDataFrame[Any]) -> Sequence[CompliantSeries]:
159159
_kwargs = {
160160
**(call_kwargs or {}),
161161
**{
@@ -258,15 +258,15 @@ def is_simple_aggregation(expr: CompliantExpr[Any, Any]) -> bool:
258258

259259

260260
def combine_evaluate_output_names(
261-
*exprs: CompliantExpr[CompliantFrameT_contra, Any],
262-
) -> Callable[[CompliantFrameT_contra], Sequence[str]]:
261+
*exprs: CompliantExpr[CompliantFrameT, Any],
262+
) -> Callable[[CompliantFrameT], Sequence[str]]:
263263
# Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1, expr2)` takes the
264264
# first name of `expr1`.
265265
if not is_compliant_expr(exprs[0]): # pragma: no cover
266266
msg = f"Safety assertion failed, expected expression, got: {type(exprs[0])}. Please report a bug."
267267
raise AssertionError(msg)
268268

269-
def evaluate_output_names(df: CompliantFrameT_contra) -> Sequence[str]:
269+
def evaluate_output_names(df: CompliantFrameT) -> Sequence[str]:
270270
return exprs[0]._evaluate_output_names(df)[:1]
271271

272272
return evaluate_output_names
@@ -287,11 +287,11 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]:
287287

288288

289289
def extract_compliant(
290-
plx: CompliantNamespace[CompliantFrameT_contra, CompliantSeriesT_co],
290+
plx: CompliantNamespace[CompliantFrameT, CompliantSeriesT_co],
291291
other: Any,
292292
*,
293293
str_as_lit: bool,
294-
) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co] | object:
294+
) -> CompliantExpr[CompliantFrameT, CompliantSeriesT_co] | object:
295295
if is_expr(other):
296296
return other._to_compliant_expr(plx)
297297
if isinstance(other, str) and not str_as_lit:
@@ -306,7 +306,7 @@ def extract_compliant(
306306

307307
def evaluate_output_names_and_aliases(
308308
expr: CompliantExpr[Any, Any],
309-
df: CompliantDataFrame | CompliantLazyFrame,
309+
df: CompliantDataFrame[Any] | CompliantLazyFrame,
310310
exclude: Sequence[str],
311311
) -> tuple[Sequence[str], Sequence[str]]:
312312
output_names = expr._evaluate_output_names(df)

narwhals/_pandas_like/dataframe.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
)
8484

8585

86-
class PandasLikeDataFrame(CompliantDataFrame, CompliantLazyFrame):
86+
class PandasLikeDataFrame(CompliantDataFrame["PandasLikeSeries"], CompliantLazyFrame):
8787
# --- not in the spec ---
8888
def __init__(
8989
self: Self,
@@ -396,11 +396,13 @@ def simple_select(self: Self, *column_names: str) -> Self:
396396
validate_column_names=False,
397397
)
398398

399-
def aggregate(self: Self, *exprs: PandasLikeExpr) -> Self:
399+
def aggregate(
400+
self: PandasLikeDataFrame, *exprs: PandasLikeExpr
401+
) -> PandasLikeDataFrame:
400402
return self.select(*exprs)
401403

402-
def select(self: Self, *exprs: PandasLikeExpr) -> Self:
403-
new_series: list[PandasLikeSeries] = evaluate_into_exprs(self, *exprs)
404+
def select(self: PandasLikeDataFrame, *exprs: PandasLikeExpr) -> PandasLikeDataFrame:
405+
new_series = evaluate_into_exprs(self, *exprs)
404406
if not new_series:
405407
# return empty dataframe, like Polars does
406408
return self._from_native_frame(
@@ -414,7 +416,9 @@ def select(self: Self, *exprs: PandasLikeExpr) -> Self:
414416
)
415417
return self._from_native_frame(df, validate_column_names=True)
416418

417-
def drop_nulls(self: Self, subset: list[str] | None) -> Self:
419+
def drop_nulls(
420+
self: PandasLikeDataFrame, subset: list[str] | None
421+
) -> PandasLikeDataFrame:
418422
if subset is None:
419423
return self._from_native_frame(
420424
self._native_frame.dropna(axis=0), validate_column_names=False
@@ -445,7 +449,9 @@ def with_row_index(self: Self, name: str) -> Self:
445449
def row(self: Self, row: int) -> tuple[Any, ...]:
446450
return tuple(x for x in self._native_frame.iloc[row])
447451

448-
def filter(self: Self, predicate: PandasLikeExpr | list[bool]) -> Self:
452+
def filter(
453+
self: PandasLikeDataFrame, predicate: PandasLikeExpr | list[bool]
454+
) -> PandasLikeDataFrame:
449455
if isinstance(predicate, list):
450456
mask_native: pd.Series[Any] | list[bool] = predicate
451457
else:
@@ -457,9 +463,11 @@ def filter(self: Self, predicate: PandasLikeExpr | list[bool]) -> Self:
457463
self._native_frame.loc[mask_native], validate_column_names=False
458464
)
459465

460-
def with_columns(self: Self, *exprs: PandasLikeExpr) -> Self:
466+
def with_columns(
467+
self: PandasLikeDataFrame, *exprs: PandasLikeExpr
468+
) -> PandasLikeDataFrame:
461469
index = self._native_frame.index
462-
new_columns: list[PandasLikeSeries] = evaluate_into_exprs(self, *exprs)
470+
new_columns = evaluate_into_exprs(self, *exprs)
463471
if not new_columns and len(self) == 0:
464472
return self
465473

@@ -528,7 +536,7 @@ def collect(
528536
self: Self,
529537
backend: Implementation | None,
530538
**kwargs: Any,
531-
) -> CompliantDataFrame:
539+
) -> CompliantDataFrame[Any]:
532540
if backend is None:
533541
return PandasLikeDataFrame(
534542
self._native_frame,

narwhals/_polars/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def collect(
474474
self: Self,
475475
backend: Implementation | None,
476476
**kwargs: Any,
477-
) -> CompliantDataFrame:
477+
) -> CompliantDataFrame[Any]:
478478
try:
479479
result = self._native_frame.collect(**kwargs)
480480
except Exception as e: # noqa: BLE001

narwhals/_polars/series.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ def dtype(self: Self) -> DType:
115115
self._native_series.dtype, self._version, self._backend_version
116116
)
117117

118+
def alias(self, name: str) -> Self:
119+
return self._from_native_object(self._native_series.alias(name))
120+
118121
@overload
119122
def __getitem__(self: Self, item: int) -> Any: ...
120123

narwhals/_spark_like/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def collect(
203203
self: Self,
204204
backend: ModuleType | Implementation | str | None,
205205
**kwargs: Any,
206-
) -> CompliantDataFrame:
206+
) -> CompliantDataFrame[Any]:
207207
if backend is Implementation.PANDAS:
208208
import pandas as pd # ignore-banned-import
209209

narwhals/typing.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import TYPE_CHECKING
44
from typing import Any
55
from typing import Callable
6-
from typing import Generic
76
from typing import Literal
87
from typing import Protocol
98
from typing import Sequence
@@ -52,7 +51,12 @@ def __narwhals_series__(self) -> CompliantSeries: ...
5251
def alias(self, name: str) -> Self: ...
5352

5453

55-
class CompliantDataFrame(Protocol):
54+
CompliantSeriesT_co = TypeVar(
55+
"CompliantSeriesT_co", bound=CompliantSeries, covariant=True
56+
)
57+
58+
59+
class CompliantDataFrame(Protocol[CompliantSeriesT_co]):
5660
def __narwhals_dataframe__(self) -> Self: ...
5761
def __narwhals_namespace__(self) -> Any: ...
5862
def simple_select(
@@ -64,6 +68,7 @@ def aggregate(self, *exprs: Any) -> Self:
6468

6569
@property
6670
def columns(self) -> Sequence[str]: ...
71+
def get_column(self, name: str) -> CompliantSeriesT_co: ...
6772

6873

6974
class CompliantLazyFrame(Protocol):
@@ -80,30 +85,25 @@ def aggregate(self, *exprs: Any) -> Self:
8085
def columns(self) -> Sequence[str]: ...
8186

8287

83-
CompliantFrameT_contra = TypeVar(
84-
"CompliantFrameT_contra",
85-
bound="CompliantDataFrame | CompliantLazyFrame",
86-
contravariant=True,
87-
)
88-
CompliantSeriesT_co = TypeVar(
89-
"CompliantSeriesT_co", bound=CompliantSeries, covariant=True
88+
CompliantFrameT = TypeVar(
89+
"CompliantFrameT", bound="CompliantDataFrame[Any] | CompliantLazyFrame"
9090
)
9191

9292

93-
class CompliantExpr(Protocol, Generic[CompliantFrameT_contra, CompliantSeriesT_co]):
93+
class CompliantExpr(Protocol[CompliantFrameT, CompliantSeriesT_co]):
9494
_implementation: Implementation
9595
_backend_version: tuple[int, ...]
9696
_version: Version
97-
_evaluate_output_names: Callable[[CompliantFrameT_contra], Sequence[str]]
97+
_evaluate_output_names: Callable[[CompliantFrameT], Sequence[str]]
9898
_alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None
9999
_depth: int
100100
_function_name: str
101101

102-
def __call__(self, df: Any) -> Sequence[CompliantSeriesT_co]: ...
102+
def __call__(self, df: CompliantFrameT) -> Sequence[CompliantSeriesT_co]: ...
103103
def __narwhals_expr__(self) -> None: ...
104104
def __narwhals_namespace__(
105105
self,
106-
) -> CompliantNamespace[CompliantFrameT_contra, CompliantSeriesT_co]: ...
106+
) -> CompliantNamespace[CompliantFrameT, CompliantSeriesT_co]: ...
107107
def is_null(self) -> Self: ...
108108
def alias(self, name: str) -> Self: ...
109109
def cast(self, dtype: DType) -> Self: ...
@@ -125,21 +125,21 @@ def broadcast(
125125
) -> Self: ...
126126

127127

128-
class CompliantNamespace(Protocol, Generic[CompliantFrameT_contra, CompliantSeriesT_co]):
128+
class CompliantNamespace(Protocol[CompliantFrameT, CompliantSeriesT_co]):
129129
def col(
130130
self, *column_names: str
131-
) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co]: ...
131+
) -> CompliantExpr[CompliantFrameT, CompliantSeriesT_co]: ...
132132
def lit(
133133
self, value: Any, dtype: DType | None
134-
) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co]: ...
134+
) -> CompliantExpr[CompliantFrameT, CompliantSeriesT_co]: ...
135135

136136

137137
class SupportsNativeNamespace(Protocol):
138138
def __native_namespace__(self) -> ModuleType: ...
139139

140140

141141
IntoCompliantExpr: TypeAlias = (
142-
"CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co] | CompliantSeriesT_co"
142+
"CompliantExpr[CompliantFrameT, CompliantSeriesT_co] | CompliantSeriesT_co"
143143
)
144144

145145
IntoExpr: TypeAlias = Union["Expr", str, "Series[Any]"]

0 commit comments

Comments
 (0)