|
12 | 12 | from narwhals._utils import Version |
13 | 13 |
|
14 | 14 | if TYPE_CHECKING: |
15 | | - from typing_extensions import TypeAlias, assert_type |
| 15 | + from typing_extensions import Never, TypeAlias, assert_type # noqa: F401 |
16 | 16 |
|
17 | 17 | from narwhals._arrow.namespace import ArrowNamespace # noqa: F401 |
18 | 18 | from narwhals._compliant import CompliantNamespace |
@@ -177,3 +177,64 @@ class NamespaceNoVersion(Namespace): ... # type: ignore[call-arg, type-arg] |
177 | 177 | with pytest.raises(TypeError, match=re.compile(r"Expected.+Version.+but got.+str")): |
178 | 178 |
|
179 | 179 | class NamespaceBadVersion(Namespace, version="invalid version"): ... # type: ignore[arg-type, type-arg] |
| 180 | + |
| 181 | + |
| 182 | +def test_namespace_is_native() -> None: |
| 183 | + pytest.importorskip("polars") |
| 184 | + import polars as pl |
| 185 | + |
| 186 | + unrelated: list[int] = [1, 2, 3] |
| 187 | + native_1 = pl.Series(unrelated) |
| 188 | + native_2 = pl.DataFrame({"a": unrelated}) |
| 189 | + |
| 190 | + maybe_native: list[pl.Series | list[int]] = [native_1, unrelated] |
| 191 | + always_native = list["pl.DataFrame | pl.Series"]((native_2, native_1)) |
| 192 | + never_native = [unrelated, 50] |
| 193 | + |
| 194 | + expected_maybe = [True, False] |
| 195 | + expected_always = [True, True] |
| 196 | + expected_never = [False, False] |
| 197 | + |
| 198 | + ns = Namespace.from_backend("polars").compliant |
| 199 | + assert [ns.is_native(el) for el in maybe_native] == expected_maybe |
| 200 | + assert [ns.is_native(el) for el in always_native] == expected_always |
| 201 | + assert [ns.is_native(el) for el in never_native] == expected_never |
| 202 | + |
| 203 | + if TYPE_CHECKING: |
| 204 | + if ns.is_native(native_1): |
| 205 | + assert_type(native_1, "pl.Series") |
| 206 | + if not ns.is_native(native_1): |
| 207 | + assert_type(native_1, "Never") |
| 208 | + |
| 209 | + if ns.is_native(unrelated): |
| 210 | + # NOTE: We can't spell intersections *yet* (https://github.com/python/typing/issues/213) |
| 211 | + # Would be: |
| 212 | + # `<subclass of list[int] and DataFrame> | <subclass of list[int] and LazyFrame> | <subclass of list[int] and Series>`` |
| 213 | + assert_type(unrelated, "Never") # pyright: ignore[reportAssertTypeFailure] |
| 214 | + else: |
| 215 | + assert_type(unrelated, "list[int]") |
| 216 | + |
| 217 | + maybe_item = maybe_native[0] |
| 218 | + assert_type(maybe_item, "pl.Series | list[int]") |
| 219 | + if ns.is_native(maybe_item): |
| 220 | + assert_type(maybe_item, "pl.Series") |
| 221 | + else: |
| 222 | + assert_type(maybe_item, "list[int]") |
| 223 | + |
| 224 | + if ns.is_native(native_2): |
| 225 | + assert_type(native_2, "pl.DataFrame") |
| 226 | + else: |
| 227 | + assert_type(native_2, "Never") |
| 228 | + |
| 229 | + always_item = always_native[1] |
| 230 | + assert_type(always_item, "pl.DataFrame | pl.Series") |
| 231 | + if ns.is_native(always_item): |
| 232 | + assert_type(always_item, "pl.DataFrame | pl.Series") |
| 233 | + if ns._dataframe._is_native(always_item): |
| 234 | + assert_type(always_item, "pl.DataFrame") |
| 235 | + elif ns._series._is_native(always_item): |
| 236 | + assert_type(always_item, "pl.Series") |
| 237 | + else: |
| 238 | + assert_type(always_item, "Never") |
| 239 | + else: |
| 240 | + assert_type(always_item, "Never") |
0 commit comments