Skip to content

Commit 117175b

Browse files
committed
Revert old approach: remove min/max methods from ArrowExtensionArray
1 parent dd936ad commit 117175b

File tree

4 files changed

+49
-87
lines changed

4 files changed

+49
-87
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 47 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
overload,
1313
)
1414
import unicodedata
15-
import warnings
1615

1716
import numpy as np
1817

@@ -23,12 +22,11 @@
2322
timezones,
2423
)
2524
from pandas.compat import (
26-
HAS_PYARROW,
27-
pa_version_under12p1,
25+
pa_version_under10p1,
26+
pa_version_under11p0,
2827
pa_version_under13p0,
2928
)
3029
from pandas.util._decorators import doc
31-
from pandas.util._exceptions import find_stack_level
3230

3331
from pandas.core.dtypes.cast import (
3432
can_hold_element,
@@ -65,7 +63,6 @@
6563
from pandas.core.arrays.masked import BaseMaskedArray
6664
from pandas.core.arrays.string_ import StringDtype
6765
import pandas.core.common as com
68-
from pandas.core.construction import extract_array
6966
from pandas.core.indexers import (
7067
check_array_indexer,
7168
unpack_tuple_and_ellipses,
@@ -77,7 +74,7 @@
7774
from pandas.io._util import _arrow_dtype_mapping
7875
from pandas.tseries.frequencies import to_offset
7976

80-
if HAS_PYARROW:
77+
if not pa_version_under10p1:
8178
import pyarrow as pa
8279
import pyarrow.compute as pc
8380

@@ -211,6 +208,16 @@ def floordiv_compat(
211208
from pandas.core.arrays.timedeltas import TimedeltaArray
212209

213210

211+
def get_unit_from_pa_dtype(pa_dtype) -> str:
212+
# https://github.com/pandas-dev/pandas/pull/50998#discussion_r1100344804
213+
if pa_version_under11p0:
214+
unit = str(pa_dtype).split("[", 1)[-1][:-1]
215+
if unit not in ["s", "ms", "us", "ns"]:
216+
raise ValueError(pa_dtype)
217+
return unit
218+
return pa_dtype.unit
219+
220+
214221
def to_pyarrow_type(
215222
dtype: ArrowDtype | pa.DataType | Dtype | None,
216223
) -> pa.DataType | None:
@@ -293,7 +300,7 @@ class ArrowExtensionArray(
293300
_dtype: ArrowDtype
294301

295302
def __init__(self, values: pa.Array | pa.ChunkedArray) -> None:
296-
if pa_version_under12p1:
303+
if pa_version_under10p1:
297304
msg = "pyarrow>=10.0.1 is required for PyArrow backed ArrowExtensionArray."
298305
raise ImportError(msg)
299306
if isinstance(values, pa.Array):
@@ -503,33 +510,6 @@ def _box_pa_array(
503510
value = to_timedelta(value, unit=pa_type.unit).as_unit(pa_type.unit)
504511
value = value.to_numpy()
505512

506-
if pa_type is not None and pa.types.is_timestamp(pa_type):
507-
# Use DatetimeArray to exclude Decimal(NaN) (GH#61774) and
508-
# ensure constructor treats tznaive the same as non-pyarrow
509-
# dtypes (GH#61775)
510-
from pandas.core.arrays.datetimes import (
511-
DatetimeArray,
512-
tz_to_dtype,
513-
)
514-
515-
pass_dtype = tz_to_dtype(tz=pa_type.tz, unit=pa_type.unit)
516-
value = extract_array(value, extract_numpy=True)
517-
if isinstance(value, DatetimeArray):
518-
dta = value
519-
else:
520-
dta = DatetimeArray._from_sequence(
521-
value, copy=copy, dtype=pass_dtype
522-
)
523-
dta_mask = dta.isna()
524-
value_i8 = cast("npt.NDArray", dta.view("i8"))
525-
if not value_i8.flags["WRITEABLE"]:
526-
# e.g. test_setitem_frame_2d_values
527-
value_i8 = value_i8.copy()
528-
dta = DatetimeArray._from_sequence(value_i8, dtype=dta.dtype)
529-
value_i8[dta_mask] = 0 # GH#61776 avoid __sub__ overflow
530-
pa_array = pa.array(dta._ndarray, type=pa_type, mask=dta_mask)
531-
return pa_array
532-
533513
try:
534514
pa_array = pa.array(value, type=pa_type, from_pandas=True)
535515
except (pa.ArrowInvalid, pa.ArrowTypeError):
@@ -854,25 +834,6 @@ def _logical_method(self, other, op) -> Self:
854834
# integer types. Otherwise these are boolean ops.
855835
if pa.types.is_integer(self._pa_array.type):
856836
return self._evaluate_op_method(other, op, ARROW_BIT_WISE_FUNCS)
857-
elif (
858-
(
859-
pa.types.is_string(self._pa_array.type)
860-
or pa.types.is_large_string(self._pa_array.type)
861-
)
862-
and op in (roperator.ror_, roperator.rand_, roperator.rxor)
863-
and isinstance(other, np.ndarray)
864-
and other.dtype == bool
865-
):
866-
# GH#60234 backward compatibility for the move to StringDtype in 3.0
867-
op_name = op.__name__[1:].strip("_")
868-
warnings.warn(
869-
f"'{op_name}' operations between boolean dtype and {self.dtype} are "
870-
"deprecated and will raise in a future version. Explicitly "
871-
"cast the strings to a boolean dtype before operating instead.",
872-
FutureWarning,
873-
stacklevel=find_stack_level(),
874-
)
875-
return op(other, self.astype(bool))
876837
else:
877838
return self._evaluate_op_method(other, op, ARROW_LOGICAL_FUNCS)
878839

@@ -1238,6 +1199,10 @@ def factorize(
12381199
null_encoding = "mask" if use_na_sentinel else "encode"
12391200

12401201
data = self._pa_array
1202+
pa_type = data.type
1203+
if pa_version_under11p0 and pa.types.is_duration(pa_type):
1204+
# https://github.com/apache/arrow/issues/15226#issuecomment-1376578323
1205+
data = data.cast(pa.int64())
12411206

12421207
if pa.types.is_dictionary(data.type):
12431208
if null_encoding == "encode":
@@ -1262,6 +1227,8 @@ def factorize(
12621227
)
12631228
uniques = type(self)(combined.dictionary)
12641229

1230+
if pa_version_under11p0 and pa.types.is_duration(pa_type):
1231+
uniques = cast(ArrowExtensionArray, uniques.astype(self.dtype))
12651232
return indices, uniques
12661233

12671234
def reshape(self, *args, **kwargs):
@@ -1548,7 +1515,19 @@ def unique(self) -> Self:
15481515
-------
15491516
ArrowExtensionArray
15501517
"""
1551-
pa_result = pc.unique(self._pa_array)
1518+
pa_type = self._pa_array.type
1519+
1520+
if pa_version_under11p0 and pa.types.is_duration(pa_type):
1521+
# https://github.com/apache/arrow/issues/15226#issuecomment-1376578323
1522+
data = self._pa_array.cast(pa.int64())
1523+
else:
1524+
data = self._pa_array
1525+
1526+
pa_result = pc.unique(data)
1527+
1528+
if pa_version_under11p0 and pa.types.is_duration(pa_type):
1529+
pa_result = pa_result.cast(pa_type)
1530+
15521531
return type(self)(pa_result)
15531532

15541533
def value_counts(self, dropna: bool = True) -> Series:
@@ -1568,12 +1547,18 @@ def value_counts(self, dropna: bool = True) -> Series:
15681547
--------
15691548
Series.value_counts
15701549
"""
1550+
pa_type = self._pa_array.type
1551+
if pa_version_under11p0 and pa.types.is_duration(pa_type):
1552+
# https://github.com/apache/arrow/issues/15226#issuecomment-1376578323
1553+
data = self._pa_array.cast(pa.int64())
1554+
else:
1555+
data = self._pa_array
1556+
15711557
from pandas import (
15721558
Index,
15731559
Series,
15741560
)
15751561

1576-
data = self._pa_array
15771562
vc = data.value_counts()
15781563

15791564
values = vc.field(0)
@@ -1583,6 +1568,9 @@ def value_counts(self, dropna: bool = True) -> Series:
15831568
values = values.filter(mask)
15841569
counts = counts.filter(mask)
15851570

1571+
if pa_version_under11p0 and pa.types.is_duration(pa_type):
1572+
values = values.cast(pa_type)
1573+
15861574
counts = ArrowExtensionArray(counts)
15871575

15881576
index = Index(type(self)(values))
@@ -1876,7 +1864,8 @@ def pyarrow_meth(data, skip_nulls, min_count=0): # type: ignore[misc]
18761864
if pa.types.is_duration(pa_type):
18771865
result = result.cast(pa_type)
18781866
elif pa.types.is_time(pa_type):
1879-
result = result.cast(pa.duration(pa_type.unit))
1867+
unit = get_unit_from_pa_dtype(pa_type)
1868+
result = result.cast(pa.duration(unit))
18801869
elif pa.types.is_date(pa_type):
18811870
# go with closest available unit, i.e. "s"
18821871
result = result.cast(pa.duration("s"))
@@ -1957,10 +1946,8 @@ def _explode(self):
19571946
fill_value = pa.scalar([None], type=self._pa_array.type)
19581947
mask = counts == 0
19591948
if mask.any():
1960-
# pc.if_else here is similar to `values[mask] = fill_value`
1961-
# but this avoids an object-dtype round-trip.
1962-
pa_values = pc.if_else(~mask, values._pa_array, fill_value)
1963-
values = type(self)(pa_values)
1949+
values = values.copy()
1950+
values[mask] = fill_value
19641951
counts = counts.copy()
19651952
counts[mask] = 1
19661953
values = values.fillna(fill_value)
@@ -2969,14 +2956,6 @@ def _dt_tz_convert(self, tz) -> Self:
29692956
result = self._pa_array.cast(pa.timestamp(current_unit, tz))
29702957
return type(self)(result)
29712958

2972-
def max(self, *, skipna: bool = True, axis: int | None = 0, **kwargs):
2973-
"""Return the maximum value of the array."""
2974-
return self._reduce("max", skipna=skipna, **kwargs)
2975-
2976-
def min(self, *, skipna: bool = True, axis: int | None = 0, **kwargs):
2977-
"""Return the minimum value of the array."""
2978-
return self._reduce("min", skipna=skipna, **kwargs)
2979-
29802959

29812960
def transpose_homogeneous_pyarrow(
29822961
arrays: Sequence[ArrowExtensionArray],

pandas/core/indexing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1612,7 +1612,7 @@ def _validate_key(self, key, axis: AxisInt) -> None:
16121612

16131613
if len(arr):
16141614
# convert to numpy array for min/max with ExtensionArrays
1615-
if hasattr(arr, 'to_numpy'):
1615+
if hasattr(arr, "to_numpy"):
16161616
np_arr = arr.to_numpy()
16171617
else:
16181618
np_arr = np.asarray(arr)

pandas/tests/arrays/test_array.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -530,16 +530,3 @@ def test_array_to_numpy_na():
530530
result = arr.to_numpy(na_value=True, dtype=bool)
531531
expected = np.array([True, True])
532532
tm.assert_numpy_array_equal(result, expected)
533-
534-
535-
def test_array_max_min():
536-
pytest.importorskip("pyarrow")
537-
# GH#61311
538-
df = pd.DataFrame({"a": [1, 2], "c": [0, 2], "d": ["c", "a"]})
539-
expected = df.iloc[:, df["c"]]
540-
df_pyarrow = pd.DataFrame(
541-
{"a": [1, 2], "c": [0, 2], "d": ["c", "a"]}
542-
).convert_dtypes(dtype_backend="pyarrow")
543-
result = df_pyarrow.iloc[:, df_pyarrow["c"]]
544-
expected_pyarrow = expected.convert_dtypes(dtype_backend="pyarrow")
545-
tm.assert_frame_equal(result, expected_pyarrow)

pandas/tests/indexing/test_iloc.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,11 +1487,7 @@ def test_iloc_arrow_extension_array(self):
14871487
# GH#61311
14881488
pytest.importorskip("pyarrow")
14891489

1490-
df = DataFrame({
1491-
"a": [1, 2],
1492-
"c": [0, 2],
1493-
"d": ["c", "a"]
1494-
})
1490+
df = DataFrame({"a": [1, 2], "c": [0, 2], "d": ["c", "a"]})
14951491

14961492
df_arrow = df.convert_dtypes(dtype_backend="pyarrow")
14971493
result = df_arrow.iloc[:, df_arrow["c"]]

0 commit comments

Comments
 (0)