Skip to content

Commit 7f6206c

Browse files
BUG: fix fill value for gouped sum in case of unobserved categories for string dtype (empty string instead of 0) (pandas-dev#61909)
1 parent 1aa7f75 commit 7f6206c

File tree

6 files changed

+40
-16
lines changed

6 files changed

+40
-16
lines changed

pandas/_libs/groupby.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def group_sum(
6565
result_mask: np.ndarray | None = ...,
6666
min_count: int = ...,
6767
is_datetimelike: bool = ...,
68+
initial: object = ...,
6869
) -> None: ...
6970
def group_prod(
7071
out: np.ndarray, # int64float_t[:, ::1]

pandas/_libs/groupby.pyx

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,7 @@ def group_sum(
672672
uint8_t[:, ::1] result_mask=None,
673673
Py_ssize_t min_count=0,
674674
bint is_datetimelike=False,
675+
object initial=0,
675676
) -> None:
676677
"""
677678
Only aggregates on axis=0 using Kahan summation
@@ -689,9 +690,15 @@ def group_sum(
689690
raise ValueError("len(index) != len(labels)")
690691

691692
nobs = np.zeros((<object>out).shape, dtype=np.int64)
692-
# the below is equivalent to `np.zeros_like(out)` but faster
693-
sumx = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
694-
compensation = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
693+
if initial == 0:
694+
# the below is equivalent to `np.zeros_like(out)` but faster
695+
sumx = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
696+
compensation = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
697+
else:
698+
# in practice this path is only taken for strings to use empty string as initial
699+
assert sum_t is object
700+
sumx = np.full((<object>out).shape, initial, dtype=object)
701+
# object code path does not use `compensation`
695702

696703
N, K = (<object>values).shape
697704

pandas/core/arrays/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2366,6 +2366,7 @@ def _groupby_op(
23662366
kind = WrappedCythonOp.get_kind_from_how(how)
23672367
op = WrappedCythonOp(how=how, kind=kind, has_dropped_na=has_dropped_na)
23682368

2369+
initial: Any = 0
23692370
# GH#43682
23702371
if isinstance(self.dtype, StringDtype):
23712372
# StringArray
@@ -2389,6 +2390,7 @@ def _groupby_op(
23892390

23902391
arr = self
23912392
if op.how == "sum":
2393+
initial = ""
23922394
# https://github.com/pandas-dev/pandas/issues/60229
23932395
# All NA should result in the empty string.
23942396
if min_count == 0:
@@ -2405,6 +2407,7 @@ def _groupby_op(
24052407
ngroups=ngroups,
24062408
comp_ids=ids,
24072409
mask=None,
2410+
initial=initial,
24082411
**kwargs,
24092412
)
24102413

pandas/core/groupby/ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import functools
1212
from typing import (
1313
TYPE_CHECKING,
14+
Any,
1415
Callable,
1516
Generic,
1617
final,
@@ -317,6 +318,7 @@ def _cython_op_ndim_compat(
317318
comp_ids: np.ndarray,
318319
mask: npt.NDArray[np.bool_] | None = None,
319320
result_mask: npt.NDArray[np.bool_] | None = None,
321+
initial: Any = 0,
320322
**kwargs,
321323
) -> np.ndarray:
322324
if values.ndim == 1:
@@ -333,6 +335,7 @@ def _cython_op_ndim_compat(
333335
comp_ids=comp_ids,
334336
mask=mask,
335337
result_mask=result_mask,
338+
initial=initial,
336339
**kwargs,
337340
)
338341
if res.shape[0] == 1:
@@ -348,6 +351,7 @@ def _cython_op_ndim_compat(
348351
comp_ids=comp_ids,
349352
mask=mask,
350353
result_mask=result_mask,
354+
initial=initial,
351355
**kwargs,
352356
)
353357

@@ -361,6 +365,7 @@ def _call_cython_op(
361365
comp_ids: np.ndarray,
362366
mask: npt.NDArray[np.bool_] | None,
363367
result_mask: npt.NDArray[np.bool_] | None,
368+
initial: Any = 0,
364369
**kwargs,
365370
) -> np.ndarray: # np.ndarray[ndim=2]
366371
orig_values = values
@@ -415,6 +420,10 @@ def _call_cython_op(
415420
"first",
416421
"sum",
417422
]:
423+
if self.how == "sum":
424+
# pass in through kwargs only for sum (other functions don't have
425+
# the keyword)
426+
kwargs["initial"] = initial
418427
func(
419428
out=result,
420429
counts=counts,

pandas/tests/groupby/test_categorical.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ def f(a):
3232
return a
3333

3434
index = MultiIndex.from_product(map(f, args), names=names)
35+
if isinstance(fill_value, dict):
36+
# fill_value is a dict mapping column names to fill values
37+
# -> reindex column by column (reindex itself does not support this)
38+
res = {}
39+
for col in result.columns:
40+
res[col] = result[col].reindex(index, fill_value=fill_value[col])
41+
return DataFrame(res, index=index).sort_index()
42+
3543
return result.reindex(index, fill_value=fill_value).sort_index()
3644

3745

@@ -340,18 +348,14 @@ def test_apply(ordered):
340348

341349

342350
@pytest.mark.filterwarnings("ignore:invalid value encountered in cast:RuntimeWarning")
343-
def test_observed(request, using_infer_string, observed):
351+
def test_observed(observed, using_infer_string):
344352
# multiple groupers, don't re-expand the output space
345353
# of the grouper
346354
# gh-14942 (implement)
347355
# gh-10132 (back-compat)
348356
# gh-8138 (back-compat)
349357
# gh-8869
350358

351-
if using_infer_string and not observed:
352-
# TODO(infer_string) this fails with filling the string column with 0
353-
request.applymarker(pytest.mark.xfail(reason="TODO(infer_string)"))
354-
355359
cat1 = Categorical(["a", "a", "b", "b"], categories=["a", "b", "z"], ordered=True)
356360
cat2 = Categorical(["c", "d", "c", "d"], categories=["c", "d", "y"], ordered=True)
357361
df = DataFrame({"A": cat1, "B": cat2, "values": [1, 2, 3, 4]})
@@ -379,7 +383,10 @@ def test_observed(request, using_infer_string, observed):
379383
result = gb.sum()
380384
if not observed:
381385
expected = cartesian_product_for_groupers(
382-
expected, [cat1, cat2], list("AB"), fill_value=0
386+
expected,
387+
[cat1, cat2],
388+
list("AB"),
389+
fill_value={"values": 0, "C": ""} if using_infer_string else 0,
383390
)
384391

385392
tm.assert_frame_equal(result, expected)

pandas/tests/groupby/test_timegrouper.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
import pytest
1111
import pytz
1212

13-
from pandas._config import using_string_dtype
14-
1513
import pandas as pd
1614
from pandas import (
1715
DataFrame,
@@ -75,10 +73,7 @@ def groupby_with_truncated_bingrouper(frame_for_truncated_bingrouper):
7573

7674

7775
class TestGroupBy:
78-
# TODO(infer_string) resample sum introduces 0's
79-
# https://github.com/pandas-dev/pandas/issues/60229
80-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
81-
def test_groupby_with_timegrouper(self):
76+
def test_groupby_with_timegrouper(self, using_infer_string):
8277
# GH 4161
8378
# TimeGrouper requires a sorted index
8479
# also verifies that the resultant index has the correct name
@@ -112,11 +107,13 @@ def test_groupby_with_timegrouper(self):
112107
unit=df.index.unit,
113108
)
114109
expected = DataFrame(
115-
{"Buyer": 0, "Quantity": 0},
110+
{"Buyer": "" if using_infer_string else 0, "Quantity": 0},
116111
index=exp_dti,
117112
)
118113
# Cast to object to avoid implicit cast when setting entry to "CarlCarlCarl"
119114
expected = expected.astype({"Buyer": object})
115+
if using_infer_string:
116+
expected = expected.astype({"Buyer": "str"})
120117
expected.iloc[0, 0] = "CarlCarlCarl"
121118
expected.iloc[6, 0] = "CarlCarl"
122119
expected.iloc[18, 0] = "Joe"

0 commit comments

Comments
 (0)