Skip to content

Commit 07fe195

Browse files
committed
BUG: groupby.groups with NA categories fails
1 parent 44c5613 commit 07fe195

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

pandas/core/groupby/grouper.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,16 @@
1212

1313
import numpy as np
1414

15+
from pandas._libs import (
16+
algos as libalgos,
17+
)
1518
from pandas._libs.tslibs import OutOfBoundsDatetime
1619
from pandas.errors import InvalidIndexError
1720
from pandas.util._decorators import cache_readonly
1821

1922
from pandas.core.dtypes.common import (
23+
ensure_int64,
24+
ensure_platform_int,
2025
is_list_like,
2126
is_scalar,
2227
)
@@ -38,7 +43,10 @@
3843
)
3944
from pandas.core.series import Series
4045

41-
from pandas.io.formats.printing import pprint_thing
46+
from pandas.io.formats.printing import (
47+
PrettyDict,
48+
pprint_thing,
49+
)
4250

4351
if TYPE_CHECKING:
4452
from collections.abc import (
@@ -668,8 +676,16 @@ def _codes_and_uniques(self) -> tuple[npt.NDArray[np.signedinteger], ArrayLike]:
668676
def groups(self) -> dict[Hashable, Index]:
669677
codes, uniques = self._codes_and_uniques
670678
uniques = Index._with_infer(uniques, name=self.name)
671-
cats = Categorical.from_codes(codes, uniques, validate=False)
672-
return self._index.groupby(cats)
679+
680+
r, counts = libalgos.groupsort_indexer(ensure_platform_int(codes), len(uniques))
681+
counts = ensure_int64(counts).cumsum()
682+
_result = (r[start:end] for start, end in zip(counts, counts[1:]))
683+
result = dict(zip(uniques, _result))
684+
685+
# map to the label
686+
result = {k: self._index.take(v) for k, v in result.items()}
687+
688+
return PrettyDict(result)
673689

674690
@property
675691
def observed_grouping(self) -> Grouping:

pandas/tests/groupby/test_categorical.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,23 @@ def test_observed_groups(observed):
506506
tm.assert_dict_equal(result, expected)
507507

508508

509+
def test_groups_na_category(dropna, observed):
510+
# https://github.com/pandas-dev/pandas/issues/61356
511+
df = DataFrame(
512+
{"cat": Categorical(["a", np.nan, "a"], categories=list("adb"))},
513+
index=list("xyz"),
514+
)
515+
g = df.groupby("cat", observed=observed, dropna=dropna)
516+
517+
result = g.groups
518+
expected = {"a": Index(["x", "z"])}
519+
if not dropna:
520+
expected |= {np.nan: Index(["y"])}
521+
if not observed:
522+
expected |= {"b": Index([]), "d": Index([])}
523+
tm.assert_dict_equal(result, expected)
524+
525+
509526
@pytest.mark.parametrize(
510527
"keys, expected_values, expected_index_levels",
511528
[

0 commit comments

Comments
 (0)