Skip to content

Commit 2467f6e

Browse files
committed
Patch ops
1 parent c0bdd67 commit 2467f6e

File tree

4 files changed

+70
-10
lines changed

4 files changed

+70
-10
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,16 @@
5050
pandas_dtype,
5151
)
5252
from pandas.core.dtypes.dtypes import DatetimeTZDtype
53+
from pandas.core.dtypes.generic import (
54+
ABCDataFrame,
55+
ABCIndex,
56+
ABCSeries,
57+
)
5358
from pandas.core.dtypes.missing import isna
5459

5560
from pandas.core import (
5661
algorithms as algos,
62+
arraylike,
5763
missing,
5864
ops,
5965
roperator,
@@ -752,6 +758,39 @@ def __array__(
752758

753759
return self.to_numpy(dtype=dtype, copy=copy)
754760

761+
def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
762+
if any(
763+
isinstance(other, (ABCSeries, ABCIndex, ABCDataFrame)) for other in inputs
764+
):
765+
return NotImplemented
766+
767+
result = arraylike.maybe_dispatch_ufunc_to_dunder_op(
768+
self, ufunc, method, *inputs, **kwargs
769+
)
770+
if result is not NotImplemented:
771+
return result
772+
773+
if "out" in kwargs:
774+
return arraylike.dispatch_ufunc_with_out(
775+
self, ufunc, method, *inputs, **kwargs
776+
)
777+
778+
if method == "reduce":
779+
result = arraylike.dispatch_reduction_ufunc(
780+
self, ufunc, method, *inputs, **kwargs
781+
)
782+
if result is not NotImplemented:
783+
return result
784+
785+
if self.dtype.kind == "f":
786+
# e.g. test_log_arrow_backed_missing_value
787+
new_inputs = [
788+
x if x is not self else x.to_numpy(na_value=np.nan) for x in inputs
789+
]
790+
return getattr(ufunc, method)(*new_inputs, **kwargs)
791+
792+
return arraylike.default_array_ufunc(self, ufunc, method, *inputs, **kwargs)
793+
755794
def __invert__(self) -> Self:
756795
# This is a bit wise op for integer types
757796
if pa.types.is_integer(self._pa_array.type):
@@ -923,7 +962,13 @@ def _logical_method(self, other, op) -> Self:
923962
return self._evaluate_op_method(other, op, ARROW_LOGICAL_FUNCS)
924963

925964
def _arith_method(self, other, op) -> Self:
926-
return self._evaluate_op_method(other, op, ARROW_ARITHMETIC_FUNCS)
965+
result = self._evaluate_op_method(other, op, ARROW_ARITHMETIC_FUNCS)
966+
if is_nan_na() and result.dtype.kind == "f":
967+
parr = result._pa_array
968+
mask = pc.is_nan(parr).to_numpy()
969+
arr = pc.replace_with_mask(parr, mask, pa.scalar(None, type=parr.type))
970+
result = type(self)(arr)
971+
return result
927972

928973
def equals(self, other) -> bool:
929974
if not isinstance(other, ArrowExtensionArray):

pandas/core/arrays/base.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2539,14 +2539,6 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
25392539
if result is not NotImplemented:
25402540
return result
25412541

2542-
# TODO: putting this here is hacky as heck
2543-
if self.dtype == "float64[pyarrow]":
2544-
# e.g. test_log_arrow_backed_missing_value
2545-
new_inputs = [
2546-
x if x is not self else x.to_numpy(na_value=np.nan) for x in inputs
2547-
]
2548-
return getattr(ufunc, method)(*new_inputs, **kwargs)
2549-
25502542
return arraylike.default_array_ufunc(self, ufunc, method, *inputs, **kwargs)
25512543

25522544
def map(self, mapper, na_action: Literal["ignore"] | None = None):

pandas/tests/extension/test_arrow.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3589,3 +3589,26 @@ def test_timestamp_dtype_matches_to_datetime():
35893589
expected = pd.Series([ts], dtype=dtype1).convert_dtypes(dtype_backend="pyarrow")
35903590

35913591
tm.assert_series_equal(result, expected)
3592+
3593+
3594+
def test_ops_with_nan_is_na(using_nan_is_na):
3595+
# GH#61732
3596+
ser = pd.Series([-1, 0, 1], dtype="int64[pyarrow]")
3597+
3598+
result = ser - np.nan
3599+
if using_nan_is_na:
3600+
assert result.isna().all()
3601+
else:
3602+
assert not result.isna().any()
3603+
3604+
result = ser * np.nan
3605+
if using_nan_is_na:
3606+
assert result.isna().all()
3607+
else:
3608+
assert not result.isna().any()
3609+
3610+
result = ser / 0
3611+
if using_nan_is_na:
3612+
assert result.isna()[1]
3613+
else:
3614+
assert not result.isna()[1]

pandas/tests/series/test_npfuncs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_numpy_argwhere(index):
3838

3939

4040
@td.skip_if_no("pyarrow")
41-
def test_log_arrow_backed_missing_value():
41+
def test_log_arrow_backed_missing_value(using_nan_is_na):
4242
# GH#56285
4343
ser = Series([1, 2, None], dtype="float64[pyarrow]")
4444
result = np.log(ser)

0 commit comments

Comments
 (0)