Skip to content

Commit 2b55311

Browse files
committed
fix remaining tests
1 parent 48f6a8b commit 2b55311

File tree

12 files changed

+197
-103
lines changed

12 files changed

+197
-103
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,11 +396,66 @@ def _cast_pointwise_result(self, values) -> ArrayLike:
396396
if len(values) == 0:
397397
# Retain our dtype
398398
return self[:0].copy()
399-
arr = pa.array(values, from_pandas=True)
399+
400+
try:
401+
arr = pa.array(values, from_pandas=True)
402+
except (ValueError, TypeError):
403+
# e.g. test_by_column_values_with_same_starting_value with nested
404+
# values, one entry of which is an ArrowStringArray
405+
# or test_agg_lambda_complex128_dtype_conversion for complex values
406+
return super()._cast_pointwise_result(values)
407+
408+
if pa.types.is_duration(arr.type):
409+
# workaround for https://github.com/apache/arrow/issues/40620
410+
result = ArrowExtensionArray._from_sequence(values)
411+
if pa.types.is_duration(self._pa_array.type):
412+
result = result.astype(self.dtype)
413+
elif pa.types.is_timestamp(self._pa_array.type):
414+
# Try to retain original unit
415+
new_dtype = ArrowDtype(pa.duration(self._pa_array.type.unit))
416+
try:
417+
result = result.astype(new_dtype)
418+
except ValueError:
419+
pass
420+
elif pa.types.is_date64(self._pa_array.type):
421+
# Try to match unit we get on non-pointwise op
422+
dtype = ArrowDtype(pa.duration("ms"))
423+
result = result.astype(dtype)
424+
elif pa.types.is_date(self._pa_array.type):
425+
# Try to match unit we get on non-pointwise op
426+
dtype = ArrowDtype(pa.duration("s"))
427+
result = result.astype(dtype)
428+
return result
429+
430+
elif pa.types.is_date(arr.type) and pa.types.is_date(self._pa_array.type):
431+
arr = arr.cast(self._pa_array.type)
432+
elif pa.types.is_time(arr.type) and pa.types.is_time(self._pa_array.type):
433+
arr = arr.cast(self._pa_array.type)
434+
elif pa.types.is_decimal(arr.type) and pa.types.is_decimal(self._pa_array.type):
435+
arr = arr.cast(self._pa_array.type)
436+
elif pa.types.is_integer(arr.type) and pa.types.is_integer(self._pa_array.type):
437+
try:
438+
arr = arr.cast(self._pa_array.type)
439+
except pa.lib.ArrowInvalid:
440+
# e.g. test_combine_add if we can't cast
441+
pass
442+
elif pa.types.is_floating(arr.type) and pa.types.is_floating(
443+
self._pa_array.type
444+
):
445+
try:
446+
arr = arr.cast(self._pa_array.type)
447+
except pa.lib.ArrowInvalid:
448+
# e.g. test_combine_add if we can't cast
449+
pass
450+
400451
if isinstance(self.dtype, StringDtype):
401452
if pa.types.is_string(arr.type) or pa.types.is_large_string(arr.type):
402453
# ArrowStringArrayNumpySemantics
403-
return type(self)(arr)
454+
return type(self)(arr).astype(self.dtype)
455+
if self.dtype.na_value is np.nan:
456+
# ArrowEA has different semantics, so we return numpy-based
457+
# result instead
458+
return super()._cast_pointwise_result(values)
404459
return ArrowExtensionArray(arr)
405460
return type(self)(arr)
406461

pandas/core/arrays/masked.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pandas.util._decorators import doc
2727

2828
from pandas.core.dtypes.base import ExtensionDtype
29+
from pandas.core.dtypes.cast import maybe_downcast_to_dtype
2930
from pandas.core.dtypes.common import (
3031
is_bool,
3132
is_integer_dtype,
@@ -149,7 +150,15 @@ def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False) -> Self:
149150

150151
def _cast_pointwise_result(self, values) -> ArrayLike:
151152
values = np.asarray(values, dtype=object)
152-
return lib.maybe_convert_objects(values, convert_to_nullable_dtype=True)
153+
result = lib.maybe_convert_objects(values, convert_to_nullable_dtype=True)
154+
lkind = self.dtype.kind
155+
rkind = result.dtype.kind
156+
if (lkind in "iu" and rkind in "iu") or (lkind == rkind == "f"):
157+
new_data = maybe_downcast_to_dtype(
158+
result._data, dtype=self.dtype.numpy_dtype
159+
)
160+
result = type(result)(new_data, result._mask)
161+
return result
153162

154163
@classmethod
155164
@doc(ExtensionArray._empty)

pandas/core/arrays/numpy_.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
from pandas.compat.numpy import function as nv
1515

1616
from pandas.core.dtypes.astype import astype_array
17-
from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
17+
from pandas.core.dtypes.cast import (
18+
construct_1d_object_array_from_listlike,
19+
maybe_downcast_to_dtype,
20+
)
1821
from pandas.core.dtypes.common import pandas_dtype
1922
from pandas.core.dtypes.dtypes import NumpyEADtype
2023
from pandas.core.dtypes.missing import isna
@@ -34,6 +37,7 @@
3437
from collections.abc import Callable
3538

3639
from pandas._typing import (
40+
ArrayLike,
3741
AxisInt,
3842
Dtype,
3943
FillnaOptions,
@@ -145,6 +149,24 @@ def _from_sequence(
145149
result = result.copy()
146150
return cls(result)
147151

152+
def _cast_pointwise_result(self, values) -> ArrayLike:
153+
result = super()._cast_pointwise_result(values)
154+
lkind = self.dtype.kind
155+
rkind = result.dtype.kind
156+
if (
157+
(lkind in "iu" and rkind in "iu")
158+
or (lkind == "f" and rkind == "f")
159+
or (lkind == rkind == "c")
160+
):
161+
result = maybe_downcast_to_dtype(result, self.dtype.numpy_dtype)
162+
elif rkind == "M":
163+
# Ensure potential subsequent .astype(object) doesn't incorrectly
164+
# convert Timestamps to ints
165+
from pandas import array as pd_array
166+
167+
result = pd_array(result, copy=False)
168+
return result
169+
148170
# ------------------------------------------------------------------------
149171
# Data
150172

pandas/core/arrays/sparse/array.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,20 @@ def _from_factorized(cls, values, original) -> Self:
609609

610610
def _cast_pointwise_result(self, values):
611611
result = super()._cast_pointwise_result(values)
612-
return type(self)._from_sequence(result)
612+
if result.dtype.kind == self.dtype.kind:
613+
try:
614+
# e.g. test_groupby_agg_extension
615+
res = type(self)._from_sequence(result, dtype=self.dtype)
616+
if ((res == result) | (isna(result) & res.isna())).all():
617+
# This does not hold for e.g.
618+
# test_arith_frame_with_scalar[0-__truediv__]
619+
return res
620+
return type(self)._from_sequence(result)
621+
except (ValueError, TypeError):
622+
return type(self)._from_sequence(result)
623+
else:
624+
# e.g. test_combine_le avoid casting bools to Sparse[float64, nan]
625+
return type(self)._from_sequence(result)
613626

614627
# ------------------------------------------------------------------------
615628
# Data

pandas/core/indexes/base.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6397,18 +6397,20 @@ def map(self, mapper, na_action: Literal["ignore"] | None = None):
63976397
if not new_values.size:
63986398
# empty
63996399
dtype = self.dtype
6400-
return Index(new_values, dtype=dtype, copy=False, name=self.name)
6400+
elif isinstance(new_values, Categorical):
6401+
# cast_pointwise_result is unnecessary
6402+
dtype = new_values.dtype
64016403
else:
6404+
if isinstance(self, MultiIndex):
6405+
arr = self[:0].to_flat_index().array
6406+
else:
6407+
arr = self[:0].array
64026408
# e.g. if we are floating and new_values is all ints, then we
64036409
# don't want to cast back to floating. But if we are UInt64
64046410
# and new_values is all ints, we want to try.
6405-
if isinstance(self._values, np.ndarray):
6406-
return Index(new_values, dtype=dtype, copy=False, name=self.name)
6407-
else:
6408-
new_values = self._values._cast_pointwise_result(new_values)
6409-
return Index(
6410-
new_values, dtype=new_values.dtype, copy=False, name=self.name
6411-
)
6411+
new_values = arr._cast_pointwise_result(new_values)
6412+
dtype = new_values.dtype
6413+
return Index(new_values, dtype=dtype, copy=False, name=self.name)
64126414

64136415
# TODO: De-duplicate with map, xref GH#32349
64146416
@final

pandas/tests/extension/base/methods.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,18 @@ def test_combine_le(self, data_repeated):
367367
)
368368
tm.assert_series_equal(result, expected)
369369

370+
def _construct_for_combine_add(self, left, right):
371+
if isinstance(right, type(left)):
372+
return left._from_sequence(
373+
[a + b for (a, b) in zip(list(left), list(right))],
374+
dtype=left.dtype,
375+
)
376+
else:
377+
return left._from_sequence(
378+
[a + right for a in list(left)],
379+
dtype=left.dtype,
380+
)
381+
370382
def test_combine_add(self, data_repeated):
371383
# GH 20825
372384
orig_data1, orig_data2 = data_repeated(2)
@@ -377,26 +389,22 @@ def test_combine_add(self, data_repeated):
377389
# we will expect Series.combine to raise as well.
378390
try:
379391
with np.errstate(over="ignore"):
380-
expected = pd.Series(
381-
orig_data1._from_sequence(
382-
[a + b for (a, b) in zip(list(orig_data1), list(orig_data2))]
383-
)
384-
)
392+
arr = self._construct_for_combine_add(orig_data1, orig_data2)
385393
except TypeError:
386394
# If the operation is not supported pointwise for our scalars,
387395
# then Series.combine should also raise
388396
with pytest.raises(TypeError):
389397
s1.combine(s2, lambda x1, x2: x1 + x2)
390398
return
399+
expected = pd.Series(arr)
391400

392401
result = s1.combine(s2, lambda x1, x2: x1 + x2)
393402
tm.assert_series_equal(result, expected)
394403

395404
val = s1.iloc[0]
396405
result = s1.combine(val, lambda x1, x2: x1 + x2)
397-
expected = pd.Series(
398-
orig_data1._from_sequence([a + val for a in list(orig_data1)])
399-
)
406+
arr = self._construct_for_combine_add(orig_data1, val)
407+
expected = pd.Series(arr)
400408
tm.assert_series_equal(result, expected)
401409

402410
def test_combine_first(self, data):

pandas/tests/extension/decimal/test_decimal.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import decimal
4-
import operator
54

65
import numpy as np
76
import pytest
@@ -282,33 +281,10 @@ def _create_arithmetic_method(cls, op):
282281
DecimalArrayWithoutCoercion._add_arithmetic_ops()
283282

284283

285-
def test_combine_from_sequence_raises(monkeypatch):
286-
# https://github.com/pandas-dev/pandas/issues/22850
287-
cls = DecimalArrayWithoutFromSequence
288-
289-
def construct_array_type(self):
290-
return DecimalArrayWithoutFromSequence
291-
292-
monkeypatch.setattr(DecimalDtype, "construct_array_type", construct_array_type)
293-
294-
arr = cls([decimal.Decimal("1.0"), decimal.Decimal("2.0")])
295-
ser = pd.Series(arr)
296-
result = ser.combine(ser, operator.add)
297-
298-
# note: object dtype
299-
expected = pd.Series(
300-
[decimal.Decimal("2.0"), decimal.Decimal("4.0")], dtype="object"
301-
)
302-
tm.assert_series_equal(result, expected)
303-
304-
305-
@pytest.mark.parametrize(
306-
"class_", [DecimalArrayWithoutFromSequence, DecimalArrayWithoutCoercion]
307-
)
308-
def test_scalar_ops_from_sequence_raises(class_):
284+
def test_scalar_ops_from_sequence_raises():
309285
# op(EA, EA) should return an EA, or an ndarray if it's not possible
310286
# to return an EA with the return values.
311-
arr = class_([decimal.Decimal("1.0"), decimal.Decimal("2.0")])
287+
arr = DecimalArrayWithoutCoercion([decimal.Decimal("1.0"), decimal.Decimal("2.0")])
312288
result = arr + arr
313289
expected = np.array(
314290
[decimal.Decimal("2.0"), decimal.Decimal("4.0")], dtype="object"

pandas/tests/extension/json/array.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,13 @@ def _from_sequence(cls, scalars, *, dtype=None, copy=False):
9090
def _from_factorized(cls, values, original):
9191
return cls([UserDict(x) for x in values if x != ()])
9292

93+
def _cast_pointwise_result(self, values):
94+
result = super()._cast_pointwise_result(values)
95+
try:
96+
return type(self)._from_sequence(result, dtype=self.dtype)
97+
except (ValueError, TypeError):
98+
return result
99+
93100
def __getitem__(self, item):
94101
if isinstance(item, tuple):
95102
item = unpack_tuple_and_ellipses(item)

pandas/tests/extension/test_arrow.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747
from pandas.errors import Pandas4Warning
4848

49+
from pandas.core.dtypes.common import pandas_dtype
4950
from pandas.core.dtypes.dtypes import (
5051
ArrowDtype,
5152
CategoricalDtypeType,
@@ -271,6 +272,26 @@ def data_for_twos(data):
271272

272273

273274
class TestArrowArray(base.ExtensionTests):
275+
def _construct_for_combine_add(self, left, right):
276+
dtype = left.dtype
277+
278+
# in a couple cases, addition is not dtype-preserving
279+
if dtype == "bool[pyarrow]":
280+
dtype = pandas_dtype("int64[pyarrow]")
281+
elif dtype == "int8[pyarrow]" and isinstance(right, type(left)):
282+
dtype = pandas_dtype("int64[pyarrow]")
283+
284+
if isinstance(right, type(left)):
285+
return left._from_sequence(
286+
[a + b for (a, b) in zip(list(left), list(right))],
287+
dtype=dtype,
288+
)
289+
else:
290+
return left._from_sequence(
291+
[a + right for a in list(left)],
292+
dtype=dtype,
293+
)
294+
274295
def test_compare_scalar(self, data, comparison_op):
275296
ser = pd.Series(data)
276297
self._compare_other(ser, data, comparison_op, data[0])
@@ -797,14 +818,24 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
797818
if op_name in ["eq", "ne", "lt", "le", "gt", "ge"]:
798819
return pointwise_result.astype("boolean[pyarrow]")
799820

821+
original_dtype = tm.get_dtype(expected)
822+
800823
was_frame = False
801824
if isinstance(expected, pd.DataFrame):
802825
was_frame = True
803826
expected_data = expected.iloc[:, 0]
804-
original_dtype = obj.iloc[:, 0].dtype
805827
else:
806828
expected_data = expected
807-
original_dtype = obj.dtype
829+
830+
# the pointwise method will have retained our original dtype, while
831+
# the op(ser, other) version will have cast to 64bit
832+
if type(other) is int and op_name not in ["__floordiv__"]:
833+
if original_dtype.kind == "f":
834+
return expected.astype("float64[pyarrow]")
835+
else:
836+
return expected.astype("int64[pyarrow]")
837+
elif type(other) is float:
838+
return expected.astype("float64[pyarrow]")
808839

809840
orig_pa_type = original_dtype.pyarrow_dtype
810841
if not was_frame and isinstance(other, pd.Series):
@@ -836,29 +867,7 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
836867

837868
pa_expected = pa.array(expected_data._values)
838869

839-
if pa.types.is_duration(pa_expected.type):
840-
if pa.types.is_date(orig_pa_type):
841-
if pa.types.is_date64(orig_pa_type):
842-
# TODO: why is this different vs date32?
843-
unit = "ms"
844-
else:
845-
unit = "s"
846-
else:
847-
# pyarrow sees sequence of datetime/timedelta objects and defaults
848-
# to "us" but the non-pointwise op retains unit
849-
# timestamp or duration
850-
unit = orig_pa_type.unit
851-
if type(other) in [datetime, timedelta] and unit in ["s", "ms"]:
852-
# pydatetime/pytimedelta objects have microsecond reso, so we
853-
# take the higher reso of the original and microsecond. Note
854-
# this matches what we would do with DatetimeArray/TimedeltaArray
855-
unit = "us"
856-
857-
pa_expected = pa_expected.cast(f"duration[{unit}]")
858-
859-
elif pa.types.is_decimal(pa_expected.type) and pa.types.is_decimal(
860-
orig_pa_type
861-
):
870+
if pa.types.is_decimal(pa_expected.type) and pa.types.is_decimal(orig_pa_type):
862871
# decimal precision can resize in the result type depending on data
863872
# just compare the float values
864873
alt = getattr(obj, op_name)(other)

0 commit comments

Comments
 (0)