Skip to content

Commit 51b363e

Browse files
committed
ENH(string dtype): Implement cumsum for Python-backed strings
1 parent 6bcd303 commit 51b363e

File tree

6 files changed

+113
-22
lines changed

6 files changed

+113
-22
lines changed

doc/source/whatsnew/v2.3.0.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Other enhancements
3737
updated to work correctly with NumPy >= 2 (:issue:`57739`)
3838
- :meth:`Series.str.decode` result now has ``StringDtype`` when ``future.infer_string`` is True (:issue:`60709`)
3939
- :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` (:issue:`60663`)
40-
- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`)
40+
- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns (:issue:`60633`, :issue:`???`)
4141
- The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)
4242

4343
.. ---------------------------------------------------------------------------

pandas/core/arrays/string_.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,100 @@ def _reduce(
870870

871871
raise TypeError(f"Cannot perform reduction '{name}' with string dtype")
872872

873+
def _accumulate(self, name: str, *, skipna: bool = True, **kwargs) -> StringArray:
874+
"""
875+
Return an ExtensionArray performing an accumulation operation.
876+
877+
The underlying data type might change.
878+
879+
Parameters
880+
----------
881+
name : str
882+
Name of the function, supported values are:
883+
- cummin
884+
- cummax
885+
- cumsum
886+
- cumprod
887+
skipna : bool, default True
888+
If True, skip NA values.
889+
**kwargs
890+
Additional keyword arguments passed to the accumulation function.
891+
Currently, there is no supported kwarg.
892+
893+
Returns
894+
-------
895+
array
896+
897+
Raises
898+
------
899+
NotImplementedError : subclass does not define accumulations
900+
"""
901+
if is_string_dtype(self):
902+
return self._str_accumulate(name=name, skipna=skipna, **kwargs)
903+
return super()._accumulate(name=name, skipna=skipna, **kwargs)
904+
905+
def _str_accumulate(
906+
self, name: str, *, skipna: bool = True, **kwargs
907+
) -> StringArray:
908+
"""
909+
Accumulate implementation for strings, see `_accumulate` docstring for details.
910+
911+
pyarrow.compute does not implement these methods for strings.
912+
"""
913+
if name == "cumprod":
914+
msg = f"operation '{name}' not supported for dtype '{self.dtype}'"
915+
raise TypeError(msg)
916+
917+
# We may need to strip out trailing NA values
918+
tail: np.array | None = None
919+
na_mask: np.array | None = None
920+
ndarray = self._ndarray
921+
np_func = {
922+
"cumsum": np.cumsum,
923+
"cummin": np.minimum.accumulate,
924+
"cummax": np.maximum.accumulate,
925+
}[name]
926+
927+
from pandas.core import missing
928+
929+
if self._hasna:
930+
na_mask = isna(ndarray)
931+
if np.all(na_mask):
932+
return type(self)(ndarray)
933+
if skipna:
934+
if name == "cumsum":
935+
ndarray = np.where(na_mask, "", ndarray)
936+
else:
937+
# We can retain the running min/max by forward/backward filling.
938+
ndarray = ndarray.copy()
939+
missing.pad_or_backfill_inplace(
940+
ndarray.T,
941+
method="pad",
942+
axis=0,
943+
)
944+
missing.pad_or_backfill_inplace(
945+
ndarray.T,
946+
method="backfill",
947+
axis=0,
948+
)
949+
else:
950+
# When not skipping NA values, the result should be null from
951+
# the first NA value onward.
952+
idx = np.argmax(na_mask)
953+
tail = np.empty(len(ndarray) - idx, dtype="object")
954+
tail[:] = np.nan
955+
ndarray = ndarray[:idx]
956+
957+
np_result = np_func(ndarray)
958+
959+
if tail is not None:
960+
np_result = np.hstack((np_result, tail))
961+
elif na_mask is not None:
962+
np_result = np.where(na_mask, np.nan, np_result)
963+
964+
result = type(self)(np_result)
965+
return result
966+
873967
def _wrap_reduction_result(self, axis: AxisInt | None, result) -> Any:
874968
if self.dtype.na_value is np.nan and result is libmissing.NA:
875969
# the masked_reductions use pd.NA -> convert to np.nan

pandas/tests/apply/test_str.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import pytest
66

77
from pandas.compat import (
8-
HAS_PYARROW,
98
WASM,
109
)
1110

@@ -166,13 +165,13 @@ def test_agg_cython_table_transform_series(request, series, func, expected):
166165
# GH21224
167166
# test transforming functions in
168167
# pandas.core.base.SelectionMixin._cython_table (cumprod, cumsum)
169-
if series.dtype == "string" and func == "cumsum" and not HAS_PYARROW:
170-
request.applymarker(
171-
pytest.mark.xfail(
172-
raises=NotImplementedError,
173-
reason="TODO(infer_string) cumsum not yet implemented for string",
174-
)
175-
)
168+
# if series.dtype == "string" and func == "cumsum" and not HAS_PYARROW:
169+
# request.applymarker(
170+
# pytest.mark.xfail(
171+
# raises=NotImplementedError,
172+
# reason="TODO(infer_string) cumsum not yet implemented for string",
173+
# )
174+
# )
176175
warn = None if isinstance(func, str) else FutureWarning
177176
with tm.assert_produces_warning(warn, match="is currently using Series.*"):
178177
result = series.agg(func)

pandas/tests/extension/test_string.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,7 @@ def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
196196

197197
def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
198198
assert isinstance(ser.dtype, StorageExtensionDtype)
199-
return ser.dtype.storage == "pyarrow" and op_name in [
200-
"cummin",
201-
"cummax",
202-
"cumsum",
203-
]
199+
return op_name in ["cummin", "cummax", "cumsum"]
204200

205201
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
206202
dtype = cast(StringDtype, tm.get_dtype(obj))

pandas/tests/groupby/test_categorical.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -325,9 +325,9 @@ def test_observed(request, using_infer_string, observed):
325325
# gh-8138 (back-compat)
326326
# gh-8869
327327

328-
if using_infer_string and not observed:
329-
# TODO(infer_string) this fails with filling the string column with 0
330-
request.applymarker(pytest.mark.xfail(reason="TODO(infer_string)"))
328+
# if using_infer_string and not observed:
329+
# # TODO(infer_string) this fails with filling the string column with 0
330+
# request.applymarker(pytest.mark.xfail(reason="TODO(infer_string)"))
331331

332332
cat1 = Categorical(["a", "a", "b", "b"], categories=["a", "b", "z"], ordered=True)
333333
cat2 = Categorical(["c", "d", "c", "d"], categories=["c", "d", "y"], ordered=True)
@@ -355,10 +355,12 @@ def test_observed(request, using_infer_string, observed):
355355
)
356356
result = gb.sum()
357357
if not observed:
358+
fill_value = "" if using_infer_string else 0
358359
expected = cartesian_product_for_groupers(
359-
expected, [cat1, cat2], list("AB"), fill_value=0
360+
expected, [cat1, cat2], list("AB"), fill_value=fill_value
360361
)
361-
362+
print(result)
363+
print(expected)
362364
tm.assert_frame_equal(result, expected)
363365

364366

pandas/tests/series/test_cumulative.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,12 +266,12 @@ def test_cumprod_timedelta(self):
266266
],
267267
)
268268
def test_cum_methods_pyarrow_strings(
269-
self, pyarrow_string_dtype, data, op, skipna, expected_data
269+
self, string_dtype_no_object, data, op, skipna, expected_data
270270
):
271271
# https://github.com/pandas-dev/pandas/pull/60633
272-
ser = pd.Series(data, dtype=pyarrow_string_dtype)
272+
ser = pd.Series(data, dtype=string_dtype_no_object)
273273
method = getattr(ser, op)
274-
expected = pd.Series(expected_data, dtype=pyarrow_string_dtype)
274+
expected = pd.Series(expected_data, dtype=string_dtype_no_object)
275275
result = method(skipna=skipna)
276276
tm.assert_series_equal(result, expected)
277277

0 commit comments

Comments
 (0)