Skip to content

Commit 7865972

Browse files
authored
fix(typing): Match lazy @overload in nw.from_native(...) (#2211)
1 parent 84a64ab commit 7865972

File tree

6 files changed

+76
-8
lines changed

6 files changed

+76
-8
lines changed

narwhals/functions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
from narwhals.typing import IntoExpr
5454
from narwhals.typing import IntoFrameT
5555
from narwhals.typing import IntoSeriesT
56+
from narwhals.typing import NativeFrame
57+
from narwhals.typing import NativeLazyFrame
5658
from narwhals.typing import _2DArray
5759

5860
class ArrowStreamExportable(Protocol):
@@ -790,6 +792,7 @@ def _read_csv_impl(
790792
) -> DataFrame[Any]:
791793
eager_backend = Implementation.from_backend(backend)
792794
native_namespace = eager_backend.to_native_namespace()
795+
native_frame: NativeFrame
793796
if eager_backend in {
794797
Implementation.POLARS,
795798
Implementation.PANDAS,
@@ -851,6 +854,7 @@ def _scan_csv_impl(
851854
source: str, *, native_namespace: ModuleType, **kwargs: Any
852855
) -> LazyFrame[Any]:
853856
implementation = Implementation.from_native_namespace(native_namespace)
857+
native_frame: NativeFrame | NativeLazyFrame
854858
if implementation is Implementation.POLARS:
855859
native_frame = native_namespace.scan_csv(source, **kwargs)
856860
elif implementation in {
@@ -917,6 +921,7 @@ def _read_parquet_impl(
917921
source: str, *, native_namespace: ModuleType, **kwargs: Any
918922
) -> DataFrame[Any]:
919923
implementation = Implementation.from_native_namespace(native_namespace)
924+
native_frame: NativeFrame
920925
if implementation in {
921926
Implementation.POLARS,
922927
Implementation.PANDAS,
@@ -980,6 +985,7 @@ def _scan_parquet_impl(
980985
source: str, *, native_namespace: ModuleType, **kwargs: Any
981986
) -> LazyFrame[Any]:
982987
implementation = Implementation.from_native_namespace(native_namespace)
988+
native_frame: NativeFrame | NativeLazyFrame
983989
if implementation is Implementation.POLARS:
984990
native_frame = native_namespace.scan_parquet(source, **kwargs)
985991
elif implementation in {

narwhals/stable/v1/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
from narwhals.functions import ArrowStreamExportable
9292
from narwhals.typing import IntoExpr
9393
from narwhals.typing import IntoFrame
94+
from narwhals.typing import IntoLazyFrameT
9495
from narwhals.typing import IntoSeries
9596
from narwhals.typing import _1DArray
9697
from narwhals.typing import _2DArray
@@ -1326,6 +1327,19 @@ def from_native(
13261327
) -> Series[IntoSeriesT]: ...
13271328

13281329

1330+
@overload
1331+
def from_native(
1332+
native_object: IntoLazyFrameT,
1333+
*,
1334+
strict: Literal[True] = ...,
1335+
eager_only: Literal[False] = ...,
1336+
eager_or_interchange_only: Literal[False] = ...,
1337+
series_only: Literal[False] = ...,
1338+
allow_series: None = ...,
1339+
) -> LazyFrame[IntoLazyFrameT]: ...
1340+
1341+
1342+
# NOTE: `pl.LazyFrame` originally matched here
13291343
@overload
13301344
def from_native(
13311345
native_object: IntoFrameT,

narwhals/translate.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from narwhals.typing import IntoDataFrameT
4747
from narwhals.typing import IntoFrame
4848
from narwhals.typing import IntoFrameT
49+
from narwhals.typing import IntoLazyFrameT
4950
from narwhals.typing import IntoSeries
5051
from narwhals.typing import IntoSeriesT
5152

@@ -194,13 +195,13 @@ def from_native(
194195

195196
@overload
196197
def from_native(
197-
native_object: IntoFrameT | IntoSeriesT,
198+
native_object: IntoFrameT | IntoLazyFrameT | IntoSeriesT,
198199
*,
199200
pass_through: Literal[True],
200201
eager_only: Literal[False] = ...,
201202
series_only: Literal[False] = ...,
202203
allow_series: Literal[True],
203-
) -> DataFrame[IntoFrameT] | LazyFrame[IntoFrameT] | Series[IntoSeriesT]: ...
204+
) -> DataFrame[IntoFrameT] | LazyFrame[IntoLazyFrameT] | Series[IntoSeriesT]: ...
204205

205206

206207
@overload
@@ -214,6 +215,21 @@ def from_native(
214215
) -> Series[IntoSeriesT]: ...
215216

216217

218+
# NOTE: Seems like `mypy` is giving a false positive
219+
# Following this advice will introduce overlapping overloads?
220+
# > note: Flipping the order of overloads will fix this error
221+
@overload
222+
def from_native( # type: ignore[overload-overlap]
223+
native_object: IntoLazyFrameT,
224+
*,
225+
pass_through: Literal[False] = ...,
226+
eager_only: Literal[False] = ...,
227+
series_only: Literal[False] = ...,
228+
allow_series: None = ...,
229+
) -> LazyFrame[IntoLazyFrameT]: ...
230+
231+
232+
# NOTE: `pl.LazyFrame` originally matched here
217233
@overload
218234
def from_native(
219235
native_object: IntoDataFrameT,
@@ -260,13 +276,13 @@ def from_native(
260276

261277
@overload
262278
def from_native(
263-
native_object: IntoFrameT,
279+
native_object: IntoFrameT | IntoLazyFrameT,
264280
*,
265281
pass_through: Literal[False] = ...,
266282
eager_only: Literal[False] = ...,
267283
series_only: Literal[False] = ...,
268284
allow_series: None = ...,
269-
) -> DataFrame[IntoFrameT] | LazyFrame[IntoFrameT]: ...
285+
) -> DataFrame[IntoFrameT] | LazyFrame[IntoLazyFrameT]: ...
270286

271287

272288
# All params passed in as variables
@@ -282,14 +298,14 @@ def from_native(
282298

283299

284300
def from_native(
285-
native_object: IntoFrameT | IntoSeriesT | IntoFrame | IntoSeries | T,
301+
native_object: IntoLazyFrameT | IntoFrameT | IntoSeriesT | IntoFrame | IntoSeries | T,
286302
*,
287303
strict: bool | None = None,
288304
pass_through: bool | None = None,
289305
eager_only: bool = False,
290306
series_only: bool = False,
291307
allow_series: bool | None = None,
292-
) -> LazyFrame[IntoFrameT] | DataFrame[IntoFrameT] | Series[IntoSeriesT] | T:
308+
) -> LazyFrame[IntoLazyFrameT] | DataFrame[IntoFrameT] | Series[IntoSeriesT] | T:
293309
"""Convert `native_object` to Narwhals Dataframe, Lazyframe, or Series.
294310
295311
Arguments:

narwhals/typing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ def columns(self) -> Any: ...
3434

3535
def join(self, *args: Any, **kwargs: Any) -> Any: ...
3636

37+
class NativeLazyFrame(NativeFrame, Protocol):
38+
def explain(self, *args: Any, **kwargs: Any) -> Any: ...
39+
3740
class NativeSeries(Sized, Iterable[Any], Protocol):
3841
def filter(self, *args: Any, **kwargs: Any) -> Any: ...
3942

@@ -67,6 +70,8 @@ def __native_namespace__(self) -> ModuleType: ...
6770
... return df.shape
6871
"""
6972

73+
IntoLazyFrame: TypeAlias = "NativeLazyFrame | LazyFrame[Any]"
74+
7075
IntoFrame: TypeAlias = Union[
7176
"NativeFrame", "DataFrame[Any]", "LazyFrame[Any]", "DataFrameLike"
7277
]
@@ -140,6 +145,8 @@ def __native_namespace__(self) -> ModuleType: ...
140145
... return df.with_columns(c=df["a"] + 1).to_native()
141146
"""
142147

148+
IntoLazyFrameT = TypeVar("IntoLazyFrameT", bound="IntoLazyFrame")
149+
143150
FrameT = TypeVar("FrameT", bound="Frame")
144151
"""TypeVar bound to Narwhals DataFrame or Narwhals LazyFrame.
145152
@@ -168,6 +175,8 @@ def __native_namespace__(self) -> ModuleType: ...
168175
... return df.with_columns(c=df["a"] + 1)
169176
"""
170177

178+
LazyFrameT = TypeVar("LazyFrameT", bound="LazyFrame[Any]")
179+
171180
IntoSeriesT = TypeVar("IntoSeriesT", bound="IntoSeries")
172181
"""TypeVar bound to object convertible to Narwhals Series.
173182

tests/frame/invalid_test.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from typing import TYPE_CHECKING
4+
from typing import TypeVar
45

56
import pandas as pd
67
import polars as pl
@@ -16,6 +17,9 @@
1617
from narwhals.typing import Frame
1718

1819

20+
T = TypeVar("T")
21+
22+
1923
@pytest.mark.skipif(
2024
POLARS_VERSION < (1,), reason="Polars would raise unrecoverable panic."
2125
)
@@ -70,12 +74,18 @@ def test_memmap() -> None:
7074
pytest.importorskip("sklearn")
7175
# the headache this caused me...
7276
from sklearn.utils import check_X_y
73-
from sklearn.utils._testing import create_memmap_backed_data
77+
78+
if TYPE_CHECKING:
79+
80+
def create_memmap_backed_data(data: T) -> T:
81+
return data
82+
else:
83+
from sklearn.utils._testing import create_memmap_backed_data
7484

7585
x_any = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
7686
y_any = create_memmap_backed_data(x_any["b"])
7787

78-
x_any, y_any = create_memmap_backed_data([x_any, y_any])
88+
x_any, y_any = create_memmap_backed_data((x_any, y_any))
7989

8090
x = nw.from_native(x_any)
8191
x = x.with_columns(y=nw.from_native(y_any, series_only=True))

tests/translate/from_native_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,16 @@ def test_from_native_strict_native_series() -> None:
312312

313313
with pytest.raises(TypeError, match="got.+numpy.ndarray"):
314314
nw.from_native(np_array, series_only=True) # type: ignore[call-overload]
315+
316+
317+
def test_from_native_lazyframe() -> None:
318+
stable_lazy = nw.from_native(lf_pl)
319+
unstable_lazy = unstable_nw.from_native(lf_pl)
320+
if TYPE_CHECKING:
321+
from typing_extensions import assert_type
322+
323+
assert_type(stable_lazy, nw.LazyFrame[pl.LazyFrame])
324+
assert_type(unstable_lazy, unstable_nw.LazyFrame[pl.LazyFrame])
325+
326+
assert isinstance(stable_lazy, nw.LazyFrame)
327+
assert isinstance(unstable_lazy, unstable_nw.LazyFrame)

0 commit comments

Comments
 (0)