Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
infer_dtype_from_scalar,
)
from pandas.core.dtypes.common import (
CategoricalDtype,
is_array_like,
is_bool_dtype,
is_float_dtype,
Expand Down Expand Up @@ -730,9 +729,7 @@ def __setstate__(self, state) -> None:

def _cmp_method(self, other, op) -> ArrowExtensionArray:
pc_func = ARROW_CMP_FUNCS[op.__name__]
if isinstance(
other, (ArrowExtensionArray, np.ndarray, list, BaseMaskedArray)
) or isinstance(getattr(other, "dtype", None), CategoricalDtype):
if isinstance(other, (ExtensionArray, np.ndarray, list)):
try:
result = pc_func(self._pa_array, self._box_pa(other))
except pa.ArrowNotImplementedError:
Expand Down
25 changes: 24 additions & 1 deletion pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,30 @@ def searchsorted(
return super().searchsorted(value=value, side=side, sorter=sorter)

def _cmp_method(self, other, op):
from pandas.arrays import BooleanArray
from pandas.arrays import (
ArrowExtensionArray,
BooleanArray,
)

if (
isinstance(other, BaseStringArray)
and self.dtype.na_value is not libmissing.NA
and other.dtype.na_value is libmissing.NA
):
# NA has priority of NaN semantics
return NotImplemented

if isinstance(other, ArrowExtensionArray):
if isinstance(other, BaseStringArray):
# pyarrow storage has priority over python storage
# (except if we have NA semantics and other not)
if not (
self.dtype.na_value is libmissing.NA
and other.dtype.na_value is not libmissing.NA
):
return NotImplemented
else:
return NotImplemented

if isinstance(other, StringArray):
other = other._ndarray
Expand Down
8 changes: 8 additions & 0 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,14 @@ def value_counts(self, dropna: bool = True) -> Series:
return result

def _cmp_method(self, other, op):
if (
isinstance(other, BaseStringArray)
and self.dtype.na_value is not libmissing.NA
and other.dtype.na_value is libmissing.NA
):
# NA has priority of NaN semantics
return NotImplemented

result = super()._cmp_method(other, op)
if self.dtype.na_value is np.nan:
if op == operator.ne:
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/ops/invalid.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

def invalid_comparison(
left: ArrayLike,
right: ArrayLike | Scalar,
right: ArrayLike | list | Scalar,
op: Callable[[Any, Any], bool],
) -> npt.NDArray[np.bool_]:
"""
Expand Down
66 changes: 56 additions & 10 deletions pandas/tests/arrays/string_/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from pandas._config import using_string_dtype

from pandas.compat import HAS_PYARROW
from pandas.compat.pyarrow import (
pa_version_under12p0,
pa_version_under19p0,
Expand Down Expand Up @@ -45,6 +46,25 @@ def cls(dtype):
return dtype.construct_array_type()


def string_dtype_highest_priority(dtype1, dtype2):
if HAS_PYARROW:
DTYPE_HIERARCHY = [
pd.StringDtype("python", na_value=np.nan),
pd.StringDtype("pyarrow", na_value=np.nan),
pd.StringDtype("python", na_value=pd.NA),
pd.StringDtype("pyarrow", na_value=pd.NA),
]
else:
DTYPE_HIERARCHY = [
pd.StringDtype("python", na_value=np.nan),
pd.StringDtype("python", na_value=pd.NA),
]

h1 = DTYPE_HIERARCHY.index(dtype1)
h2 = DTYPE_HIERARCHY.index(dtype2)
return DTYPE_HIERARCHY[max(h1, h2)]


def test_dtype_constructor():
pytest.importorskip("pyarrow")

Expand Down Expand Up @@ -319,25 +339,55 @@ def test_comparison_methods_scalar_not_string(comparison_op, dtype):
tm.assert_extension_array_equal(result, expected)


def test_comparison_methods_array(comparison_op, dtype):
def test_comparison_methods_array(comparison_op, dtype, dtype2):
op_name = f"__{comparison_op.__name__}__"

a = pd.array(["a", None, "c"], dtype=dtype)
other = [None, None, "c"]
result = getattr(a, op_name)(other)
if dtype.na_value is np.nan:
other = pd.array([None, None, "c"], dtype=dtype2)
result = comparison_op(a, other)

# ensure operation is commutative
result2 = comparison_op(other, a)
tm.assert_equal(result, result2)

if dtype.na_value is np.nan and dtype2.na_value is np.nan:
if operator.ne == comparison_op:
expected = np.array([True, True, False])
else:
expected = np.array([False, False, False])
expected[-1] = getattr(other[-1], op_name)(a[-1])
tm.assert_numpy_array_equal(result, expected)

result = getattr(a, op_name)(pd.NA)
else:
max_dtype = string_dtype_highest_priority(dtype, dtype2)
if max_dtype.storage == "python":
expected_dtype = "boolean"
else:
expected_dtype = "bool[pyarrow]"

expected = np.full(len(a), fill_value=None, dtype="object")
expected[-1] = getattr(other[-1], op_name)(a[-1])
expected = pd.array(expected, dtype=expected_dtype)
tm.assert_extension_array_equal(result, expected)


def test_comparison_methods_list(comparison_op, dtype):
op_name = f"__{comparison_op.__name__}__"

a = pd.array(["a", None, "c"], dtype=dtype)
other = [None, None, "c"]
result = comparison_op(a, other)

# ensure operation is commutative
result2 = comparison_op(other, a)
tm.assert_equal(result, result2)

if dtype.na_value is np.nan:
if operator.ne == comparison_op:
expected = np.array([True, True, True])
expected = np.array([True, True, False])
else:
expected = np.array([False, False, False])
expected[-1] = getattr(other[-1], op_name)(a[-1])
tm.assert_numpy_array_equal(result, expected)

else:
Expand All @@ -347,10 +397,6 @@ def test_comparison_methods_array(comparison_op, dtype):
expected = pd.array(expected, dtype=expected_dtype)
tm.assert_extension_array_equal(result, expected)

result = getattr(a, op_name)(pd.NA)
expected = pd.array([None, None, None], dtype=expected_dtype)
tm.assert_extension_array_equal(result, expected)
Comment on lines -362 to -364
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this case of comparing with NA, we already have a dedicated test just above, so removing it here



def test_constructor_raises(cls):
if cls is pd.arrays.StringArray:
Expand Down
10 changes: 8 additions & 2 deletions pandas/tests/extension/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pandas.api.types import is_string_dtype
from pandas.core.arrays import ArrowStringArray
from pandas.core.arrays.string_ import StringDtype
from pandas.tests.arrays.string_.test_string import string_dtype_highest_priority
from pandas.tests.extension import base


Expand Down Expand Up @@ -202,10 +203,13 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
dtype = cast(StringDtype, tm.get_dtype(obj))
if op_name in ["__add__", "__radd__"]:
cast_to = dtype
dtype_other = tm.get_dtype(other) if not isinstance(other, str) else None
if isinstance(dtype_other, StringDtype):
cast_to = string_dtype_highest_priority(dtype, dtype_other)
elif dtype.na_value is np.nan:
cast_to = np.bool_ # type: ignore[assignment]
elif dtype.storage == "pyarrow":
cast_to = "boolean[pyarrow]" # type: ignore[assignment]
cast_to = "bool[pyarrow]" # type: ignore[assignment]
else:
cast_to = "boolean" # type: ignore[assignment]
return pointwise_result.astype(cast_to)
Expand Down Expand Up @@ -237,9 +241,11 @@ def test_arith_series_with_array(
using_infer_string
and all_arithmetic_operators == "__radd__"
and (
(dtype.na_value is pd.NA) or (dtype.storage == "python" and HAS_PYARROW)
dtype.na_value is pd.NA
and not (not HAS_PYARROW and dtype.storage == "python")
)
):
# TODO(infer_string)
mark = pytest.mark.xfail(
reason="The pointwise operation result will be inferred to "
"string[nan, pyarrow], which does not match the input dtype"
Expand Down
Loading