Skip to content

Commit 593653a

Browse files
String dtype: implemen sum reduction
1 parent 2419343 commit 593653a

File tree

4 files changed

+46
-41
lines changed

4 files changed

+46
-41
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
unpack_tuple_and_ellipses,
6969
validate_indices,
7070
)
71+
from pandas.core.nanops import check_below_min_count
7172
from pandas.core.strings.base import BaseStringArrayMethods
7273

7374
from pandas.io._util import _arrow_dtype_mapping
@@ -1705,6 +1706,36 @@ def pyarrow_meth(data, skip_nulls, **kwargs):
17051706
denominator = pc.sqrt_checked(pc.count(self._pa_array))
17061707
return pc.divide_checked(numerator, denominator)
17071708

1709+
elif name == "sum" and (
1710+
pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type)
1711+
):
1712+
1713+
def pyarrow_meth(data, skip_nulls, min_count=0):
1714+
mask = pc.is_null(data) if data.null_count > 0 else None
1715+
if skip_nulls:
1716+
if min_count > 0 and check_below_min_count(
1717+
(len(data),),
1718+
None if mask is None else mask.to_numpy(),
1719+
min_count,
1720+
):
1721+
return pa.scalar(None, type=data.type)
1722+
if data.null_count > 0:
1723+
# binary_join returns null if there is any null ->
1724+
# have to filter out any nulls
1725+
data = data.filter(pc.invert(mask))
1726+
else:
1727+
if mask is not None or check_below_min_count(
1728+
(len(data),), None, min_count
1729+
):
1730+
return pa.scalar(None, type=data.type)
1731+
1732+
if pa.types.is_large_string(data.type):
1733+
data = data.cast(pa.string())
1734+
data_list = pa.ListArray.from_arrays(
1735+
[0, len(data)], data.combine_chunks()
1736+
)[0]
1737+
return pc.binary_join(data_list, "")
1738+
17081739
else:
17091740
pyarrow_name = {
17101741
"median": "quantile",

pandas/core/arrays/string_.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,9 @@ def _reduce(
809809
return self._from_sequence([result], dtype=self.dtype)
810810
return result
811811

812+
if name == "sum":
813+
return nanops.nansum(self._ndarray, skipna=skipna, **kwargs)
814+
812815
raise TypeError(f"Cannot perform reduction '{name}' with string dtype")
813816

814817
def _wrap_reduction_result(self, axis: AxisInt | None, result) -> Any:

pandas/tests/frame/test_reductions.py

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ def float_frame_with_na():
226226
class TestDataFrameAnalytics:
227227
# ---------------------------------------------------------------------
228228
# Reductions
229-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
230229
@pytest.mark.parametrize("axis", [0, 1])
231230
@pytest.mark.parametrize(
232231
"opname",
@@ -246,17 +245,11 @@ class TestDataFrameAnalytics:
246245
pytest.param("kurt", marks=td.skip_if_no("scipy")),
247246
],
248247
)
249-
def test_stat_op_api_float_string_frame(
250-
self, float_string_frame, axis, opname, using_infer_string
251-
):
252-
if (
253-
(opname in ("sum", "min", "max") and axis == 0)
254-
or opname
255-
in (
256-
"count",
257-
"nunique",
258-
)
259-
) and not (using_infer_string and opname == "sum"):
248+
def test_stat_op_api_float_string_frame(self, float_string_frame, axis, opname):
249+
if (opname in ("sum", "min", "max") and axis == 0) or opname in (
250+
"count",
251+
"nunique",
252+
):
260253
getattr(float_string_frame, opname)(axis=axis)
261254
else:
262255
if opname in ["var", "std", "sem", "skew", "kurt"]:
@@ -432,7 +425,6 @@ def test_stat_operators_attempt_obj_array(self, method, df, axis):
432425
expected[expected.isna()] = None
433426
tm.assert_series_equal(result, expected)
434427

435-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
436428
@pytest.mark.parametrize("op", ["mean", "std", "var", "skew", "kurt", "sem"])
437429
def test_mixed_ops(self, op):
438430
# GH#16116
@@ -466,9 +458,6 @@ def test_mixed_ops(self, op):
466458
with pytest.raises(TypeError, match=msg):
467459
getattr(df, op)()
468460

469-
@pytest.mark.xfail(
470-
using_string_dtype(), reason="sum doesn't work for arrow strings"
471-
)
472461
def test_reduce_mixed_frame(self):
473462
# GH 6806
474463
df = DataFrame(
@@ -608,7 +597,6 @@ def test_sem(self, datetime_frame):
608597
result = nanops.nansem(arr, axis=0)
609598
assert not (result < 0).any()
610599

611-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
612600
@pytest.mark.parametrize(
613601
"dropna, expected",
614602
[
@@ -630,7 +618,7 @@ def test_sem(self, datetime_frame):
630618
"A": [12],
631619
"B": [10.0],
632620
"C": [np.nan],
633-
"D": np.array([np.nan], dtype=object),
621+
"D": Series([np.nan], dtype="str"),
634622
"E": Categorical([np.nan], categories=["a"]),
635623
"F": DatetimeIndex([pd.NaT], dtype="M8[ns]"),
636624
"G": to_timedelta([pd.NaT]),
@@ -672,7 +660,7 @@ def test_mode_dropna(self, dropna, expected):
672660
"A": [12, 12, 19, 11],
673661
"B": [10, 10, np.nan, 3],
674662
"C": [1, np.nan, np.nan, np.nan],
675-
"D": Series([np.nan, np.nan, "a", np.nan], dtype=object),
663+
"D": Series([np.nan, np.nan, "a", np.nan], dtype="str"),
676664
"E": Categorical([np.nan, np.nan, "a", np.nan]),
677665
"F": DatetimeIndex(["NaT", "2000-01-02", "NaT", "NaT"], dtype="M8[ns]"),
678666
"G": to_timedelta(["1 days", "nan", "nan", "nan"]),
@@ -692,7 +680,6 @@ def test_mode_dropna(self, dropna, expected):
692680
expected = DataFrame(expected)
693681
tm.assert_frame_equal(result, expected)
694682

695-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
696683
def test_mode_sortwarning(self, using_infer_string):
697684
# Check for the warning that is raised when the mode
698685
# results cannot be sorted
@@ -1354,11 +1341,8 @@ def test_any_all_extra(self):
13541341
result = df[["C"]].all(axis=None).item()
13551342
assert result is True
13561343

1357-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
13581344
@pytest.mark.parametrize("axis", [0, 1])
1359-
def test_any_all_object_dtype(
1360-
self, axis, all_boolean_reductions, skipna, using_infer_string
1361-
):
1345+
def test_any_all_object_dtype(self, axis, all_boolean_reductions, skipna):
13621346
# GH#35450
13631347
df = DataFrame(
13641348
data=[
@@ -1368,13 +1352,8 @@ def test_any_all_object_dtype(
13681352
[np.nan, np.nan, "5", np.nan],
13691353
]
13701354
)
1371-
if using_infer_string:
1372-
# na in object is True while in string pyarrow numpy it's false
1373-
val = not axis == 0 and not skipna and all_boolean_reductions == "all"
1374-
else:
1375-
val = True
13761355
result = getattr(df, all_boolean_reductions)(axis=axis, skipna=skipna)
1377-
expected = Series([True, True, val, True])
1356+
expected = Series([True, True, True, True])
13781357
tm.assert_series_equal(result, expected)
13791358

13801359
def test_any_datetime(self):
@@ -1939,7 +1918,6 @@ def test_sum_timedelta64_skipna_false():
19391918
tm.assert_series_equal(result, expected)
19401919

19411920

1942-
@pytest.mark.xfail(using_string_dtype(), reason="sum doesn't work with arrow strings")
19431921
def test_mixed_frame_with_integer_sum():
19441922
# https://github.com/pandas-dev/pandas/issues/34520
19451923
df = DataFrame([["a", 1]], columns=list("ab"))

pandas/tests/series/test_reductions.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -166,18 +166,11 @@ def test_validate_stat_keepdims():
166166
np.sum(ser, keepdims=True)
167167

168168

169-
@pytest.mark.xfail(
170-
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
171-
)
172-
def test_mean_with_convertible_string_raises(using_infer_string):
169+
def test_mean_with_convertible_string_raises():
173170
# GH#44008
174171
ser = Series(["1", "2"])
175-
if using_infer_string:
176-
msg = "does not support"
177-
with pytest.raises(TypeError, match=msg):
178-
ser.sum()
179-
else:
180-
assert ser.sum() == "12"
172+
assert ser.sum() == "12"
173+
181174
msg = "Could not convert string '12' to numeric|does not support"
182175
with pytest.raises(TypeError, match=msg):
183176
ser.mean()

0 commit comments

Comments
 (0)