Skip to content

Commit de2e99c

Browse files
authored
fix(typing): Narrow NativeSeries Protocol (#2159)
Closes #2111
1 parent dc924f5 commit de2e99c

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

narwhals/typing.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727

2828
if TYPE_CHECKING:
2929
from types import ModuleType
30+
from typing import Iterable
3031
from typing import Mapping
32+
from typing import Sized
3133

3234
import numpy as np
3335
from typing_extensions import Self
@@ -53,8 +55,8 @@ def columns(self) -> Any: ...
5355

5456
def join(self, *args: Any, **kwargs: Any) -> Any: ...
5557

56-
class NativeSeries(Protocol):
57-
def __len__(self) -> int: ...
58+
class NativeSeries(Sized, Iterable[Any], Protocol):
59+
def filter(self, *args: Any, **kwargs: Any) -> Any: ...
5860

5961
class DataFrameLike(Protocol):
6062
def __dataframe__(self, *args: Any, **kwargs: Any) -> Any: ...

tests/translate/from_native_test.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,16 +292,20 @@ def __dataframe__(self) -> None: # pragma: no cover
292292
assert result is mockdf
293293

294294

295-
def test_from_native_altair_array_like() -> None:
295+
def test_from_native_strict_native_series() -> None:
296296
obj: list[int] = [1, 2, 3, 4]
297297
array_like = cast("Iterable[Any]", obj)
298298
not_array_like: Literal[1] = 1
299+
np_array = pl.Series(obj).to_numpy()
299300

300301
with pytest.raises(TypeError, match="got.+list"):
301-
false_positive_native_series = nw.from_native(obj, series_only=True) # noqa: F841
302+
nw.from_native(obj, series_only=True) # type: ignore[call-overload]
302303

303304
with pytest.raises(TypeError, match="got.+list"):
304-
true_negative_iterable = nw.from_native(array_like, series_only=True) # type: ignore[call-overload] # noqa: F841
305+
nw.from_native(array_like, series_only=True) # type: ignore[call-overload]
305306

306307
with pytest.raises(TypeError, match="got.+int"):
307-
true_negative_not_native_series = nw.from_native(not_array_like, series_only=True) # type: ignore[call-overload] # noqa: F841
308+
nw.from_native(not_array_like, series_only=True) # type: ignore[call-overload]
309+
310+
with pytest.raises(TypeError, match="got.+numpy.ndarray"):
311+
nw.from_native(np_array, series_only=True) # type: ignore[call-overload]

0 commit comments

Comments
 (0)