Skip to content

Commit ccc0689

Browse files
committed
POC: consistent NaN treatment for pyarrow dtypes
1 parent 8e86751 commit ccc0689

File tree

7 files changed

+81
-19
lines changed

7 files changed

+81
-19
lines changed

pandas/_libs/parsers.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1461,7 +1461,7 @@ def _maybe_upcast(
14611461
if isinstance(arr, IntegerArray) and arr.isna().all():
14621462
# use null instead of int64 in pyarrow
14631463
arr = arr.to_numpy(na_value=None)
1464-
arr = ArrowExtensionArray(pa.array(arr, from_pandas=True))
1464+
arr = ArrowExtensionArray(pa.array(arr))
14651465

14661466
return arr
14671467

pandas/core/arrays/arrow/array.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818

1919
from pandas._libs import lib
20+
from pandas._libs.missing import NA
2021
from pandas._libs.tslibs import (
2122
Timedelta,
2223
Timestamp,
@@ -353,7 +354,7 @@ def _from_sequence_of_strings(
353354
# duration to string casting behavior
354355
mask = isna(scalars)
355356
if not isinstance(strings, (pa.Array, pa.ChunkedArray)):
356-
strings = pa.array(strings, type=pa.string(), from_pandas=True)
357+
strings = pa.array(strings, type=pa.string())
357358
strings = pc.if_else(mask, None, strings)
358359
try:
359360
scalars = strings.cast(pa.int64())
@@ -374,7 +375,7 @@ def _from_sequence_of_strings(
374375
if isinstance(strings, (pa.Array, pa.ChunkedArray)):
375376
scalars = strings
376377
else:
377-
scalars = pa.array(strings, type=pa.string(), from_pandas=True)
378+
scalars = pa.array(strings, type=pa.string())
378379
scalars = pc.if_else(pc.equal(scalars, "1.0"), "1", scalars)
379380
scalars = pc.if_else(pc.equal(scalars, "0.0"), "0", scalars)
380381
scalars = scalars.cast(pa.bool_())
@@ -386,6 +387,13 @@ def _from_sequence_of_strings(
386387
from pandas.core.tools.numeric import to_numeric
387388

388389
scalars = to_numeric(strings, errors="raise")
390+
if not pa.types.is_decimal(pa_type):
391+
# TODO: figure out why doing this cast breaks with decimal dtype
392+
# in test_from_sequence_of_strings_pa_array
393+
mask = strings.is_null()
394+
scalars = pa.array(scalars, mask=np.array(mask), type=pa_type)
395+
# TODO: could we just do strings.cast(pa_type)?
396+
389397
else:
390398
raise NotImplementedError(
391399
f"Converting strings to {pa_type} is not implemented."
@@ -428,7 +436,7 @@ def _box_pa_scalar(cls, value, pa_type: pa.DataType | None = None) -> pa.Scalar:
428436
"""
429437
if isinstance(value, pa.Scalar):
430438
pa_scalar = value
431-
elif isna(value):
439+
elif isna(value) and not lib.is_float(value):
432440
pa_scalar = pa.scalar(None, type=pa_type)
433441
else:
434442
# Workaround https://github.com/apache/arrow/issues/37291
@@ -445,7 +453,7 @@ def _box_pa_scalar(cls, value, pa_type: pa.DataType | None = None) -> pa.Scalar:
445453
value = value.as_unit(pa_type.unit)
446454
value = value._value
447455

448-
pa_scalar = pa.scalar(value, type=pa_type, from_pandas=True)
456+
pa_scalar = pa.scalar(value, type=pa_type)
449457

450458
if pa_type is not None and pa_scalar.type != pa_type:
451459
pa_scalar = pa_scalar.cast(pa_type)
@@ -477,6 +485,13 @@ def _box_pa_array(
477485
if copy:
478486
value = value.copy()
479487
pa_array = value.__arrow_array__()
488+
489+
elif hasattr(value, "__arrow_array__"):
490+
# e.g. StringArray
491+
if copy:
492+
value = value.copy()
493+
pa_array = value.__arrow_array__()
494+
480495
else:
481496
if (
482497
isinstance(value, np.ndarray)
@@ -530,19 +545,32 @@ def _box_pa_array(
530545
pa_array = pa.array(dta._ndarray, type=pa_type, mask=dta_mask)
531546
return pa_array
532547

548+
mask = None
549+
if getattr(value, "dtype", None) is None or value.dtype.kind not in "mfM":
550+
# similar to isna(value) but exclude NaN
551+
# TODO: cythonize!
552+
mask = np.array([x is NA or x is None for x in value], dtype=bool)
553+
554+
from_pandas = False
555+
if pa.types.is_integer(pa_type):
556+
# If user specifically asks to cast a numpy float array with NaNs
557+
# to pyarrow integer, we'll treat those NaNs as NA
558+
from_pandas = True
533559
try:
534-
pa_array = pa.array(value, type=pa_type, from_pandas=True)
560+
pa_array = pa.array(
561+
value, type=pa_type, mask=mask, from_pandas=from_pandas
562+
)
535563
except (pa.ArrowInvalid, pa.ArrowTypeError):
536564
# GH50430: let pyarrow infer type, then cast
537-
pa_array = pa.array(value, from_pandas=True)
565+
pa_array = pa.array(value, mask=mask, from_pandas=from_pandas)
538566

539567
if pa_type is None and pa.types.is_duration(pa_array.type):
540568
# Workaround https://github.com/apache/arrow/issues/37291
541569
from pandas.core.tools.timedeltas import to_timedelta
542570

543571
value = to_timedelta(value)
544572
value = value.to_numpy()
545-
pa_array = pa.array(value, type=pa_type, from_pandas=True)
573+
pa_array = pa.array(value, type=pa_type)
546574

547575
if pa.types.is_duration(pa_array.type) and pa_array.null_count > 0:
548576
# GH52843: upstream bug for duration types when originally
@@ -1208,7 +1236,7 @@ def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
12081236
if not len(values):
12091237
return np.zeros(len(self), dtype=bool)
12101238

1211-
result = pc.is_in(self._pa_array, value_set=pa.array(values, from_pandas=True))
1239+
result = pc.is_in(self._pa_array, value_set=pa.array(values))
12121240
# pyarrow 2.0.0 returned nulls, so we explicitly specify dtype to convert nulls
12131241
# to False
12141242
return np.array(result, dtype=np.bool_)
@@ -2015,7 +2043,7 @@ def __setitem__(self, key, value) -> None:
20152043
raise ValueError("Length of indexer and values mismatch")
20162044
chunks = [
20172045
*self._pa_array[:key].chunks,
2018-
pa.array([value], type=self._pa_array.type, from_pandas=True),
2046+
pa.array([value], type=self._pa_array.type),
20192047
*self._pa_array[key + 1 :].chunks,
20202048
]
20212049
data = pa.chunked_array(chunks).combine_chunks()
@@ -2069,7 +2097,7 @@ def _rank_calc(
20692097
pa_type = pa.float64()
20702098
else:
20712099
pa_type = pa.uint64()
2072-
result = pa.array(ranked, type=pa_type, from_pandas=True)
2100+
result = pa.array(ranked, type=pa_type)
20732101
return result
20742102

20752103
data = self._pa_array.combine_chunks()
@@ -2321,7 +2349,7 @@ def _to_numpy_and_type(value) -> tuple[np.ndarray, pa.DataType | None]:
23212349
right, right_type = _to_numpy_and_type(right)
23222350
pa_type = left_type or right_type
23232351
result = np.where(cond, left, right)
2324-
return pa.array(result, type=pa_type, from_pandas=True)
2352+
return pa.array(result, type=pa_type)
23252353

23262354
@classmethod
23272355
def _replace_with_mask(
@@ -2364,7 +2392,7 @@ def _replace_with_mask(
23642392
replacements = replacements.as_py()
23652393
result = np.array(values, dtype=object)
23662394
result[mask] = replacements
2367-
return pa.array(result, type=values.type, from_pandas=True)
2395+
return pa.array(result, type=values.type)
23682396

23692397
# ------------------------------------------------------------------
23702398
# GroupBy Methods
@@ -2443,7 +2471,7 @@ def _groupby_op(
24432471
return type(self)(pa_result)
24442472
else:
24452473
# DatetimeArray, TimedeltaArray
2446-
pa_result = pa.array(result, from_pandas=True)
2474+
pa_result = pa.array(result)
24472475
return type(self)(pa_result)
24482476

24492477
def _apply_elementwise(self, func: Callable) -> list[list[Any]]:

pandas/core/arrays/string_.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,12 @@ def _str_map_str_or_object(
502502
if self.dtype.storage == "pyarrow":
503503
import pyarrow as pa
504504

505+
# TODO: shouldn't this already be caught my passed mask?
506+
# it isn't in test_extract_expand_capture_groups_index
507+
# mask = mask | np.array(
508+
# [x is libmissing.NA for x in result], dtype=bool
509+
# )
510+
505511
result = pa.array(
506512
result, mask=mask, type=pa.large_string(), from_pandas=True
507513
)
@@ -754,7 +760,7 @@ def __arrow_array__(self, type=None):
754760

755761
values = self._ndarray.copy()
756762
values[self.isna()] = None
757-
return pa.array(values, type=type, from_pandas=True)
763+
return pa.array(values, type=type)
758764

759765
def _values_for_factorize(self) -> tuple[np.ndarray, libmissing.NAType | float]: # type: ignore[override]
760766
arr = self._ndarray

pandas/core/generic.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9921,7 +9921,7 @@ def where(
99219921
def where(
99229922
self,
99239923
cond,
9924-
other=np.nan,
9924+
other=lib.no_default,
99259925
*,
99269926
inplace: bool = False,
99279927
axis: Axis | None = None,
@@ -10079,6 +10079,23 @@ def where(
1007910079
stacklevel=2,
1008010080
)
1008110081

10082+
if other is lib.no_default:
10083+
if self.ndim == 1:
10084+
if isinstance(self.dtype, ExtensionDtype):
10085+
other = self.dtype.na_value
10086+
else:
10087+
other = np.nan
10088+
else:
10089+
if self._mgr.nblocks == 1 and isinstance(
10090+
self._mgr.blocks[0].values.dtype, ExtensionDtype
10091+
):
10092+
# FIXME: checking this is kludgy!
10093+
other = self._mgr.blocks[0].values.dtype.na_value
10094+
else:
10095+
# FIXME: the same problem we had with Series will now
10096+
# show up column-by-column!
10097+
other = np.nan
10098+
1008210099
other = common.apply_if_callable(other, self)
1008310100
return self._where(cond, other, inplace=inplace, axis=axis, level=level)
1008410101

pandas/tests/extension/test_arrow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,7 @@ def test_EA_types(self, engine, data, dtype_backend, request):
721721
pytest.mark.xfail(reason="CSV parsers don't correctly handle binary")
722722
)
723723
df = pd.DataFrame({"with_dtype": pd.Series(data, dtype=str(data.dtype))})
724-
csv_output = df.to_csv(index=False, na_rep=np.nan)
724+
csv_output = df.to_csv(index=False, na_rep=np.nan) # should be NA?
725725
if pa.types.is_binary(pa_dtype):
726726
csv_output = BytesIO(csv_output)
727727
else:

pandas/tests/groupby/test_reductions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,10 @@ def test_first_last_skipna(any_real_nullable_dtype, sort, skipna, how):
381381
df = DataFrame(
382382
{
383383
"a": [2, 1, 1, 2, 3, 3],
384-
"b": [na_value, 3.0, na_value, 4.0, np.nan, np.nan],
385-
"c": [na_value, 3.0, na_value, 4.0, np.nan, np.nan],
384+
# TODO: test that has mixed na_value and NaN either working for
385+
# float or raising for int?
386+
"b": [na_value, 3.0, na_value, 4.0, na_value, na_value],
387+
"c": [na_value, 3.0, na_value, 4.0, na_value, na_value],
386388
},
387389
dtype=any_real_nullable_dtype,
388390
)

pandas/tests/series/methods/test_rank.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,13 @@ def test_rank_tie_methods(self, ser, results, dtype, using_infer_string):
275275

276276
ser = ser if dtype is None else ser.astype(dtype)
277277
result = ser.rank(method=method)
278+
if dtype == "float64[pyarrow]":
279+
# the NaNs are not treated as NA
280+
exp = exp.copy()
281+
if method == "average":
282+
exp[np.isnan(ser)] = 9.5
283+
elif method == "dense":
284+
exp[np.isnan(ser)] = 6
278285
tm.assert_series_equal(result, Series(exp, dtype=expected_dtype(dtype, method)))
279286

280287
@pytest.mark.parametrize("na_option", ["top", "bottom", "keep"])
@@ -320,6 +327,8 @@ def test_rank_tie_methods_on_infs_nans(
320327
order = [ranks[1], ranks[0], ranks[2]]
321328
elif na_option == "bottom":
322329
order = [ranks[0], ranks[2], ranks[1]]
330+
elif dtype == "float64[pyarrow]":
331+
order = [ranks[0], [NA] * chunk, ranks[1]]
323332
else:
324333
order = [ranks[0], [np.nan] * chunk, ranks[1]]
325334
expected = order if ascending else order[::-1]

0 commit comments

Comments
 (0)