Skip to content

Commit a0e6b9e

Browse files
authored
fix: Include all backends in get_native_namespace (#2608)
* fix: Include all backends in get_native_namespace * test typing and pyarrow_table specification * no cover cudf and None return in test
1 parent 8e65c33 commit a0e6b9e

File tree

2 files changed

+69
-50
lines changed

2 files changed

+69
-50
lines changed

narwhals/translate.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,18 @@
1212
is_native_spark_like,
1313
)
1414
from narwhals.dependencies import (
15-
get_cudf,
1615
get_dask,
1716
get_dask_expr,
18-
get_modin,
1917
get_numpy,
2018
get_pandas,
21-
get_polars,
22-
get_pyarrow,
23-
is_cudf_dataframe,
24-
is_cudf_series,
2519
is_cupy_scalar,
2620
is_dask_dataframe,
2721
is_duckdb_relation,
2822
is_ibis_table,
29-
is_modin_dataframe,
30-
is_modin_series,
3123
is_numpy_scalar,
32-
is_pandas_dataframe,
3324
is_pandas_like_dataframe,
34-
is_pandas_series,
35-
is_polars_dataframe,
3625
is_polars_lazyframe,
3726
is_polars_series,
38-
is_pyarrow_chunked_array,
3927
is_pyarrow_scalar,
4028
is_pyarrow_table,
4129
)
@@ -602,25 +590,20 @@ def get_native_namespace(
602590
return result.pop()
603591

604592

605-
def _get_native_namespace_single_obj( # noqa: PLR0911
593+
def _get_native_namespace_single_obj(
606594
obj: DataFrame[Any] | LazyFrame[Any] | Series[Any] | IntoFrame | IntoSeries,
607595
) -> Any:
596+
from contextlib import suppress
597+
608598
from narwhals.utils import has_native_namespace
609599

600+
with suppress(TypeError, AssertionError):
601+
return Version.MAIN.namespace.from_native_object(
602+
obj
603+
).implementation.to_native_namespace()
604+
610605
if has_native_namespace(obj):
611606
return obj.__native_namespace__()
612-
if is_pandas_dataframe(obj) or is_pandas_series(obj):
613-
return get_pandas()
614-
if is_modin_dataframe(obj) or is_modin_series(obj): # pragma: no cover
615-
return get_modin()
616-
if is_pyarrow_table(obj) or is_pyarrow_chunked_array(obj):
617-
return get_pyarrow()
618-
if is_cudf_dataframe(obj) or is_cudf_series(obj): # pragma: no cover
619-
return get_cudf()
620-
if is_dask_dataframe(obj): # pragma: no cover
621-
return get_dask()
622-
if is_polars_dataframe(obj) or is_polars_lazyframe(obj) or is_polars_series(obj):
623-
return get_polars()
624607
msg = f"Could not get native namespace from object of type: {type(obj)}"
625608
raise TypeError(msg)
626609

tests/translate/get_native_namespace_test.py

Lines changed: 61 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,82 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Any
44

55
import pytest
66

77
import narwhals as nw
88

99
if TYPE_CHECKING:
1010
from narwhals.typing import Frame
11+
from tests.utils import Constructor
1112

1213

13-
def test_native_namespace_polars() -> None:
14-
pytest.importorskip("polars")
15-
import polars as pl
14+
data = {"a": [1, 2, 3]}
1615

17-
df: Frame = nw.from_native(pl.DataFrame({"a": [1, 2, 3]}))
18-
assert nw.get_native_namespace(df) is pl
19-
assert nw.get_native_namespace(df.to_native()) is pl
20-
assert nw.get_native_namespace(df.lazy().to_native()) is pl
21-
assert nw.get_native_namespace(df["a"].to_native()) is pl
22-
assert nw.get_native_namespace(df, df["a"].to_native()) is pl
2316

17+
def _get_expected_namespace(constructor_name: str) -> Any | None: # noqa: PLR0911
18+
"""Get expected namespace module for a given constructor."""
19+
if "pandas" in constructor_name:
20+
import pandas as pd
21+
22+
return pd
23+
elif "polars" in constructor_name:
24+
import polars as pl
25+
26+
return pl
27+
elif "pyarrow_table" in constructor_name:
28+
import pyarrow as pa
29+
30+
return pa
31+
elif "duckdb" in constructor_name:
32+
import duckdb
33+
34+
return duckdb
35+
elif "cudf" in constructor_name: # pragma: no cover
36+
import cudf
37+
38+
return cudf
39+
elif "modin" in constructor_name:
40+
import modin.pandas as mpd
41+
42+
return mpd
43+
elif "dask" in constructor_name:
44+
import dask.dataframe as dd
45+
46+
return dd
47+
elif "ibis" in constructor_name:
48+
import ibis
49+
50+
return ibis
51+
elif "sqlframe" in constructor_name:
52+
import sqlframe
53+
54+
return sqlframe
55+
return None # pragma: no cover
56+
57+
58+
def test_native_namespace_frame(constructor: Constructor) -> None:
59+
constructor_name = constructor.__name__
60+
if constructor_name == "pyspark_lazy_constructor":
61+
pytest.skip(reason="Requires special handling for spark local vs spark connect")
62+
63+
expected_namespace = _get_expected_namespace(constructor_name=constructor_name)
64+
65+
df: Frame = nw.from_native(constructor(data))
66+
assert nw.get_native_namespace(df) is expected_namespace
67+
assert nw.get_native_namespace(df.to_native()) is expected_namespace
68+
assert nw.get_native_namespace(df.lazy().to_native()) is expected_namespace
2469

25-
def test_native_namespace_pandas() -> None:
26-
pytest.importorskip("pandas")
27-
import pandas as pd
2870

29-
df: Frame = nw.from_native(pd.DataFrame({"a": [1, 2, 3]}), eager_only=True)
30-
assert nw.get_native_namespace(df) is pd
31-
assert nw.get_native_namespace(df.to_native()) is pd
32-
assert nw.get_native_namespace(df["a"].to_native()) is pd
33-
assert nw.get_native_namespace(df, df["a"].to_native()) is pd
71+
def test_native_namespace_series(constructor_eager: Constructor) -> None:
72+
constructor_name = constructor_eager.__name__
3473

74+
expected_namespace = _get_expected_namespace(constructor_name=constructor_name)
3575

36-
def test_native_namespace_pyarrow() -> None:
37-
pytest.importorskip("pyarrow")
38-
import pyarrow as pa
76+
df: Frame = nw.from_native(constructor_eager(data), eager_only=True)
3977

40-
df: Frame = nw.from_native(pa.table({"a": [1, 2, 3]}), eager_only=True)
41-
assert nw.get_native_namespace(df) is pa
42-
assert nw.get_native_namespace(df.to_native()) is pa
43-
assert nw.get_native_namespace(df, df["a"].to_native()) is pa
78+
assert nw.get_native_namespace(df["a"].to_native()) is expected_namespace
79+
assert nw.get_native_namespace(df, df["a"].to_native()) is expected_namespace
4480

4581

4682
def test_get_native_namespace_invalid() -> None:

0 commit comments

Comments
 (0)