Skip to content

Commit 5b911f6

Browse files
authored
API: consistent sanitize_array-wrapping for lists in arithmetic ops (#62552)
1 parent b69fb3c commit 5b911f6

File tree

12 files changed

+123
-53
lines changed

12 files changed

+123
-53
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,7 @@ Other API changes
750750
- :class:`Series` "flex" methods like :meth:`Series.add` no longer allow passing a :class:`DataFrame` for ``other``; use the DataFrame reversed method instead (:issue:`46179`)
751751
- :meth:`CategoricalIndex.append` no longer attempts to cast different-dtype indexes to the caller's dtype (:issue:`41626`)
752752
- :meth:`ExtensionDtype.construct_array_type` is now a regular method instead of a ``classmethod`` (:issue:`58663`)
753+
- Arithmetic operations between a :class:`Series`, :class:`Index`, or :class:`ExtensionArray` with a ``list`` now consistently wrap that list with an array equivalent to ``Series(my_list).array``. To do any other kind of type inference or casting, do so explicitly before operating (:issue:`62552`)
753754
- Comparison operations between :class:`Index` and :class:`Series` now consistently return :class:`Series` regardless of which object is on the left or right (:issue:`36759`)
754755
- Numpy functions like ``np.isinf`` that return a bool dtype when called on a :class:`Index` object now return a bool-dtype :class:`Index` instead of ``np.ndarray`` (:issue:`52676`)
755756

pandas/core/arrays/masked.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,9 @@ def _maybe_mask_result(
10171017

10181018
return IntegerArray(result, mask, copy=False)
10191019

1020+
elif result.dtype == object:
1021+
result[mask] = self.dtype.na_value
1022+
return result
10201023
else:
10211024
result[mask] = np.nan
10221025
return result

pandas/core/indexes/base.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7285,19 +7285,7 @@ def _cmp_method(self, other, op):
72857285
else:
72867286
other = np.asarray(other)
72877287

7288-
if is_object_dtype(self.dtype) and isinstance(other, ExtensionArray):
7289-
# e.g. PeriodArray, Categorical
7290-
result = op(self._values, other)
7291-
7292-
elif isinstance(self._values, ExtensionArray):
7293-
result = op(self._values, other)
7294-
7295-
elif is_object_dtype(self.dtype) and not isinstance(self, ABCMultiIndex):
7296-
# don't pass MultiIndex
7297-
result = ops.comp_method_OBJECT_ARRAY(op, self._values, other)
7298-
7299-
else:
7300-
result = ops.comparison_op(self._values, other, op)
7288+
result = ops.comparison_op(self._values, other, op)
73017289

73027290
return result
73037291

pandas/core/ops/array_ops.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@
5353

5454
from pandas.core import roperator
5555
from pandas.core.computation import expressions
56-
from pandas.core.construction import ensure_wrapped_if_datetimelike
56+
from pandas.core.construction import (
57+
ensure_wrapped_if_datetimelike,
58+
sanitize_array,
59+
)
5760
from pandas.core.ops import missing
5861
from pandas.core.ops.dispatch import should_extension_dispatch
5962
from pandas.core.ops.invalid import invalid_comparison
@@ -261,6 +264,10 @@ def arithmetic_op(left: ArrayLike, right: Any, op):
261264
# and `maybe_prepare_scalar_for_op` has already been called on `right`
262265
# We need to special-case datetime64/timedelta64 dtypes (e.g. because numpy
263266
# casts integer dtypes to timedelta64 when operating with timedelta64 - GH#22390)
267+
if isinstance(right, list):
268+
# GH#62423
269+
right = sanitize_array(right, None)
270+
right = ensure_wrapped_if_datetimelike(right)
264271

265272
if (
266273
should_extension_dispatch(left, right)
@@ -310,7 +317,8 @@ def comparison_op(left: ArrayLike, right: Any, op) -> ArrayLike:
310317
if isinstance(rvalues, list):
311318
# We don't catch tuple here bc we may be comparing e.g. MultiIndex
312319
# to a tuple that represents a single entry, see test_compare_tuple_strs
313-
rvalues = np.asarray(rvalues)
320+
rvalues = sanitize_array(rvalues, None)
321+
rvalues = ensure_wrapped_if_datetimelike(rvalues)
314322

315323
if isinstance(rvalues, (np.ndarray, ABCExtensionArray)):
316324
# TODO: make this treatment consistent across ops and classes.

pandas/core/ops/common.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,16 @@
1111
from pandas._libs.missing import is_matching_na
1212

1313
from pandas.core.dtypes.generic import (
14+
ABCExtensionArray,
1415
ABCIndex,
1516
ABCSeries,
1617
)
1718

19+
from pandas.core.construction import (
20+
ensure_wrapped_if_datetimelike,
21+
sanitize_array,
22+
)
23+
1824
if TYPE_CHECKING:
1925
from collections.abc import Callable
2026

@@ -56,6 +62,7 @@ def _unpack_zerodim_and_defer(method: F, name: str) -> F:
5662
-------
5763
method
5864
"""
65+
is_logical = name.strip("_") in ["or", "xor", "and", "ror", "rxor", "rand"]
5966

6067
@wraps(method)
6168
def new_method(self, other):
@@ -66,6 +73,14 @@ def new_method(self, other):
6673
return NotImplemented
6774

6875
other = item_from_zerodim(other)
76+
if (
77+
isinstance(self, ABCExtensionArray)
78+
and isinstance(other, list)
79+
and not is_logical
80+
):
81+
# See GH#62423
82+
other = sanitize_array(other, None)
83+
other = ensure_wrapped_if_datetimelike(other)
6984

7085
return method(self, other)
7186

pandas/tests/arithmetic/test_numeric.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,23 @@ def test_numeric_cmp_string_numexpr_path(self, box_with_array, monkeypatch):
151151

152152

153153
class TestNumericArraylikeArithmeticWithDatetimeLike:
154+
def test_mul_timedelta_list(self, box_with_array):
155+
# GH#62524
156+
box = box_with_array
157+
left = np.array([3, 4])
158+
left = tm.box_expected(left, box)
159+
160+
right = [Timedelta(days=1), Timedelta(days=2)]
161+
162+
result = left * right
163+
164+
expected = TimedeltaIndex([Timedelta(days=3), Timedelta(days=8)])
165+
expected = tm.box_expected(expected, box)
166+
tm.assert_equal(result, expected)
167+
168+
result2 = right * left
169+
tm.assert_equal(result2, expected)
170+
154171
@pytest.mark.parametrize("box_cls", [np.array, Index, Series])
155172
@pytest.mark.parametrize(
156173
"left", lefts, ids=lambda x: type(x).__name__ + str(x.dtype)

pandas/tests/arithmetic/test_string.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,22 @@ def test_add_2d(any_string_dtype, request):
213213
s + b
214214

215215

216-
def test_add_sequence(any_string_dtype, request):
216+
def test_add_sequence(any_string_dtype, request, using_infer_string):
217217
dtype = any_string_dtype
218-
if dtype == np.dtype(object):
218+
if (
219+
dtype != object
220+
and dtype.storage == "python"
221+
and dtype.na_value is np.nan
222+
and HAS_PYARROW
223+
and using_infer_string
224+
):
225+
mark = pytest.mark.xfail(
226+
reason="As of GH#62522, the list gets wrapped with sanitize_array, "
227+
"which casts to a higher-priority StringArray, so we get "
228+
"NotImplemented."
229+
)
230+
request.applymarker(mark)
231+
if dtype == np.dtype(object) and using_infer_string:
219232
mark = pytest.mark.xfail(reason="Cannot broadcast list")
220233
request.applymarker(mark)
221234

@@ -415,30 +428,45 @@ def test_comparison_methods_array_arrow_extension(comparison_op, any_string_dtyp
415428
tm.assert_extension_array_equal(result, expected)
416429

417430

418-
def test_comparison_methods_list(comparison_op, any_string_dtype):
431+
@pytest.mark.parametrize("box", [pd.array, pd.Index, Series])
432+
def test_comparison_methods_list(comparison_op, any_string_dtype, box, request):
419433
dtype = any_string_dtype
434+
435+
if box is pd.array and dtype != object and dtype.na_value is np.nan:
436+
mark = pytest.mark.xfail(
437+
reason="After wrapping list, op returns NotImplemented, see GH#62522"
438+
)
439+
request.applymarker(mark)
440+
420441
op_name = f"__{comparison_op.__name__}__"
421442

422-
a = pd.array(["a", None, "c"], dtype=dtype)
443+
a = box(pd.array(["a", None, "c"], dtype=dtype))
444+
item = "c"
423445
other = [None, None, "c"]
424446
result = comparison_op(a, other)
425447

426448
# ensure operation is commutative
427449
result2 = comparison_op(other, a)
428450
tm.assert_equal(result, result2)
429451

430-
if dtype == object or dtype.na_value is np.nan:
452+
if dtype == np.dtype(object) or dtype.na_value is np.nan:
431453
if operator.ne == comparison_op:
432454
expected = np.array([True, True, False])
433455
else:
434456
expected = np.array([False, False, False])
435-
expected[-1] = getattr(other[-1], op_name)(a[-1])
436-
result = extract_array(result, extract_numpy=True)
437-
tm.assert_numpy_array_equal(result, expected)
457+
expected[-1] = getattr(item, op_name)(item)
458+
if box is not pd.Index:
459+
# if GH#62766 is addressed this check can be removed
460+
expected = box(expected, dtype=expected.dtype)
461+
tm.assert_equal(result, expected)
438462

439463
else:
440464
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
441465
expected = np.full(len(a), fill_value=None, dtype="object")
442-
expected[-1] = getattr(other[-1], op_name)(a[-1])
466+
expected[-1] = getattr(item, op_name)(item)
443467
expected = pd.array(expected, dtype=expected_dtype)
444-
tm.assert_extension_array_equal(result, expected)
468+
expected = extract_array(expected, extract_numpy=True)
469+
if box is not pd.Index:
470+
# if GH#62766 is addressed this check can be removed
471+
expected = tm.box_expected(expected, box)
472+
tm.assert_equal(result, expected)

pandas/tests/arithmetic/test_timedelta64.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,23 @@ class TestTimedelta64ArithmeticUnsorted:
274274
# Tests moved from type-specific test files but not
275275
# yet sorted/parametrized/de-duplicated
276276

277+
def test_td64_op_with_list(self, box_with_array):
278+
# GH#62353
279+
box = box_with_array
280+
281+
left = TimedeltaIndex(["2D", "4D"])
282+
left = tm.box_expected(left, box)
283+
284+
right = [Timestamp("2016-01-01"), Timestamp("2016-02-01")]
285+
286+
result = left + right
287+
expected = DatetimeIndex(["2016-01-03", "2016-02-05"], dtype="M8[ns]")
288+
expected = tm.box_expected(expected, box)
289+
tm.assert_equal(result, expected)
290+
291+
result2 = right + left
292+
tm.assert_equal(result2, expected)
293+
277294
def test_ufunc_coercions(self):
278295
# normal ops are also tested in tseries/test_timedeltas.py
279296
idx = TimedeltaIndex(["2h", "4h", "6h", "8h", "10h"], freq="2h", name="x")

pandas/tests/arrays/integer/test_arithmetic.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,6 @@ def test_error_invalid_values(data, all_arithmetic_operators):
201201
]: # (data[~data.isna()] >= 0).all():
202202
res = ops(str_ser)
203203
expected = pd.Series(["foo" * x for x in data], index=s.index)
204-
expected = expected.fillna(np.nan)
205-
# TODO: doing this fillna to keep tests passing as we make
206-
# assert_almost_equal stricter, but the expected with pd.NA seems
207-
# more-correct than np.nan here.
208204
tm.assert_series_equal(res, expected)
209205
else:
210206
with tm.external_error_raised(TypeError):

pandas/tests/indexes/categorical/test_category.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,10 @@ def test_disallow_addsub_ops(self, func, op_name):
326326
cat_or_list = "'(Categorical|list)' and '(Categorical|list)'"
327327
msg = "|".join(
328328
[
329-
f"cannot perform {op_name} with this index type: CategoricalIndex",
330-
"can only concatenate list",
331329
rf"unsupported operand type\(s\) for [\+-]: {cat_or_list}",
330+
"Object with dtype category cannot perform the numpy op (add|subtract)",
331+
"operation 'r?(add|sub)' not supported for dtype 'str' "
332+
"with dtype 'category'",
332333
]
333334
)
334335
with pytest.raises(TypeError, match=msg):

0 commit comments

Comments
 (0)