Skip to content

Commit 0828bb9

Browse files
committed
fix(typing): Mostly resolve invariance issues
The problem is that type checkers special-case `__init__` to allow covariance ... despite the same thing not being allowed on a classmethod (`from_<...>`)
1 parent 16cefb3 commit 0828bb9

File tree

7 files changed

+55
-58
lines changed

7 files changed

+55
-58
lines changed

narwhals/_plan/arrow/dataframe.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,9 @@
2424

2525
from typing_extensions import Self
2626

27-
from narwhals._arrow.typing import ChunkedArrayAny
27+
from narwhals._arrow.typing import ChunkedArrayAny # noqa: F401
2828
from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar
2929
from narwhals._plan.arrow.namespace import ArrowNamespace
30-
from narwhals._plan.dataframe import DataFrame as NwDataFrame
3130
from narwhals._plan.expressions import ExprIR, NamedIR
3231
from narwhals._plan.options import SortMultipleOptions
3332
from narwhals._plan.typing import Seq
@@ -63,11 +62,6 @@ def schema(self) -> dict[str, DType]:
6362
def __len__(self) -> int:
6463
return self.native.num_rows
6564

66-
def to_narwhals(self) -> NwDataFrame[pa.Table, ChunkedArrayAny]:
67-
from narwhals._plan.dataframe import DataFrame
68-
69-
return DataFrame[pa.Table, "ChunkedArrayAny"]._from_compliant(self)
70-
7165
@classmethod
7266
def from_dict(
7367
cls, data: Mapping[str, Any], /, *, schema: IntoSchema | None = None

narwhals/_plan/compliant/dataframe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,11 @@ def to_dict(
114114
def to_dict(
115115
self, *, as_series: bool
116116
) -> dict[str, SeriesT] | dict[str, list[Any]]: ...
117-
def to_narwhals(self) -> DataFrame[NativeDataFrameT, NativeSeriesT]: ...
117+
def to_narwhals(self) -> DataFrame[NativeDataFrameT, NativeSeriesT]:
118+
from narwhals._plan.dataframe import DataFrame
119+
120+
return DataFrame[NativeDataFrameT, NativeSeriesT](self)
121+
118122
def with_row_index(self, name: str) -> Self: ...
119123

120124

narwhals/_plan/compliant/series.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,6 @@ def to_list(self) -> list[Any]: ...
7070
def to_narwhals(self) -> Series[NativeSeriesT]:
7171
from narwhals._plan.series import Series
7272

73-
return Series[NativeSeriesT]._from_compliant(self)
73+
return Series[NativeSeriesT](self)
7474

7575
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: ...

narwhals/_plan/dataframe.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
from narwhals._plan.typing import (
1111
IntoExpr,
1212
NativeDataFrameT,
13+
NativeDataFrameT_co,
1314
NativeFrameT,
15+
NativeFrameT_co,
1416
NativeSeriesT,
1517
OneOrIterable,
1618
)
@@ -21,15 +23,13 @@
2123
if TYPE_CHECKING:
2224
from collections.abc import Sequence
2325

24-
import pyarrow as pa
2526
from typing_extensions import Self
2627

2728
from narwhals._plan.compliant.dataframe import CompliantDataFrame, CompliantFrame
28-
from narwhals.typing import NativeFrame
2929

3030

31-
class BaseFrame(Generic[NativeFrameT]):
32-
_compliant: CompliantFrame[Any, NativeFrameT]
31+
class BaseFrame(Generic[NativeFrameT_co]):
32+
_compliant: CompliantFrame[Any, NativeFrameT_co]
3333
_version: ClassVar[Version] = Version.MAIN
3434

3535
@property
@@ -47,30 +47,26 @@ def columns(self) -> list[str]:
4747
def __repr__(self) -> str: # pragma: no cover
4848
return generate_repr(f"nw.{type(self).__name__}", self.to_native().__repr__())
4949

50-
@classmethod
51-
def from_native(cls, native: Any, /) -> Self:
52-
raise NotImplementedError
50+
def __init__(self, compliant: Any, /) -> None:
51+
self._compliant = compliant
5352

54-
@classmethod
55-
def _from_compliant(cls, compliant: CompliantFrame[Any, NativeFrameT], /) -> Self:
56-
obj = cls.__new__(cls)
57-
obj._compliant = compliant
58-
return obj
53+
def _with_compliant(self, compliant: CompliantFrame[Any, NativeFrameT], /) -> Self:
54+
return type(self)(compliant)
5955

60-
def to_native(self) -> NativeFrameT:
56+
def to_native(self) -> NativeFrameT_co:
6157
return self._compliant.native
6258

6359
def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self:
6460
named_irs, schema = prepare_projection(
6561
_parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), schema=self
6662
)
67-
return self._from_compliant(self._compliant.select(schema.select_irs(named_irs)))
63+
return self._with_compliant(self._compliant.select(schema.select_irs(named_irs)))
6864

6965
def with_columns(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self:
7066
named_irs, schema = prepare_projection(
7167
_parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), schema=self
7268
)
73-
return self._from_compliant(
69+
return self._with_compliant(
7470
self._compliant.with_columns(schema.with_columns_irs(named_irs))
7571
)
7672

@@ -85,32 +81,33 @@ def sort(
8581
by, *more_by, descending=descending, nulls_last=nulls_last
8682
)
8783
named_irs, _ = prepare_projection(sort, schema=self)
88-
return self._from_compliant(self._compliant.sort(named_irs, opts))
84+
return self._with_compliant(self._compliant.sort(named_irs, opts))
8985

9086
def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self:
91-
return self._from_compliant(self._compliant.drop(columns, strict=strict))
87+
return self._with_compliant(self._compliant.drop(columns, strict=strict))
9288

9389
def drop_nulls(self, subset: str | Sequence[str] | None = None) -> Self:
9490
subset = [subset] if isinstance(subset, str) else subset
95-
return self._from_compliant(self._compliant.drop_nulls(subset))
91+
return self._with_compliant(self._compliant.drop_nulls(subset))
9692

9793

98-
class DataFrame(BaseFrame[NativeDataFrameT], Generic[NativeDataFrameT, NativeSeriesT]):
99-
_compliant: CompliantDataFrame[Any, NativeDataFrameT, NativeSeriesT]
94+
class DataFrame(
95+
BaseFrame[NativeDataFrameT_co], Generic[NativeDataFrameT_co, NativeSeriesT]
96+
):
97+
_compliant: CompliantDataFrame[Any, NativeDataFrameT_co, NativeSeriesT]
10098

10199
@property
102100
def _series(self) -> type[Series[NativeSeriesT]]:
103101
return Series[NativeSeriesT]
104102

105-
# NOTE: Gave up on trying to get typing working for now
106103
@classmethod
107-
def from_native( # type: ignore[override]
108-
cls, native: NativeFrame, /
109-
) -> DataFrame[pa.Table, pa.ChunkedArray[Any]]:
104+
def from_native(
105+
cls: type[DataFrame[Any, Any]], native: NativeDataFrameT, /
106+
) -> DataFrame[NativeDataFrameT]:
110107
if is_pyarrow_table(native):
111108
from narwhals._plan.arrow.dataframe import ArrowDataFrame
112109

113-
return ArrowDataFrame.from_native(native, cls._version).to_narwhals()
110+
return cls(ArrowDataFrame.from_native(native, cls._version))
114111

115112
raise NotImplementedError(type(native))
116113

@@ -129,7 +126,7 @@ def to_dict(
129126
) -> dict[str, Series[NativeSeriesT]] | dict[str, list[Any]]:
130127
if as_series:
131128
return {
132-
key: self._series._from_compliant(value)
129+
key: self._series(value)
133130
for key, value in self._compliant.to_dict(as_series=as_series).items()
134131
}
135132
return self._compliant.to_dict(as_series=as_series)

narwhals/_plan/group_by.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, frame: DataFrameT, grouper: Grouped, /) -> None:
2525

2626
def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFrameT:
2727
frame = self._frame
28-
return frame._from_compliant(
28+
return frame._with_compliant(
2929
self._grouper.agg(*aggs, **named_aggs)
3030
.resolve(frame)
3131
.evaluate(frame._compliant)
@@ -35,7 +35,7 @@ def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]:
3535
frame = self._frame
3636
resolver = self._grouper.agg().resolve(frame)
3737
for key, df in frame._compliant.group_by_resolver(resolver):
38-
yield key, frame._from_compliant(df)
38+
yield key, frame._with_compliant(df)
3939

4040

4141
class Grouped(Grouper["Resolved"]):

narwhals/_plan/series.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,19 @@
22

33
from typing import TYPE_CHECKING, Any, ClassVar, Generic
44

5-
from narwhals._plan.typing import NativeSeriesT
5+
from narwhals._plan.typing import NativeSeriesT, NativeSeriesT_co
66
from narwhals._utils import Version
77
from narwhals.dependencies import is_pyarrow_chunked_array
88

99
if TYPE_CHECKING:
1010
from collections.abc import Iterator
1111

12-
import pyarrow as pa
13-
from typing_extensions import Self
14-
1512
from narwhals._plan.compliant.series import CompliantSeries
1613
from narwhals.dtypes import DType
17-
from narwhals.typing import NativeSeries
1814

1915

20-
class Series(Generic[NativeSeriesT]):
21-
_compliant: CompliantSeries[NativeSeriesT]
16+
class Series(Generic[NativeSeriesT_co]):
17+
_compliant: CompliantSeries[NativeSeriesT_co]
2218
_version: ClassVar[Version] = Version.MAIN
2319

2420
@property
@@ -33,27 +29,21 @@ def dtype(self) -> DType:
3329
def name(self) -> str:
3430
return self._compliant.name
3531

36-
# NOTE: Gave up on trying to get typing working for now
32+
def __init__(self, compliant: CompliantSeries[NativeSeriesT_co], /) -> None:
33+
self._compliant = compliant
34+
3735
@classmethod
3836
def from_native(
39-
cls, native: NativeSeries, name: str = "", /
40-
) -> Series[pa.ChunkedArray[Any]]:
37+
cls: type[Series[Any]], native: NativeSeriesT, name: str = "", /
38+
) -> Series[NativeSeriesT]:
4139
if is_pyarrow_chunked_array(native):
4240
from narwhals._plan.arrow.series import ArrowSeries
4341

44-
return ArrowSeries.from_native(
45-
native, name, version=cls._version
46-
).to_narwhals()
42+
return cls(ArrowSeries.from_native(native, name, version=cls._version))
4743

4844
raise NotImplementedError(type(native))
4945

50-
@classmethod
51-
def _from_compliant(cls, compliant: CompliantSeries[NativeSeriesT], /) -> Self:
52-
obj = cls.__new__(cls)
53-
obj._compliant = compliant
54-
return obj
55-
56-
def to_native(self) -> NativeSeriesT:
46+
def to_native(self) -> NativeSeriesT_co:
5747
return self._compliant.native
5848

5949
def to_list(self) -> list[Any]:
@@ -63,5 +53,5 @@ def __iter__(self) -> Iterator[Any]:
6353
yield from self.to_native()
6454

6555

66-
class SeriesV1(Series[NativeSeriesT]):
56+
class SeriesV1(Series[NativeSeriesT_co]):
6757
_version: ClassVar[Version] = Version.V1

narwhals/_plan/typing.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,22 @@
8787
"NonNestedLiteralT", bound="NonNestedLiteral", default="NonNestedLiteral"
8888
)
8989
NativeSeriesT = TypeVar("NativeSeriesT", bound="NativeSeries", default="NativeSeries")
90+
NativeSeriesT_co = TypeVar(
91+
"NativeSeriesT_co", bound="NativeSeries", covariant=True, default="NativeSeries"
92+
)
9093
NativeFrameT = TypeVar("NativeFrameT", bound="NativeFrame", default="NativeFrame")
94+
NativeFrameT_co = TypeVar(
95+
"NativeFrameT_co", bound="NativeFrame", covariant=True, default="NativeFrame"
96+
)
9197
NativeDataFrameT = TypeVar(
9298
"NativeDataFrameT", bound="NativeDataFrame", default="NativeDataFrame"
9399
)
100+
NativeDataFrameT_co = TypeVar(
101+
"NativeDataFrameT_co",
102+
bound="NativeDataFrame",
103+
covariant=True,
104+
default="NativeDataFrame",
105+
)
94106
LiteralT = TypeVar("LiteralT", bound="NonNestedLiteral | Series[t.Any]", default=t.Any)
95107
MapIR: TypeAlias = "t.Callable[[ExprIR], ExprIR]"
96108
"""A function to apply to all nodes in this tree."""

0 commit comments

Comments
 (0)