Skip to content

Commit d5a126d

Browse files
committed
fix(typing): Huuuuuuuge progress
Will close #1928
1 parent aef26ce commit d5a126d

File tree

5 files changed

+136
-116
lines changed

5 files changed

+136
-116
lines changed

narwhals/stable/v2/typing.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, Union
3+
from typing import TYPE_CHECKING, Any, TypeVar, Union
44

55
if TYPE_CHECKING:
66
import sys
7-
from collections.abc import Iterable, Sized
87

98
from narwhals.stable.v2 import DataFrame, LazyFrame
109

@@ -14,23 +13,7 @@
1413
from typing_extensions import TypeAlias
1514

1615
from narwhals.stable.v2 import Expr, Series
17-
18-
# All dataframes supported by Narwhals have a
19-
# `columns` property. Their similarities don't extend
20-
# _that_ much further unfortunately...
21-
class NativeFrame(Protocol):
22-
@property
23-
def columns(self) -> Any: ...
24-
25-
def join(self, *args: Any, **kwargs: Any) -> Any: ...
26-
27-
class NativeDataFrame(Sized, NativeFrame, Protocol): ...
28-
29-
class NativeLazyFrame(NativeFrame, Protocol):
30-
def explain(self, *args: Any, **kwargs: Any) -> Any: ...
31-
32-
class NativeSeries(Sized, Iterable[Any], Protocol):
33-
def filter(self, *args: Any, **kwargs: Any) -> Any: ...
16+
from narwhals.typing import NativeDataFrame, NativeLazyFrame, NativeSeries
3417

3518

3619
IntoExpr: TypeAlias = Union["Expr", str, "Series[Any]"]

narwhals/translate.py

Lines changed: 76 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import datetime as dt
44
from decimal import Decimal
55
from functools import wraps
6-
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, overload
6+
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict, TypeVar, overload
77

88
from narwhals._constants import EPOCH, MS_PER_SECOND
99
from narwhals._namespace import (
@@ -30,6 +30,8 @@
3030
)
3131

3232
if TYPE_CHECKING:
33+
from typing_extensions import NotRequired, Unpack
34+
3335
from narwhals.dataframe import DataFrame, LazyFrame
3436
from narwhals.series import Series
3537
from narwhals.typing import (
@@ -99,148 +101,144 @@ def to_native(
99101
return narwhals_object
100102

101103

102-
@overload
103-
def from_native(native_object: SeriesT, **kwds: Any) -> SeriesT: ...
104+
# Upper bound
105+
class FromNativeFlags(TypedDict, total=False):
106+
pass_through: bool
107+
eager_only: bool
108+
series_only: bool
109+
allow_series: bool | None
104110

105111

106-
@overload
107-
def from_native(native_object: DataFrameT, **kwds: Any) -> DataFrameT: ...
112+
class DefaultFlags(TypedDict, total=False):
113+
pass_through: Literal[False]
114+
eager_only: Literal[False]
115+
series_only: Literal[False]
116+
allow_series: None
108117

109118

110-
@overload
111-
def from_native(native_object: LazyFrameT, **kwds: Any) -> LazyFrameT: ...
119+
class SeriesNever(TypedDict, total=False):
120+
pass_through: bool
121+
eager_only: bool
122+
series_only: Literal[False]
123+
allow_series: Literal[False] | None
124+
125+
126+
class PassThrough(TypedDict):
127+
pass_through: Literal[True]
128+
129+
130+
class SeriesAllow(TypedDict):
131+
pass_through: NotRequired[bool]
132+
eager_only: NotRequired[bool]
133+
series_only: NotRequired[Literal[False]]
134+
allow_series: Literal[True]
135+
136+
137+
class SeriesOnly(TypedDict):
138+
pass_through: NotRequired[bool]
139+
eager_only: NotRequired[bool]
140+
series_only: Literal[True]
141+
allow_series: NotRequired[bool | None]
142+
143+
144+
class EagerOnly(TypedDict):
145+
pass_through: NotRequired[bool]
146+
eager_only: Literal[True]
147+
series_only: NotRequired[bool]
148+
allow_series: NotRequired[bool | None]
149+
150+
151+
class LazyAllow(TypedDict):
152+
pass_through: NotRequired[bool]
153+
eager_only: NotRequired[Literal[False]]
154+
series_only: NotRequired[Literal[False]]
155+
allow_series: NotRequired[bool | None]
112156

113157

114158
@overload
115-
def from_native(
116-
native_object: IntoDataFrameT | IntoSeriesT,
117-
*,
118-
pass_through: Literal[True],
119-
eager_only: Literal[True],
120-
series_only: Literal[False] = ...,
121-
allow_series: Literal[True],
122-
) -> DataFrame[IntoDataFrameT] | Series[IntoSeriesT]: ...
159+
def from_native(native_object: SeriesT, **kwds: Any) -> SeriesT: ...
123160

124161

125162
@overload
126-
def from_native(
127-
native_object: IntoDataFrameT,
128-
*,
129-
pass_through: Literal[True],
130-
eager_only: Literal[False] = ...,
131-
series_only: Literal[False] = ...,
132-
allow_series: None = ...,
133-
) -> DataFrame[IntoDataFrameT]: ...
163+
def from_native(native_object: DataFrameT, **kwds: Any) -> DataFrameT: ...
134164

135165

136166
@overload
137-
def from_native(
138-
native_object: T,
139-
*,
140-
pass_through: Literal[True],
141-
eager_only: Literal[False] = ...,
142-
series_only: Literal[False] = ...,
143-
allow_series: None = ...,
144-
) -> T: ...
167+
def from_native(native_object: LazyFrameT, **kwds: Any) -> LazyFrameT: ...
145168

146169

147170
@overload
148171
def from_native(
149-
native_object: IntoDataFrameT,
150-
*,
151-
pass_through: Literal[True],
152-
eager_only: Literal[True],
153-
series_only: Literal[False] = ...,
154-
allow_series: None = ...,
172+
native_object: IntoDataFrameT, **kwds: Unpack[SeriesNever]
155173
) -> DataFrame[IntoDataFrameT]: ...
156174

157175

158176
@overload
159177
def from_native(
160-
native_object: T,
161-
*,
162-
pass_through: Literal[True],
163-
eager_only: Literal[True],
164-
series_only: Literal[False] = ...,
165-
allow_series: None = ...,
166-
) -> T: ...
178+
native_object: IntoSeriesT, **kwds: Unpack[SeriesOnly]
179+
) -> Series[IntoSeriesT]: ...
167180

168181

169182
@overload
170183
def from_native(
171-
native_object: IntoDataFrameT | IntoLazyFrameT | IntoSeriesT,
172-
*,
173-
pass_through: Literal[True],
174-
eager_only: Literal[False] = ...,
175-
series_only: Literal[False] = ...,
176-
allow_series: Literal[True],
177-
) -> DataFrame[IntoDataFrameT] | LazyFrame[IntoLazyFrameT] | Series[IntoSeriesT]: ...
184+
native_object: IntoSeriesT, **kwds: Unpack[SeriesAllow]
185+
) -> Series[IntoSeriesT]: ...
178186

179187

180188
@overload
181189
def from_native(
182-
native_object: IntoSeriesT,
183-
*,
184-
pass_through: Literal[True],
185-
eager_only: Literal[False] = ...,
186-
series_only: Literal[True],
187-
allow_series: None = ...,
188-
) -> Series[IntoSeriesT]: ...
190+
native_object: IntoLazyFrameT, **kwds: Unpack[LazyAllow]
191+
) -> LazyFrame[IntoLazyFrameT]: ...
189192

190193

191194
@overload
192195
def from_native(
193-
native_object: IntoLazyFrameT,
194-
*,
195-
pass_through: Literal[False] = ...,
196-
eager_only: Literal[False] = ...,
197-
series_only: Literal[False] = ...,
198-
allow_series: None = ...,
199-
) -> LazyFrame[IntoLazyFrameT]: ...
196+
native_object: IntoDataFrameT | IntoSeriesT, **kwds: Unpack[SeriesAllow]
197+
) -> DataFrame[IntoDataFrameT] | Series[IntoSeriesT]: ...
200198

201199

202200
@overload
203201
def from_native(
204-
native_object: IntoDataFrameT,
202+
native_object: T,
205203
*,
206-
pass_through: Literal[False] = ...,
204+
pass_through: Literal[True],
207205
eager_only: Literal[False] = ...,
208206
series_only: Literal[False] = ...,
209207
allow_series: None = ...,
210-
) -> DataFrame[IntoDataFrameT]: ...
208+
) -> T: ...
211209

212210

213211
@overload
214212
def from_native(
215-
native_object: IntoDataFrameT,
213+
native_object: T,
216214
*,
217-
pass_through: Literal[False] = ...,
215+
pass_through: Literal[True],
218216
eager_only: Literal[True],
219217
series_only: Literal[False] = ...,
220218
allow_series: None = ...,
221-
) -> DataFrame[IntoDataFrameT]: ...
219+
) -> T: ...
222220

223221

224222
@overload
225223
def from_native(
226-
native_object: IntoFrame | IntoSeries,
224+
native_object: IntoDataFrameT | IntoLazyFrameT | IntoSeriesT,
227225
*,
228-
pass_through: Literal[False] = ...,
226+
pass_through: Literal[True],
229227
eager_only: Literal[False] = ...,
230228
series_only: Literal[False] = ...,
231229
allow_series: Literal[True],
232-
) -> DataFrame[Any] | LazyFrame[Any] | Series[Any]: ...
230+
) -> DataFrame[IntoDataFrameT] | LazyFrame[IntoLazyFrameT] | Series[IntoSeriesT]: ...
233231

234232

235233
@overload
236234
def from_native(
237-
native_object: IntoSeriesT,
235+
native_object: IntoFrame | IntoSeries,
238236
*,
239237
pass_through: Literal[False] = ...,
240238
eager_only: Literal[False] = ...,
241-
series_only: Literal[True],
242-
allow_series: None = ...,
243-
) -> Series[IntoSeriesT]: ...
239+
series_only: Literal[False] = ...,
240+
allow_series: Literal[True],
241+
) -> DataFrame[Any] | LazyFrame[Any] | Series[Any]: ...
244242

245243

246244
# All params passed in as variables

narwhals/typing.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,19 @@ def columns(self) -> Any: ...
3535

3636
def join(self, *args: Any, **kwargs: Any) -> Any: ...
3737

38-
class NativeDataFrame(Sized, NativeFrame, Protocol): ...
38+
class NativeDataFrame(Sized, NativeFrame, Protocol):
39+
def drop(self, *args: Any, **kwargs: Any) -> Any: ...
3940

4041
class NativeLazyFrame(NativeFrame, Protocol):
4142
def explain(self, *args: Any, **kwargs: Any) -> Any: ...
4243

44+
# Needs to have something `NativeDataFrame` doesn't?
4345
class NativeSeries(Sized, Iterable[Any], Protocol):
4446
def filter(self, *args: Any, **kwargs: Any) -> Any: ...
47+
# `pd.DataFrame` has this - the others don't
48+
def value_counts(self, *args: Any, **kwargs: Any) -> Any: ...
49+
# `pl.DataFrame` has this - the others don't
50+
def unique(self, *args: Any, **kwargs: Any) -> Any: ...
4551

4652
class SupportsNativeNamespace(Protocol):
4753
def __native_namespace__(self) -> ModuleType: ...

tests/translate/from_native_test.py

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -435,26 +435,57 @@ def test_pyspark_connect_deps_2517(constructor: Constructor) -> None: # pragma:
435435
nw.from_native(spark.createDataFrame([(1,)], ["a"]))
436436

437437

438-
@pytest.mark.parametrize(
439-
("eager_only", "pass_through", "context"),
440-
[
441-
(False, False, does_not_raise()),
442-
(False, True, does_not_raise()),
443-
(True, True, does_not_raise()),
444-
(True, False, pytest.raises(TypeError, match="Cannot only use")),
445-
],
446-
)
447-
def test_eager_only_pass_through_main(
448-
constructor: Constructor, *, eager_only: bool, pass_through: bool, context: Any
449-
) -> None:
438+
def test_eager_only_pass_through_main(constructor: Constructor) -> None:
450439
if not any(s in str(constructor) for s in ("pyspark", "dask", "ibis", "duckdb")):
451440
pytest.skip(reason="Non lazy or polars")
452441

453442
df = constructor(data)
454443

455-
with context:
456-
res = nw.from_native(df, eager_only=eager_only, pass_through=pass_through) # type: ignore[call-overload]
457-
if eager_only and pass_through:
458-
assert not isinstance(res, nw.LazyFrame)
459-
else:
460-
assert isinstance(res, nw.LazyFrame)
444+
r1 = nw.from_native(df, eager_only=False, pass_through=False)
445+
r2 = nw.from_native(df, eager_only=False, pass_through=True)
446+
r3 = nw.from_native(df, eager_only=True, pass_through=True)
447+
448+
assert isinstance(r1, nw.LazyFrame)
449+
assert isinstance(r2, nw.LazyFrame)
450+
assert not isinstance(r3, nw.LazyFrame)
451+
452+
with pytest.raises(TypeError, match=r"Cannot.+use.+eager_only"):
453+
nw.from_native(df, eager_only=True, pass_through=False) # type: ignore[type-var]
454+
455+
456+
def test_from_native_eager_only_series_only_allow() -> None:
457+
pytest.importorskip("polars")
458+
pytest.importorskip("pandas")
459+
import pandas as pd
460+
import polars as pl
461+
462+
pl_ser = pl.Series([1, 2, 3])
463+
pd_ser = pd.Series([1, 2, 3])
464+
465+
s01 = nw.from_native(pl_ser, series_only=True)
466+
s02 = nw.from_native(pl_ser, allow_series=True)
467+
s03 = nw.from_native(pl_ser, eager_only=True, series_only=True)
468+
s04 = nw.from_native(pl_ser, eager_only=True, series_only=True, allow_series=True)
469+
s05 = nw.from_native(pl_ser, eager_only=True, allow_series=True)
470+
s06 = nw.from_native(pl_ser, series_only=True, allow_series=True)
471+
472+
assert isinstance(s01, nw.Series)
473+
assert isinstance(s02, nw.Series)
474+
assert isinstance(s03, nw.Series)
475+
assert isinstance(s04, nw.Series)
476+
assert isinstance(s05, nw.Series)
477+
assert isinstance(s06, nw.Series)
478+
479+
s11 = nw.from_native(pd_ser, series_only=True)
480+
s12 = nw.from_native(pd_ser, allow_series=True)
481+
s13 = nw.from_native(pd_ser, eager_only=True, series_only=True)
482+
s14 = nw.from_native(pd_ser, eager_only=True, series_only=True, allow_series=True)
483+
s15 = nw.from_native(pd_ser, eager_only=True, allow_series=True)
484+
s16 = nw.from_native(pd_ser, series_only=True, allow_series=True)
485+
486+
assert isinstance(s11, nw.Series)
487+
assert isinstance(s12, nw.Series)
488+
assert isinstance(s13, nw.Series)
489+
assert isinstance(s14, nw.Series)
490+
assert isinstance(s15, nw.Series)
491+
assert isinstance(s16, nw.Series)

tests/utils_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,11 @@ def test_maybe_set_index_pandas_direct_index(
132132
df = nw.from_native(native_df_or_series, allow_series=True)
133133
result = nw.maybe_set_index(df, index=narwhals_index)
134134
if isinstance(native_df_or_series, pd.Series):
135+
assert isinstance(result, nw.Series)
135136
native_df_or_series.index = pandas_index # type: ignore[assignment]
136137
assert_series_equal(nw.to_native(result), native_df_or_series)
137138
else:
139+
assert isinstance(result, nw.DataFrame)
138140
expected = native_df_or_series.set_index(pandas_index) # type: ignore[arg-type]
139141
assert_frame_equal(nw.to_native(result), expected)
140142

0 commit comments

Comments
 (0)