Skip to content

Commit 1f3ada9

Browse files
committed
feat(typing): Make CompliantDataFrame generic
- Originally in #2064 - Will eventually support #2104 (comment)
1 parent 77a0150 commit 1f3ada9

File tree

3 files changed

+38
-36
lines changed

3 files changed

+38
-36
lines changed

narwhals/_expression_parsing.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from narwhals.expr import Expr
3131
from narwhals.typing import CompliantDataFrame
3232
from narwhals.typing import CompliantExpr
33-
from narwhals.typing import CompliantFrameT_contra
33+
from narwhals.typing import CompliantFrameT
3434
from narwhals.typing import CompliantLazyFrame
3535
from narwhals.typing import CompliantNamespace
3636
from narwhals.typing import CompliantSeries
@@ -52,8 +52,8 @@ def is_expr(obj: Any) -> TypeIs[Expr]:
5252

5353

5454
def evaluate_into_expr(
55-
df: CompliantFrameT_contra,
56-
expr: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co],
55+
df: CompliantFrameT,
56+
expr: CompliantExpr[CompliantFrameT, CompliantSeriesT_co],
5757
) -> Sequence[CompliantSeriesT_co]:
5858
"""Return list of raw columns.
5959
@@ -73,9 +73,9 @@ def evaluate_into_expr(
7373

7474

7575
def evaluate_into_exprs(
76-
df: CompliantFrameT_contra,
76+
df: CompliantFrameT,
7777
/,
78-
*exprs: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co],
78+
*exprs: CompliantExpr[CompliantFrameT, CompliantSeriesT_co],
7979
) -> list[CompliantSeriesT_co]:
8080
"""Evaluate each expr into Series."""
8181
return [
@@ -87,13 +87,13 @@ def evaluate_into_exprs(
8787

8888
@overload
8989
def maybe_evaluate_expr(
90-
df: CompliantFrameT_contra,
91-
expr: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co],
90+
df: CompliantFrameT,
91+
expr: CompliantExpr[CompliantFrameT, CompliantSeriesT_co],
9292
) -> CompliantSeriesT_co: ...
9393

9494

9595
@overload
96-
def maybe_evaluate_expr(df: CompliantDataFrame, expr: T) -> T: ...
96+
def maybe_evaluate_expr(df: CompliantDataFrame[Any], expr: T) -> T: ...
9797

9898

9999
def maybe_evaluate_expr(
@@ -155,7 +155,7 @@ def reuse_series_implementation(
155155
"""
156156
plx = expr.__narwhals_namespace__()
157157

158-
def func(df: CompliantDataFrame) -> Sequence[CompliantSeries]:
158+
def func(df: CompliantDataFrame[Any]) -> Sequence[CompliantSeries]:
159159
_kwargs = {
160160
**(call_kwargs or {}),
161161
**{
@@ -258,15 +258,15 @@ def is_simple_aggregation(expr: CompliantExpr[Any, Any]) -> bool:
258258

259259

260260
def combine_evaluate_output_names(
261-
*exprs: CompliantExpr[CompliantFrameT_contra, Any],
262-
) -> Callable[[CompliantFrameT_contra], Sequence[str]]:
261+
*exprs: CompliantExpr[CompliantFrameT, Any],
262+
) -> Callable[[CompliantFrameT], Sequence[str]]:
263263
# Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1, expr2)` takes the
264264
# first name of `expr1`.
265265
if not is_compliant_expr(exprs[0]): # pragma: no cover
266266
msg = f"Safety assertion failed, expected expression, got: {type(exprs[0])}. Please report a bug."
267267
raise AssertionError(msg)
268268

269-
def evaluate_output_names(df: CompliantFrameT_contra) -> Sequence[str]:
269+
def evaluate_output_names(df: CompliantFrameT) -> Sequence[str]:
270270
return exprs[0]._evaluate_output_names(df)[:1]
271271

272272
return evaluate_output_names
@@ -287,11 +287,11 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]:
287287

288288

289289
def extract_compliant(
290-
plx: CompliantNamespace[CompliantFrameT_contra, CompliantSeriesT_co],
290+
plx: CompliantNamespace[CompliantFrameT, CompliantSeriesT_co],
291291
other: Any,
292292
*,
293293
str_as_lit: bool,
294-
) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co] | object:
294+
) -> CompliantExpr[CompliantFrameT, CompliantSeriesT_co] | object:
295295
if is_expr(other):
296296
return other._to_compliant_expr(plx)
297297
if isinstance(other, str) and not str_as_lit:
@@ -306,7 +306,7 @@ def extract_compliant(
306306

307307
def evaluate_output_names_and_aliases(
308308
expr: CompliantExpr[Any, Any],
309-
df: CompliantDataFrame | CompliantLazyFrame,
309+
df: CompliantDataFrame[Any] | CompliantLazyFrame,
310310
exclude: Sequence[str],
311311
) -> tuple[Sequence[str], Sequence[str]]:
312312
output_names = expr._evaluate_output_names(df)

narwhals/typing.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import TYPE_CHECKING
44
from typing import Any
55
from typing import Callable
6-
from typing import Generic
76
from typing import Literal
87
from typing import Protocol
98
from typing import Sequence
@@ -52,7 +51,12 @@ def __narwhals_series__(self) -> CompliantSeries: ...
5251
def alias(self, name: str) -> Self: ...
5352

5453

55-
class CompliantDataFrame(Protocol):
54+
CompliantSeriesT_co = TypeVar(
55+
"CompliantSeriesT_co", bound=CompliantSeries, covariant=True
56+
)
57+
58+
59+
class CompliantDataFrame(Protocol[CompliantSeriesT_co]):
5660
def __narwhals_dataframe__(self) -> Self: ...
5761
def __narwhals_namespace__(self) -> Any: ...
5862
def simple_select(
@@ -64,6 +68,7 @@ def aggregate(self, *exprs: Any) -> Self:
6468

6569
@property
6670
def columns(self) -> Sequence[str]: ...
71+
def get_column(self, name: str) -> CompliantSeriesT_co: ...
6772

6873

6974
class CompliantLazyFrame(Protocol):
@@ -80,30 +85,25 @@ def aggregate(self, *exprs: Any) -> Self:
8085
def columns(self) -> Sequence[str]: ...
8186

8287

83-
CompliantFrameT_contra = TypeVar(
84-
"CompliantFrameT_contra",
85-
bound="CompliantDataFrame | CompliantLazyFrame",
86-
contravariant=True,
87-
)
88-
CompliantSeriesT_co = TypeVar(
89-
"CompliantSeriesT_co", bound=CompliantSeries, covariant=True
88+
CompliantFrameT = TypeVar(
89+
"CompliantFrameT", bound="CompliantDataFrame[Any] | CompliantLazyFrame"
9090
)
9191

9292

93-
class CompliantExpr(Protocol, Generic[CompliantFrameT_contra, CompliantSeriesT_co]):
93+
class CompliantExpr(Protocol[CompliantFrameT, CompliantSeriesT_co]):
9494
_implementation: Implementation
9595
_backend_version: tuple[int, ...]
9696
_version: Version
97-
_evaluate_output_names: Callable[[CompliantFrameT_contra], Sequence[str]]
97+
_evaluate_output_names: Callable[[CompliantFrameT], Sequence[str]]
9898
_alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None
9999
_depth: int
100100
_function_name: str
101101

102-
def __call__(self, df: Any) -> Sequence[CompliantSeriesT_co]: ...
102+
def __call__(self, df: CompliantFrameT) -> Sequence[CompliantSeriesT_co]: ...
103103
def __narwhals_expr__(self) -> None: ...
104104
def __narwhals_namespace__(
105105
self,
106-
) -> CompliantNamespace[CompliantFrameT_contra, CompliantSeriesT_co]: ...
106+
) -> CompliantNamespace[CompliantFrameT, CompliantSeriesT_co]: ...
107107
def is_null(self) -> Self: ...
108108
def alias(self, name: str) -> Self: ...
109109
def cast(self, dtype: DType) -> Self: ...
@@ -125,21 +125,21 @@ def broadcast(
125125
) -> Self: ...
126126

127127

128-
class CompliantNamespace(Protocol, Generic[CompliantFrameT_contra, CompliantSeriesT_co]):
128+
class CompliantNamespace(Protocol[CompliantFrameT, CompliantSeriesT_co]):
129129
def col(
130130
self, *column_names: str
131-
) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co]: ...
131+
) -> CompliantExpr[CompliantFrameT, CompliantSeriesT_co]: ...
132132
def lit(
133133
self, value: Any, dtype: DType | None
134-
) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co]: ...
134+
) -> CompliantExpr[CompliantFrameT, CompliantSeriesT_co]: ...
135135

136136

137137
class SupportsNativeNamespace(Protocol):
138138
def __native_namespace__(self) -> ModuleType: ...
139139

140140

141141
IntoCompliantExpr: TypeAlias = (
142-
"CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co] | CompliantSeriesT_co"
142+
"CompliantExpr[CompliantFrameT, CompliantSeriesT_co] | CompliantSeriesT_co"
143143
)
144144

145145
IntoExpr: TypeAlias = Union["Expr", str, "Series[Any]"]

narwhals/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from narwhals.series import Series
5555
from narwhals.typing import CompliantDataFrame
5656
from narwhals.typing import CompliantExpr
57-
from narwhals.typing import CompliantFrameT_contra
57+
from narwhals.typing import CompliantFrameT
5858
from narwhals.typing import CompliantLazyFrame
5959
from narwhals.typing import CompliantSeries
6060
from narwhals.typing import CompliantSeriesT_co
@@ -1357,7 +1357,9 @@ def _hasattr_static(obj: Any, attr: str) -> bool:
13571357
return getattr_static(obj, attr, sentinel) is not sentinel
13581358

13591359

1360-
def is_compliant_dataframe(obj: Any) -> TypeIs[CompliantDataFrame]:
1360+
def is_compliant_dataframe(
1361+
obj: CompliantDataFrame[CompliantSeriesT_co] | Any,
1362+
) -> TypeIs[CompliantDataFrame[CompliantSeriesT_co]]:
13611363
return _hasattr_static(obj, "__narwhals_dataframe__")
13621364

13631365

@@ -1370,8 +1372,8 @@ def is_compliant_series(obj: Any) -> TypeIs[CompliantSeries]:
13701372

13711373

13721374
def is_compliant_expr(
1373-
obj: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co] | Any,
1374-
) -> TypeIs[CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co]]:
1375+
obj: CompliantExpr[CompliantFrameT, CompliantSeriesT_co] | Any,
1376+
) -> TypeIs[CompliantExpr[CompliantFrameT, CompliantSeriesT_co]]:
13751377
return hasattr(obj, "__narwhals_expr__")
13761378

13771379

0 commit comments

Comments
 (0)