Skip to content

Commit ab30d87

Browse files
fix object-dtype implementation + update tests
1 parent 593653a commit ab30d87

File tree

5 files changed

+24
-29
lines changed

5 files changed

+24
-29
lines changed

pandas/core/array_algos/masked_reductions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def _reductions(
6262
):
6363
return libmissing.NA
6464

65+
if values.dtype == np.dtype(object):
66+
values = values[~mask]
67+
return func(values, axis=axis, **kwargs)
6568
return func(values, where=~mask, axis=axis, **kwargs)
6669

6770

pandas/core/arrays/string_.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -803,15 +803,12 @@ def _reduce(
803803
else:
804804
return nanops.nanall(self._ndarray, skipna=skipna)
805805

806-
if name in ["min", "max"]:
807-
result = getattr(self, name)(skipna=skipna, axis=axis)
806+
if name in ["min", "max", "sum"]:
807+
result = getattr(self, name)(skipna=skipna, axis=axis, **kwargs)
808808
if keepdims:
809809
return self._from_sequence([result], dtype=self.dtype)
810810
return result
811811

812-
if name == "sum":
813-
return nanops.nansum(self._ndarray, skipna=skipna, **kwargs)
814-
815812
raise TypeError(f"Cannot perform reduction '{name}' with string dtype")
816813

817814
def _wrap_reduction_result(self, axis: AxisInt | None, result) -> Any:
@@ -834,6 +831,20 @@ def max(self, axis=None, skipna: bool = True, **kwargs) -> Scalar:
834831
)
835832
return self._wrap_reduction_result(axis, result)
836833

834+
def sum(
835+
self,
836+
*,
837+
axis: AxisInt | None = None,
838+
skipna: bool = True,
839+
min_count: int = 0,
840+
**kwargs,
841+
) -> Scalar:
842+
nv.validate_sum((), kwargs)
843+
result = masked_reductions.sum(
844+
values=self._ndarray, mask=self.isna(), skipna=skipna
845+
)
846+
return self._wrap_reduction_result(axis, result)
847+
837848
def value_counts(self, dropna: bool = True) -> Series:
838849
from pandas.core.algorithms import value_counts_internal as value_counts
839850

pandas/tests/arrays/string_/test_string.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,14 +444,12 @@ def test_astype_float(dtype, any_float_dtype):
444444
tm.assert_series_equal(result, expected)
445445

446446

447-
@pytest.mark.xfail(reason="Not implemented StringArray.sum")
448447
def test_reduce(skipna, dtype):
449448
arr = pd.Series(["a", "b", "c"], dtype=dtype)
450449
result = arr.sum(skipna=skipna)
451450
assert result == "abc"
452451

453452

454-
@pytest.mark.xfail(reason="Not implemented StringArray.sum")
455453
def test_reduce_missing(skipna, dtype):
456454
arr = pd.Series([None, "a", None, "b", "c", None], dtype=dtype)
457455
result = arr.sum(skipna=skipna)

pandas/tests/extension/test_arrow.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -480,10 +480,11 @@ def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
480480
pass
481481
else:
482482
return False
483+
elif pa.types.is_binary(pa_dtype) and op_name == "sum":
484+
return False
483485
elif (
484486
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
485487
) and op_name in [
486-
"sum",
487488
"mean",
488489
"median",
489490
"prod",
@@ -582,6 +583,8 @@ def _get_expected_reduction_dtype(self, arr, op_name: str, skipna: bool):
582583
cmp_dtype = "float64[pyarrow]"
583584
elif op_name in ["sum", "prod"] and pa.types.is_boolean(pa_type):
584585
cmp_dtype = "uint64[pyarrow]"
586+
elif op_name == "sum" and pa.types.is_string(pa_type):
587+
cmp_dtype = arr.dtype
585588
else:
586589
cmp_dtype = {
587590
"i": "int64[pyarrow]",
@@ -613,26 +616,6 @@ def test_median_not_approximate(self, typ):
613616
result = pd.Series([1, 2], dtype=f"{typ}[pyarrow]").median()
614617
assert result == 1.5
615618

616-
def test_in_numeric_groupby(self, data_for_grouping):
617-
dtype = data_for_grouping.dtype
618-
if is_string_dtype(dtype):
619-
df = pd.DataFrame(
620-
{
621-
"A": [1, 1, 2, 2, 3, 3, 1, 4],
622-
"B": data_for_grouping,
623-
"C": [1, 1, 1, 1, 1, 1, 1, 1],
624-
}
625-
)
626-
627-
expected = pd.Index(["C"])
628-
msg = re.escape(f"agg function failed [how->sum,dtype->{dtype}")
629-
with pytest.raises(TypeError, match=msg):
630-
df.groupby("A").sum()
631-
result = df.groupby("A").sum(numeric_only=True).columns
632-
tm.assert_index_equal(result, expected)
633-
else:
634-
super().test_in_numeric_groupby(data_for_grouping)
635-
636619
def test_construct_from_string_own_name(self, dtype, request):
637620
pa_dtype = dtype.pyarrow_dtype
638621
if pa.types.is_decimal(pa_dtype):

pandas/tests/extension/test_string.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def _get_expected_exception(
188188

189189
def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
190190
return (
191-
op_name in ["min", "max"]
191+
op_name in ["min", "max", "sum"]
192192
or ser.dtype.na_value is np.nan # type: ignore[union-attr]
193193
and op_name in ("any", "all")
194194
)

0 commit comments

Comments
 (0)