diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index f3082ad9464d7..f43d13bf25e9d 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -914,6 +914,8 @@ Datetimelike - Bug in :meth:`to_datetime` reports incorrect index in case of any failure scenario. (:issue:`58298`) - Bug in :meth:`to_datetime` with ``format="ISO8601"`` and ``utc=True`` where naive timestamps incorrectly inherited timezone offset from previous timestamps in a series. (:issue:`61389`) - Bug in :meth:`to_datetime` wrongly converts when ``arg`` is a ``np.datetime64`` object with unit of ``ps``. (:issue:`60341`) +- Bug in comparison between objects with ``np.datetime64`` dtype and ``timestamp[pyarrow]`` dtypes incorrectly raising ``TypeError`` (:issue:`60937`) +- Bug in comparison between objects with pyarrow date dtype and ``timestamp[pyarrow]`` or ``np.datetime64`` dtype failing to consider these as non-comparable (:issue:`62157`) - Bug in constructing arrays with :class:`ArrowDtype` with ``timestamp`` type incorrectly allowing ``Decimal("NaN")`` (:issue:`61773`) - Bug in constructing arrays with a timezone-aware :class:`ArrowDtype` from timezone-naive datetime objects incorrectly treating those as UTC times instead of wall times like :class:`DatetimeTZDtype` (:issue:`61775`) - Bug in setting scalar values with mismatched resolution into arrays with non-nanosecond ``datetime64``, ``timedelta64`` or :class:`DatetimeTZDtype` incorrectly truncating those scalars (:issue:`56410`) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 653a900fbfe45..2eed608908440 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1,5 +1,9 @@ from __future__ import annotations +from datetime import ( + date, + datetime, +) import functools import operator from pathlib import Path @@ -827,28 +831,46 @@ def __setstate__(self, state) -> None: def _cmp_method(self, other, op) -> ArrowExtensionArray: pc_func = ARROW_CMP_FUNCS[op.__name__] + ltype = self._pa_array.type + if isinstance(other, (ExtensionArray, np.ndarray, list)): - try: - result = pc_func(self._pa_array, self._box_pa(other)) - except pa.ArrowNotImplementedError: - # TODO: could this be wrong if other is object dtype? - # in which case we need to operate pointwise? + boxed = self._box_pa(other) + rtype = boxed.type + if (pa.types.is_timestamp(ltype) and pa.types.is_date(rtype)) or ( + pa.types.is_timestamp(rtype) and pa.types.is_date(ltype) + ): + # GH#62157 match non-pyarrow behavior result = ops.invalid_comparison(self, other, op) result = pa.array(result, type=pa.bool_()) - elif is_scalar(other): - try: - result = pc_func(self._pa_array, self._box_pa(other)) - except (pa.lib.ArrowNotImplementedError, pa.lib.ArrowInvalid): - mask = isna(self) | isna(other) - valid = ~mask - result = np.zeros(len(self), dtype="bool") - np_array = np.array(self) + else: try: - result[valid] = op(np_array[valid], other) - except TypeError: + result = pc_func(self._pa_array, boxed) + except pa.ArrowNotImplementedError: + # TODO: could this be wrong if other is object dtype? + # in which case we need to operate pointwise? result = ops.invalid_comparison(self, other, op) + result = pa.array(result, type=pa.bool_()) + elif is_scalar(other): + if (isinstance(other, datetime) and pa.types.is_date(ltype)) or ( + type(other) is date and pa.types.is_timestamp(ltype) + ): + # GH#62157 match non-pyarrow behavior + result = ops.invalid_comparison(self, other, op) result = pa.array(result, type=pa.bool_()) - result = pc.if_else(valid, result, None) + else: + try: + result = pc_func(self._pa_array, self._box_pa(other)) + except (pa.lib.ArrowNotImplementedError, pa.lib.ArrowInvalid): + mask = isna(self) | isna(other) + valid = ~mask + result = np.zeros(len(self), dtype="bool") + np_array = np.array(self) + try: + result[valid] = op(np_array[valid], other) + except TypeError: + result = ops.invalid_comparison(self, other, op) + result = pa.array(result, type=pa.bool_()) + result = pc.if_else(valid, result, None) else: raise NotImplementedError( f"{op.__name__} not implemented for {type(other)}" diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index b93d1ae408400..c68b329b00968 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -971,6 +971,8 @@ def _cmp_method(self, other, op): try: other = self._validate_comparison_value(other) except InvalidComparison: + if hasattr(other, "dtype") and isinstance(other.dtype, ArrowDtype): + return NotImplemented return invalid_comparison(self, other, op) dtype = getattr(other, "dtype", None) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 1863771dff593..6e11b54e3dfee 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -3578,3 +3578,53 @@ def test_timestamp_dtype_matches_to_datetime(): expected = pd.Series([ts], dtype=dtype1).convert_dtypes(dtype_backend="pyarrow") tm.assert_series_equal(result, expected) + + +def test_timestamp_vs_dt64_comparison(): + # GH#60937 + left = pd.Series(["2016-01-01"], dtype="timestamp[ns][pyarrow]") + right = left.astype("datetime64[ns]") + + result = left == right + expected = pd.Series([True], dtype="bool[pyarrow]") + tm.assert_series_equal(result, expected) + + result = right == left + tm.assert_series_equal(result, expected) + + +# TODO: reuse assert_invalid_comparison? +def test_date_vs_timestamp_scalar_comparison(): + # GH#62157 match non-pyarrow behavior + ser = pd.Series(["2016-01-01"], dtype="date32[pyarrow]") + ser2 = ser.astype("timestamp[ns][pyarrow]") + + ts = ser2[0] + dt = ser[0] + + # date dtype don't match a Timestamp object + assert not (ser == ts).any() + assert not (ts == ser).any() + + # timestamp dtype doesn't match date object + assert not (ser2 == dt).any() + assert not (dt == ser2).any() + + +# TODO: reuse assert_invalid_comparison? +def test_date_vs_timestamp_array_comparison(): + # GH#62157 match non-pyarrow behavior + # GH# + ser = pd.Series(["2016-01-01"], dtype="date32[pyarrow]") + ser2 = ser.astype("timestamp[ns][pyarrow]") + ser3 = ser.astype("datetime64[ns]") + + assert not (ser == ser2).any() + assert not (ser2 == ser).any() + assert (ser != ser2).all() + assert (ser2 != ser).all() + + assert not (ser == ser3).any() + assert not (ser3 == ser).any() + assert (ser != ser3).all() + assert (ser3 != ser).all()