Skip to content

Commit de0d274

Browse files
committed
BUG: String[pyarrow] comparison with mixed object
1 parent f6fa9b6 commit de0d274

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -829,12 +829,17 @@ def _cmp_method(self, other, op) -> ArrowExtensionArray:
829829
pc_func = ARROW_CMP_FUNCS[op.__name__]
830830
if isinstance(other, (ExtensionArray, np.ndarray, list)):
831831
try:
832-
result = pc_func(self._pa_array, self._box_pa(other))
833-
except pa.ArrowNotImplementedError:
834-
# TODO: could this be wrong if other is object dtype?
835-
# in which case we need to operate pointwise?
836-
result = ops.invalid_comparison(self, other, op)
837-
result = pa.array(result, type=pa.bool_())
832+
boxed = self._box_pa(other)
833+
except pa.lib.ArrowInvalid:
834+
# e.g. GH#60228 [1, "b"] we have to operate pointwise
835+
res_values = [op(x, y) for x, y in zip(self, other)]
836+
result = pa.array(res_values, type=pa.bool_(), from_pandas=True)
837+
else:
838+
try:
839+
result = pc_func(self._pa_array, boxed)
840+
except pa.ArrowNotImplementedError:
841+
result = ops.invalid_comparison(self, other, op)
842+
result = pa.array(result, type=pa.bool_())
838843
elif is_scalar(other):
839844
try:
840845
result = pc_func(self._pa_array, self._box_pa(other))

pandas/tests/extension/test_string.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,3 +288,19 @@ def test_searchsorted_with_na_raises(data_for_sorting, as_series):
288288
)
289289
with pytest.raises(ValueError, match=msg):
290290
arr.searchsorted(b)
291+
292+
293+
def test_mixed_object_comparison(dtype):
294+
# GH#60228
295+
ser = pd.Series(["a", "b"], dtype=dtype)
296+
297+
mixed = pd.Series([1, "b"], dtype=object)
298+
299+
result = ser == mixed
300+
expected = pd.Series([False, True], dtype=bool)
301+
if dtype.storage == "python" and dtype.na_value is pd.NA:
302+
expected = expected.astype("boolean")
303+
elif dtype.storage == "pyarrow" and dtype.na_value is pd.NA:
304+
expected = expected.astype("bool[pyarrow]")
305+
306+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)