Skip to content

Commit 8ce8b5d

Browse files
authored
feat: support passing multiple arguments positionally to get_native_namespace (#2178)
1 parent ea213c4 commit 8ce8b5d

File tree

2 files changed

+24
-17
lines changed

2 files changed

+24
-17
lines changed

narwhals/translate.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,6 @@
4040
from narwhals.utils import Version
4141

4242
if TYPE_CHECKING:
43-
import pandas as pd
44-
import polars as pl
45-
import pyarrow as pa
46-
47-
from narwhals._arrow.typing import ArrowChunkedArray
4843
from narwhals.dataframe import DataFrame
4944
from narwhals.dataframe import LazyFrame
5045
from narwhals.series import Series
@@ -790,21 +785,14 @@ def _from_native_impl( # noqa: PLR0915
790785

791786

792787
def get_native_namespace(
793-
obj: DataFrame[Any]
794-
| LazyFrame[Any]
795-
| Series[Any]
796-
| pd.DataFrame
797-
| pd.Series[Any]
798-
| pl.DataFrame
799-
| pl.LazyFrame
800-
| pl.Series
801-
| pa.Table
802-
| ArrowChunkedArray,
788+
*obj: DataFrame[Any] | LazyFrame[Any] | Series[Any] | IntoFrame | IntoSeries,
803789
) -> Any:
804790
"""Get native namespace from object.
805791
806792
Arguments:
807-
obj: Dataframe, Lazyframe, or Series.
793+
obj: Dataframe, Lazyframe, or Series. Multiple objects can be
794+
passed positionally, in which case they must all have the
795+
same native namespace (else an error is raised).
808796
809797
Returns:
810798
Native module.
@@ -820,6 +808,19 @@ def get_native_namespace(
820808
>>> nw.get_native_namespace(df)
821809
<module 'polars'...>
822810
"""
811+
if not obj:
812+
msg = "At least one object must be passed to `get_native_namespace`."
813+
raise ValueError(msg)
814+
result = {_get_native_namespace_single_obj(x) for x in obj}
815+
if len(result) != 1:
816+
msg = f"Found objects with different native namespaces: {result}."
817+
raise ValueError(msg)
818+
return result.pop()
819+
820+
821+
def _get_native_namespace_single_obj(
822+
obj: DataFrame[Any] | LazyFrame[Any] | Series[Any] | IntoFrame | IntoSeries,
823+
) -> Any:
823824
from narwhals.utils import has_native_namespace
824825

825826
if has_native_namespace(obj):

tests/translate/get_native_namespace_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,22 @@ def test_native_namespace() -> None:
1919
assert nw.get_native_namespace(df.to_native()) is pl
2020
assert nw.get_native_namespace(df.lazy().to_native()) is pl
2121
assert nw.get_native_namespace(df["a"].to_native()) is pl
22+
assert nw.get_native_namespace(df, df["a"].to_native()) is pl
2223
df = nw.from_native(pd.DataFrame({"a": [1, 2, 3]}), eager_only=True)
2324
assert nw.get_native_namespace(df) is pd
2425
assert nw.get_native_namespace(df.to_native()) is pd
2526
assert nw.get_native_namespace(df["a"].to_native()) is pd
27+
assert nw.get_native_namespace(df, df["a"].to_native()) is pd
2628
df = nw.from_native(pa.table({"a": [1, 2, 3]}), eager_only=True)
2729
assert nw.get_native_namespace(df) is pa
2830
assert nw.get_native_namespace(df.to_native()) is pa
29-
assert nw.get_native_namespace(df["a"].to_native()) is pa
31+
assert nw.get_native_namespace(df, df["a"].to_native()) is pa
3032

3133

3234
def test_get_native_namespace_invalid() -> None:
3335
with pytest.raises(TypeError, match="Could not get native namespace"):
3436
nw.get_native_namespace(1) # type: ignore[arg-type]
37+
with pytest.raises(ValueError, match="At least one object"):
38+
nw.get_native_namespace()
39+
with pytest.raises(ValueError, match="Found objects with different"):
40+
nw.get_native_namespace(pd.Series([1]), pl.Series([2]))

0 commit comments

Comments
 (0)