Skip to content

Commit 022a3d9

Browse files
chore: Add CompliantNamespace.is_native (#3130)
Co-authored-by: Marco Edward Gorelli <[email protected]>
1 parent 24cf843 commit 022a3d9

File tree

4 files changed

+81
-8
lines changed

4 files changed

+81
-8
lines changed

narwhals/_compliant/namespace.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
EagerSeriesT,
1414
LazyExprT,
1515
NativeFrameT,
16-
NativeFrameT_co,
1716
NativeSeriesT,
1817
)
1918
from narwhals._expression_parsing import is_expr, is_series
@@ -27,7 +26,7 @@
2726
if TYPE_CHECKING:
2827
from collections.abc import Container, Iterable, Sequence
2928

30-
from typing_extensions import TypeAlias
29+
from typing_extensions import TypeAlias, TypeIs
3130

3231
from narwhals._compliant.selectors import CompliantSelectorNamespace
3332
from narwhals._compliant.when_then import CompliantWhen, EagerWhen
@@ -115,6 +114,9 @@ def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ...
115114
def coalesce(self, *exprs: CompliantExprT) -> CompliantExprT: ...
116115
# NOTE: typing this accurately requires 2x more `TypeVar`s
117116
def from_native(self, data: Any, /) -> Any: ...
117+
def is_native(self, obj: Any, /) -> TypeIs[Any]:
118+
"""Return `True` if `obj` can be passed to `from_native`."""
119+
...
118120

119121

120122
class DepthTrackingNamespace(
@@ -141,16 +143,18 @@ def exclude(self, excluded_names: Container[str]) -> DepthTrackingExprT:
141143

142144
class LazyNamespace(
143145
CompliantNamespace[CompliantLazyFrameT, LazyExprT],
144-
Protocol[CompliantLazyFrameT, LazyExprT, NativeFrameT_co],
146+
Protocol[CompliantLazyFrameT, LazyExprT, NativeFrameT],
145147
):
146148
@property
147149
def _backend_version(self) -> tuple[int, ...]:
148150
return self._implementation._backend_version()
149151

150152
@property
151153
def _lazyframe(self) -> type[CompliantLazyFrameT]: ...
154+
def is_native(self, obj: Any, /) -> TypeIs[NativeFrameT]:
155+
return self._lazyframe._is_native(obj)
152156

153-
def from_native(self, data: NativeFrameT_co | Any, /) -> CompliantLazyFrameT:
157+
def from_native(self, data: NativeFrameT | Any, /) -> CompliantLazyFrameT:
154158
if self._lazyframe._is_native(data):
155159
return self._lazyframe.from_native(data, context=self)
156160
msg = f"Unsupported type: {type(data).__name__!r}" # pragma: no cover
@@ -173,6 +177,9 @@ def when(
173177
self, predicate: EagerExprT
174178
) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT]: ...
175179

180+
def is_native(self, obj: Any, /) -> TypeIs[NativeFrameT | NativeSeriesT]:
181+
return self._dataframe._is_native(obj) or self._series._is_native(obj)
182+
176183
@overload
177184
def from_native(self, data: NativeFrameT, /) -> EagerDataFrameT: ...
178185
@overload

narwhals/_polars/namespace.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from collections.abc import Iterable, Sequence
1818
from datetime import timezone
1919

20+
from typing_extensions import TypeIs
21+
2022
from narwhals._compliant import CompliantSelectorNamespace, CompliantWhen
2123
from narwhals._polars.dataframe import Method, PolarsDataFrame, PolarsLazyFrame
2224
from narwhals._polars.typing import FrameT
@@ -101,6 +103,9 @@ def parse_into_expr(
101103
return self.col(data)
102104
return self.lit(data.to_native() if is_series(data) else data, None)
103105

106+
def is_native(self, obj: Any) -> TypeIs[pl.DataFrame | pl.LazyFrame | pl.Series]:
107+
return isinstance(obj, (pl.DataFrame, pl.LazyFrame, pl.Series))
108+
104109
@overload
105110
def from_native(self, data: pl.DataFrame, /) -> PolarsDataFrame: ...
106111
@overload

narwhals/_sql/namespace.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import TYPE_CHECKING, Any, Protocol
66

77
from narwhals._compliant import LazyNamespace
8-
from narwhals._compliant.typing import NativeExprT, NativeFrameT_co
8+
from narwhals._compliant.typing import NativeExprT, NativeFrameT
99
from narwhals._sql.typing import SQLExprT, SQLLazyFrameT
1010

1111
if TYPE_CHECKING:
@@ -15,8 +15,8 @@
1515

1616

1717
class SQLNamespace(
18-
LazyNamespace[SQLLazyFrameT, SQLExprT, NativeFrameT_co],
19-
Protocol[SQLLazyFrameT, SQLExprT, NativeFrameT_co, NativeExprT],
18+
LazyNamespace[SQLLazyFrameT, SQLExprT, NativeFrameT],
19+
Protocol[SQLLazyFrameT, SQLExprT, NativeFrameT, NativeExprT],
2020
):
2121
def _function(self, name: str, *args: NativeExprT | PythonLiteral) -> NativeExprT: ...
2222
def _lit(self, value: Any) -> NativeExprT: ...

tests/namespace_test.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from narwhals._utils import Version
1313

1414
if TYPE_CHECKING:
15-
from typing_extensions import TypeAlias, assert_type
15+
from typing_extensions import Never, TypeAlias, assert_type # noqa: F401
1616

1717
from narwhals._arrow.namespace import ArrowNamespace # noqa: F401
1818
from narwhals._compliant import CompliantNamespace
@@ -177,3 +177,64 @@ class NamespaceNoVersion(Namespace): ... # type: ignore[call-arg, type-arg]
177177
with pytest.raises(TypeError, match=re.compile(r"Expected.+Version.+but got.+str")):
178178

179179
class NamespaceBadVersion(Namespace, version="invalid version"): ... # type: ignore[arg-type, type-arg]
180+
181+
182+
def test_namespace_is_native() -> None:
183+
pytest.importorskip("polars")
184+
import polars as pl
185+
186+
unrelated: list[int] = [1, 2, 3]
187+
native_1 = pl.Series(unrelated)
188+
native_2 = pl.DataFrame({"a": unrelated})
189+
190+
maybe_native: list[pl.Series | list[int]] = [native_1, unrelated]
191+
always_native = list["pl.DataFrame | pl.Series"]((native_2, native_1))
192+
never_native = [unrelated, 50]
193+
194+
expected_maybe = [True, False]
195+
expected_always = [True, True]
196+
expected_never = [False, False]
197+
198+
ns = Namespace.from_backend("polars").compliant
199+
assert [ns.is_native(el) for el in maybe_native] == expected_maybe
200+
assert [ns.is_native(el) for el in always_native] == expected_always
201+
assert [ns.is_native(el) for el in never_native] == expected_never
202+
203+
if TYPE_CHECKING:
204+
if ns.is_native(native_1):
205+
assert_type(native_1, "pl.Series")
206+
if not ns.is_native(native_1):
207+
assert_type(native_1, "Never")
208+
209+
if ns.is_native(unrelated):
210+
# NOTE: We can't spell intersections *yet* (https://github.com/python/typing/issues/213)
211+
# Would be:
212+
# `<subclass of list[int] and DataFrame> | <subclass of list[int] and LazyFrame> | <subclass of list[int] and Series>``
213+
assert_type(unrelated, "Never") # pyright: ignore[reportAssertTypeFailure]
214+
else:
215+
assert_type(unrelated, "list[int]")
216+
217+
maybe_item = maybe_native[0]
218+
assert_type(maybe_item, "pl.Series | list[int]")
219+
if ns.is_native(maybe_item):
220+
assert_type(maybe_item, "pl.Series")
221+
else:
222+
assert_type(maybe_item, "list[int]")
223+
224+
if ns.is_native(native_2):
225+
assert_type(native_2, "pl.DataFrame")
226+
else:
227+
assert_type(native_2, "Never")
228+
229+
always_item = always_native[1]
230+
assert_type(always_item, "pl.DataFrame | pl.Series")
231+
if ns.is_native(always_item):
232+
assert_type(always_item, "pl.DataFrame | pl.Series")
233+
if ns._dataframe._is_native(always_item):
234+
assert_type(always_item, "pl.DataFrame")
235+
elif ns._series._is_native(always_item):
236+
assert_type(always_item, "pl.Series")
237+
else:
238+
assert_type(always_item, "Never")
239+
else:
240+
assert_type(always_item, "Never")

0 commit comments

Comments
 (0)