Skip to content

Commit 3d0ec2f

Browse files
fix: More aligning modin, pandas dtypes (#2958)
Co-authored-by: FBruzzesi <[email protected]>
1 parent 5bcdb6e commit 3d0ec2f

File tree

2 files changed

+39
-37
lines changed

2 files changed

+39
-37
lines changed

narwhals/_pandas_like/utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,16 @@
129129
"ns": "nanosecond",
130130
}
131131

132+
PANDAS_VERSION = Implementation.PANDAS._backend_version()
133+
"""Static backend version for `pandas`.
134+
135+
Always available if we reached here, due to a module-level import.
136+
"""
137+
138+
139+
def is_pandas_or_modin(implementation: Implementation) -> bool:
140+
return implementation in {Implementation.PANDAS, Implementation.MODIN}
141+
132142

133143
def align_and_extract_native(
134144
lhs: PandasLikeSeries, rhs: PandasLikeSeries | object
@@ -499,10 +509,9 @@ def narwhals_to_native_dtype( # noqa: C901, PLR0912, PLR0915
499509
# or at least, convert_dtypes(dtype_backend='pyarrow') doesn't
500510
# convert to it?
501511
return "category"
502-
backend_version = implementation._backend_version()
503512
if isinstance_or_issubclass(dtype, dtypes.Datetime):
504513
# Pandas does not support "ms" or "us" time units before version 2.0
505-
if implementation is Implementation.PANDAS and backend_version < (
514+
if is_pandas_or_modin(implementation) and PANDAS_VERSION < (
506515
2,
507516
): # pragma: no cover
508517
dt_time_unit = "ns"
@@ -515,7 +524,7 @@ def narwhals_to_native_dtype( # noqa: C901, PLR0912, PLR0915
515524
tz_part = f", {tz}" if (tz := dtype.time_zone) else ""
516525
return f"datetime64[{dt_time_unit}{tz_part}]"
517526
if isinstance_or_issubclass(dtype, dtypes.Duration):
518-
if implementation is Implementation.PANDAS and backend_version < (
527+
if is_pandas_or_modin(implementation) and PANDAS_VERSION < (
519528
2,
520529
): # pragma: no cover
521530
du_time_unit = "ns"
@@ -545,10 +554,7 @@ def narwhals_to_native_dtype( # noqa: C901, PLR0912, PLR0915
545554
if isinstance_or_issubclass(
546555
dtype, (dtypes.Struct, dtypes.Array, dtypes.List, dtypes.Time, dtypes.Binary)
547556
):
548-
if implementation in {
549-
Implementation.PANDAS,
550-
Implementation.MODIN,
551-
} and Implementation.PANDAS._backend_version() >= (2, 2):
557+
if is_pandas_or_modin(implementation) and PANDAS_VERSION >= (2, 2):
552558
try:
553559
import pandas as pd
554560
import pyarrow as pa # ignore-banned-import # noqa: F401

tests/expr_and_series/cast_test.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,15 @@
5959
SPARK_LIKE_INCOMPATIBLE_COLUMNS = {"e", "f", "g", "h", "o", "p"}
6060
DUCKDB_INCOMPATIBLE_COLUMNS = {"o"}
6161
IBIS_INCOMPATIBLE_COLUMNS = {"o"}
62+
MODIN_XFAIL_COLUMNS = {"o", "k"}
6263

6364

6465
@pytest.mark.filterwarnings("ignore:casting period[M] values to int64:FutureWarning")
65-
def test_cast(constructor: Constructor, request: pytest.FixtureRequest) -> None:
66+
def test_cast(constructor: Constructor) -> None:
6667
if "pyarrow_table_constructor" in str(constructor) and PYARROW_VERSION <= (
6768
15,
6869
): # pragma: no cover
6970
pytest.skip()
70-
if "modin_constructor" in str(constructor):
71-
# TODO(unassigned): in modin, we end up with `'<U0'` dtype
72-
request.applymarker(pytest.mark.xfail)
7371

7472
if "pyspark" in str(constructor):
7573
incompatible_columns = SPARK_LIKE_INCOMPATIBLE_COLUMNS # pragma: no cover
@@ -107,8 +105,18 @@ def test_cast(constructor: Constructor, request: pytest.FixtureRequest) -> None:
107105
}
108106
cast_map = {c: t for c, t in cast_map.items() if c not in incompatible_columns}
109107

110-
result = df.select(*[nw.col(col_).cast(dtype) for col_, dtype in cast_map.items()])
111-
assert dict(result.collect_schema()) == cast_map
108+
result = df.select(
109+
*[nw.col(col_).cast(dtype) for col_, dtype in cast_map.items()]
110+
).collect_schema()
111+
112+
for (key, ltype), rtype in zip(result.items(), cast_map.values()):
113+
if "modin_constructor" in str(constructor) and key in MODIN_XFAIL_COLUMNS:
114+
# TODO(unassigned): in modin we end up with `'<U0'` dtype
115+
# This block will act similarly to an xfail i.e. if we fix the issue, the
116+
# assert will fail
117+
assert ltype != rtype
118+
else:
119+
assert ltype == rtype, f"types differ for column {key}: {ltype}!={rtype}"
112120

113121

114122
def test_cast_series(
@@ -118,17 +126,14 @@ def test_cast_series(
118126
15,
119127
): # pragma: no cover
120128
request.applymarker(pytest.mark.xfail)
121-
if "modin_constructor" in str(constructor_eager):
122-
# TODO(unassigned): in modin, we end up with `'<U0'` dtype
123-
request.applymarker(pytest.mark.xfail)
129+
124130
df = (
125131
nw.from_native(constructor_eager(DATA))
126132
.select(nw.col(key).cast(value) for key, value in SCHEMA.items())
127133
.lazy()
128134
.collect()
129135
)
130-
131-
expected = {
136+
cast_map = {
132137
"a": nw.Int32,
133138
"b": nw.Int16,
134139
"c": nw.Int8,
@@ -146,25 +151,16 @@ def test_cast_series(
146151
"o": nw.String,
147152
"p": nw.Duration,
148153
}
149-
result = df.select(
150-
df["a"].cast(nw.Int32),
151-
df["b"].cast(nw.Int16),
152-
df["c"].cast(nw.Int8),
153-
df["d"].cast(nw.Int64),
154-
df["e"].cast(nw.UInt32),
155-
df["f"].cast(nw.UInt16),
156-
df["g"].cast(nw.UInt8),
157-
df["h"].cast(nw.UInt64),
158-
df["i"].cast(nw.Float32),
159-
df["j"].cast(nw.Float64),
160-
df["k"].cast(nw.String),
161-
df["l"].cast(nw.Datetime),
162-
df["m"].cast(nw.Int8),
163-
df["n"].cast(nw.Int8),
164-
df["o"].cast(nw.String),
165-
df["p"].cast(nw.Duration),
166-
)
167-
assert result.schema == expected
154+
result = df.select(df[col_].cast(dtype) for col_, dtype in cast_map.items()).schema
155+
156+
for (key, ltype), rtype in zip(result.items(), cast_map.values()):
157+
if "modin_constructor" in str(constructor_eager) and key in MODIN_XFAIL_COLUMNS:
158+
# TODO(unassigned): in modin we end up with `'<U0'` dtype
159+
# This block will act similarly to an xfail i.e. if we fix the issue, the
160+
# assert will fail
161+
assert ltype != rtype
162+
else:
163+
assert ltype == rtype, f"types differ for column {key}: {ltype}!={rtype}"
168164

169165

170166
def test_cast_string() -> None:

0 commit comments

Comments
 (0)