Skip to content

Commit 5838dde

Browse files
authored
fix: correctly preserve arrow dtypes for pandas-like, improve concat_str performance (#3193)
1 parent e3fd995 commit 5838dde

File tree

5 files changed

+110
-67
lines changed

5 files changed

+110
-67
lines changed

narwhals/_pandas_like/namespace.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,8 @@ def concat_str(
332332

333333
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
334334
expr_results = [s for _expr in exprs for s in _expr(df)]
335-
align = self._series._align_full_broadcast
336-
series = align(*(s.cast(string) for s in expr_results))
337-
null_mask = align(*(s.is_null() for s in expr_results))
335+
series = [s.cast(string) for s in expr_results]
336+
null_mask = [s.is_null() for s in expr_results]
338337

339338
if not ignore_nulls:
340339
null_mask_result = reduce(operator.or_, null_mask)
@@ -345,15 +344,16 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
345344
# NOTE: Trying to help `mypy` later
346345
# error: Cannot determine type of "values" [has-type]
347346
values: list[PandasLikeSeries]
348-
init_value, *values = [
347+
init_value, *values = (
349348
s.zip_with(~nm, "") for s, nm in zip_strict(series, null_mask)
350-
]
351-
array_funcs = series[0]._array_funcs
352-
sep_array = init_value.from_iterable(
353-
data=array_funcs.repeat(separator, len(init_value)),
354-
name="sep",
355-
index=init_value.native.index,
356-
context=self,
349+
)
350+
sep_array = init_value._with_native(
351+
init_value.__native_namespace__().Series(
352+
separator,
353+
name="sep",
354+
index=init_value.native.index,
355+
dtype=init_value.native.dtype,
356+
)
357357
)
358358
separators = (sep_array.zip_with(~nm, "") for nm in null_mask[:-1])
359359
result = reduce(

narwhals/_pandas_like/series.py

Lines changed: 41 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

3+
import operator
34
import warnings
4-
from typing import TYPE_CHECKING, Any, Literal, cast
5+
from typing import TYPE_CHECKING, Any, Callable, Literal, cast
56

67
import numpy as np
78

@@ -304,6 +305,11 @@ def _scatter_in_place(self, indices: Self, values: Self) -> None:
304305
self.native.iloc[indices.native] = values_native
305306

306307
def cast(self, dtype: IntoDType) -> Self:
308+
if self.dtype == dtype and self.native.dtype != "object":
309+
# Avoid dealing with pandas' type-system if we can. Note that it's only
310+
# safe to do this if we're not starting with object dtype, see tests/expr_and_series/cast_test.py::test_cast_object_pandas
311+
# for an example of why.
312+
return self._with_native(self.native, preserve_broadcast=True)
307313
pd_dtype = narwhals_to_native_dtype(
308314
dtype,
309315
dtype_backend=get_dtype_backend(self.native.dtype, self._implementation),
@@ -387,103 +393,87 @@ def first(self) -> PythonLiteral:
387393
def last(self) -> PythonLiteral:
388394
return self.native.iloc[-1] if len(self.native) else None
389395

396+
def _with_binary(self, op: Callable[..., PandasLikeSeries], other: Any) -> Self:
397+
ser, other_native = align_and_extract_native(self, other)
398+
preserve_broadcast = self._broadcast and getattr(other, "_broadcast", True)
399+
return self._with_native(
400+
op(ser, other_native), preserve_broadcast=preserve_broadcast
401+
).alias(self.name)
402+
403+
def _with_binary_right(self, op: Callable[..., PandasLikeSeries], other: Any) -> Self:
404+
return self._with_binary(lambda x, y: op(y, x), other).alias(self.name)
405+
390406
def __eq__(self, other: object) -> Self: # type: ignore[override]
391-
ser, other = align_and_extract_native(self, other)
392-
return self._with_native(ser == other).alias(self.name)
407+
return self._with_binary(operator.eq, other)
393408

394409
def __ne__(self, other: object) -> Self: # type: ignore[override]
395-
ser, other = align_and_extract_native(self, other)
396-
return self._with_native(ser != other).alias(self.name)
410+
return self._with_binary(operator.ne, other)
397411

398412
def __ge__(self, other: Any) -> Self:
399-
ser, other = align_and_extract_native(self, other)
400-
return self._with_native(ser >= other).alias(self.name)
413+
return self._with_binary(operator.ge, other)
401414

402415
def __gt__(self, other: Any) -> Self:
403-
ser, other = align_and_extract_native(self, other)
404-
return self._with_native(ser > other).alias(self.name)
416+
return self._with_binary(operator.gt, other)
405417

406418
def __le__(self, other: Any) -> Self:
407-
ser, other = align_and_extract_native(self, other)
408-
return self._with_native(ser <= other).alias(self.name)
419+
return self._with_binary(operator.le, other)
409420

410421
def __lt__(self, other: Any) -> Self:
411-
ser, other = align_and_extract_native(self, other)
412-
return self._with_native(ser < other).alias(self.name)
422+
return self._with_binary(operator.lt, other)
413423

414424
def __and__(self, other: Any) -> Self:
415-
ser, other = align_and_extract_native(self, other)
416-
return self._with_native(ser & other).alias(self.name)
425+
return self._with_binary(operator.and_, other)
417426

418427
def __rand__(self, other: Any) -> Self:
419-
ser, other = align_and_extract_native(self, other)
420-
ser = cast("pd.Series[Any]", ser)
421-
return self._with_native(ser.__and__(other)).alias(self.name)
428+
return self._with_binary_right(operator.and_, other)
422429

423430
def __or__(self, other: Any) -> Self:
424-
ser, other = align_and_extract_native(self, other)
425-
return self._with_native(ser | other).alias(self.name)
431+
return self._with_binary(operator.or_, other)
426432

427433
def __ror__(self, other: Any) -> Self:
428-
ser, other = align_and_extract_native(self, other)
429-
ser = cast("pd.Series[Any]", ser)
430-
return self._with_native(ser.__or__(other)).alias(self.name)
434+
return self._with_binary_right(operator.or_, other)
431435

432436
def __add__(self, other: Any) -> Self:
433-
ser, other = align_and_extract_native(self, other)
434-
return self._with_native(ser + other).alias(self.name)
437+
return self._with_binary(operator.add, other)
435438

436439
def __radd__(self, other: Any) -> Self:
437-
_, other_native = align_and_extract_native(self, other)
438-
return self._with_native(self.native.__radd__(other_native)).alias(self.name)
440+
return self._with_binary_right(operator.add, other)
439441

440442
def __sub__(self, other: Any) -> Self:
441-
ser, other = align_and_extract_native(self, other)
442-
return self._with_native(ser - other).alias(self.name)
443+
return self._with_binary(operator.sub, other)
443444

444445
def __rsub__(self, other: Any) -> Self:
445-
_, other_native = align_and_extract_native(self, other)
446-
return self._with_native(self.native.__rsub__(other_native)).alias(self.name)
446+
return self._with_binary_right(operator.sub, other)
447447

448448
def __mul__(self, other: Any) -> Self:
449-
ser, other = align_and_extract_native(self, other)
450-
return self._with_native(ser * other).alias(self.name)
449+
return self._with_binary(operator.mul, other)
451450

452451
def __rmul__(self, other: Any) -> Self:
453-
_, other_native = align_and_extract_native(self, other)
454-
return self._with_native(self.native.__rmul__(other_native)).alias(self.name)
452+
return self._with_binary_right(operator.mul, other)
455453

456454
def __truediv__(self, other: Any) -> Self:
457-
ser, other = align_and_extract_native(self, other)
458-
return self._with_native(ser / other).alias(self.name)
455+
return self._with_binary(operator.truediv, other)
459456

460457
def __rtruediv__(self, other: Any) -> Self:
461-
_, other_native = align_and_extract_native(self, other)
462-
return self._with_native(self.native.__rtruediv__(other_native)).alias(self.name)
458+
return self._with_binary_right(operator.truediv, other)
463459

464460
def __floordiv__(self, other: Any) -> Self:
465-
ser, other = align_and_extract_native(self, other)
466-
return self._with_native(ser // other).alias(self.name)
461+
return self._with_binary(operator.floordiv, other)
467462

468463
def __rfloordiv__(self, other: Any) -> Self:
469-
_, other_native = align_and_extract_native(self, other)
470-
return self._with_native(self.native.__rfloordiv__(other_native)).alias(self.name)
464+
return self._with_binary_right(operator.floordiv, other)
471465

472466
def __pow__(self, other: Any) -> Self:
473-
ser, other = align_and_extract_native(self, other)
474-
return self._with_native(ser**other).alias(self.name)
467+
return self._with_binary(operator.pow, other)
475468

476469
def __rpow__(self, other: Any) -> Self:
477-
_, other_native = align_and_extract_native(self, other)
478-
return self._with_native(self.native.__rpow__(other_native)).alias(self.name)
470+
return self._with_binary_right(operator.pow, other)
479471

480472
def __mod__(self, other: Any) -> Self:
481-
ser, other = align_and_extract_native(self, other)
482-
return self._with_native(ser % other).alias(self.name)
473+
return self._with_binary(operator.mod, other)
483474

484475
def __rmod__(self, other: Any) -> Self:
485-
_, other_native = align_and_extract_native(self, other)
486-
return self._with_native(self.native.__rmod__(other_native)).alias(self.name)
476+
return self._with_binary_right(operator.mod, other)
487477

488478
# Unary
489479

narwhals/_pandas_like/utils.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,6 @@ def is_dtype_pyarrow(dtype: Any) -> TypeIs[pd.ArrowDtype]:
469469
None: "uint16",
470470
},
471471
dtypes.UInt8: {"pyarrow": "UInt8[pyarrow]", "numpy_nullable": "UInt8", None: "uint8"},
472-
dtypes.String: {"pyarrow": "string[pyarrow]", "numpy_nullable": "string", None: str},
473472
dtypes.Boolean: {
474473
"pyarrow": "boolean[pyarrow]",
475474
"numpy_nullable": "boolean",
@@ -494,6 +493,25 @@ def narwhals_to_native_dtype( # noqa: C901, PLR0912
494493
return pd_type
495494
if into_pd_type := NW_TO_PD_DTYPES_BACKEND.get(base_type):
496495
return into_pd_type[dtype_backend]
496+
if issubclass(base_type, dtypes.String):
497+
if dtype_backend == "pyarrow":
498+
import pyarrow as pa # ignore-banned-import
499+
500+
# Note: this is different from `string[pyarrow]`, even though the repr
501+
# looks the same.
502+
# >>> pd.DataFrame({'a':['foo']}, dtype='string[pyarrow]')['a'].str.len()
503+
# 0 3
504+
# Name: a, dtype: Int64
505+
# >>> pd.DataFrame({'a':['foo']}, dtype=pd.ArrowDtype(pa.string()))['a'].str.len()
506+
# 0 3
507+
# Name: a, dtype: int32[pyarrow]
508+
#
509+
# `ArrowDType(pa.string())` is what `.convert_dtypes(dtype_backend='pyarrow')` converts to,
510+
# so we use that here.
511+
return pd.ArrowDtype(pa.string())
512+
if dtype_backend == "numpy_nullable":
513+
return "string"
514+
return str
497515
if isinstance_or_issubclass(dtype, dtypes.Datetime):
498516
if is_pandas_or_modin(implementation) and PANDAS_VERSION < (
499517
2,
@@ -533,7 +551,7 @@ def narwhals_to_native_dtype( # noqa: C901, PLR0912
533551
)
534552
if isinstance_or_issubclass(dtype, dtypes.Date):
535553
try:
536-
import pyarrow as pa # ignore-banned-import # noqa: F401
554+
import pyarrow as pa # ignore-banned-import
537555
except ModuleNotFoundError as exc: # pragma: no cover
538556
# BUG: Never re-raised?
539557
msg = "'pyarrow>=13.0.0' is required for `Date` dtype."

tests/expr_and_series/cast_test.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
Constructor,
1414
ConstructorEager,
1515
assert_equal_data,
16-
is_windows,
16+
is_pyarrow_windows_no_tzdata,
1717
time_unit_compat,
1818
)
1919

@@ -209,11 +209,13 @@ def test_cast_datetime_tz_aware(
209209
"dask" in str(constructor)
210210
or "duckdb" in str(constructor)
211211
or "cudf" in str(constructor) # https://github.com/rapidsai/cudf/issues/16973
212-
or ("pyarrow_table" in str(constructor) and is_windows())
213212
or "pyspark" in str(constructor)
214213
or "ibis" in str(constructor)
215214
):
216215
request.applymarker(pytest.mark.xfail)
216+
request.applymarker(
217+
pytest.mark.xfail(is_pyarrow_windows_no_tzdata(constructor), reason="no tzdata")
218+
)
217219

218220
data = {
219221
"date": [
@@ -239,9 +241,11 @@ def test_cast_datetime_utc(
239241
"dask" in str(constructor)
240242
# https://github.com/eakmanrq/sqlframe/issues/406
241243
or "sqlframe" in str(constructor)
242-
or ("pyarrow_table" in str(constructor) and is_windows())
243244
):
244245
request.applymarker(pytest.mark.xfail)
246+
request.applymarker(
247+
pytest.mark.xfail(is_pyarrow_windows_no_tzdata(constructor), reason="no tzdata")
248+
)
245249

246250
data = {
247251
"date": [
@@ -394,3 +398,29 @@ def test_cast_typing_invalid() -> None:
394398

395399
with pytest.raises(AttributeError):
396400
df.select(a.cast(nw.Array(nw.List, 2))) # type: ignore[arg-type]
401+
402+
403+
@pytest.mark.skipif(PANDAS_VERSION < (2,), reason="too old for pyarrow")
404+
def test_pandas_pyarrow_dtypes() -> None:
405+
s = nw.from_native(
406+
pd.Series([123, None]).convert_dtypes(dtype_backend="pyarrow"), series_only=True
407+
).cast(nw.String)
408+
result = s.str.len_chars().to_native()
409+
assert result.dtype == "Int32[pyarrow]"
410+
411+
s = nw.from_native(
412+
pd.Series([123, None], dtype="string[pyarrow]"), series_only=True
413+
).cast(nw.String)
414+
result = s.str.len_chars().to_native()
415+
assert result.dtype == "Int64"
416+
417+
s = nw.from_native(
418+
pd.DataFrame({"a": ["foo", "bar"]}, dtype="string[pyarrow]")
419+
).select(nw.col("a").cast(nw.String))["a"]
420+
assert s.to_native().dtype == "string[pyarrow]"
421+
422+
423+
def test_cast_object_pandas() -> None:
424+
s = nw.from_native(pd.DataFrame({"a": [2, 3, None]}, dtype=object))["a"]
425+
assert s[0] == 2
426+
assert s.cast(nw.String)[0] == "2"

tests/frame/schema_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,11 @@ def test_all_nulls_pandas() -> None:
387387
def test_schema_to_pandas(
388388
dtype_backend: DTypeBackend | Sequence[DTypeBackend] | None, expected: dict[str, Any]
389389
) -> None:
390+
if (
391+
dtype_backend == "pyarrow"
392+
or (isinstance(dtype_backend, list) and "pyarrow" in dtype_backend)
393+
) and PANDAS_VERSION < (1, 5):
394+
pytest.skip()
390395
schema = nw.Schema(
391396
{
392397
"a": nw.Int64(),

0 commit comments

Comments
 (0)