Skip to content

Commit 06f5ff3

Browse files
committed
add to BaseGrouper constructor key_dtype_str paramater, correct def result_index_and_ids and def groups
1 parent 5007ce2 commit 06f5ff3

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

pandas/core/groupby/groupby.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,7 @@ def __iter__(self) -> Iterator[tuple[Hashable, NDFrameT]]:
10121012
# mypy: Argument 1 to "len" has incompatible type "Hashable"; expected "Sized"
10131013
if (
10141014
(is_list_like(level) and len(level) == 1) # type: ignore[arg-type]
1015-
or (isinstance(keys, list))
1015+
or (isinstance(keys, list) and len(keys) == 1)
10161016
):
10171017
# GH#42795 - when keys is a list, return tuples even when length is 1
10181018
result = (((key,), group) for key, group in result)

pandas/core/groupby/grouper.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,9 @@ def get_grouper(
780780
elif isinstance(key, ops.BaseGrouper):
781781
return key, frozenset(), obj
782782

783+
key_dtype_str = False
783784
if not isinstance(key, list):
785+
key_dtype_str = True
784786
keys = [key]
785787
match_axis_length = False
786788
else:
@@ -895,7 +897,6 @@ def is_in_obj(gpr) -> bool:
895897
if not isinstance(gpr, Grouping)
896898
else gpr
897899
)
898-
899900
groupings.append(ping)
900901

901902
if len(groupings) == 0 and len(obj):
@@ -904,7 +905,9 @@ def is_in_obj(gpr) -> bool:
904905
groupings.append(Grouping(Index([], dtype="int"), np.array([], dtype=np.intp)))
905906

906907
# create the internals grouper
907-
grouper = ops.BaseGrouper(group_axis, groupings, sort=sort, dropna=dropna)
908+
grouper = ops.BaseGrouper(
909+
group_axis, groupings, sort=sort, dropna=dropna, key_dtype_str=key_dtype_str
910+
)
908911
return grouper, frozenset(exclusions), obj
909912

910913

pandas/core/groupby/ops.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -584,13 +584,15 @@ def __init__(
584584
groupings: Sequence[grouper.Grouping],
585585
sort: bool = True,
586586
dropna: bool = True,
587+
key_dtype_str: bool = True,
587588
) -> None:
588589
assert isinstance(axis, Index), axis
589590

590591
self.axis = axis
591592
self._groupings: list[grouper.Grouping] = list(groupings)
592593
self._sort = sort
593594
self.dropna = dropna
595+
self.key_dtype_str = key_dtype_str
594596

595597
@property
596598
def groupings(self) -> list[grouper.Grouping]:
@@ -640,6 +642,7 @@ def _get_splitter(self, data: NDFrame) -> DataSplitter:
640642
@cache_readonly
641643
def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]:
642644
"""dict {group name -> group indices}"""
645+
643646
if len(self.groupings) == 1 and isinstance(self.result_index, CategoricalIndex):
644647
# This shows unused categories in indices GH#38642
645648
return self.groupings[0].indices
@@ -703,7 +706,10 @@ def size(self) -> Series:
703706
@cache_readonly
704707
def groups(self) -> dict[Hashable, Index]:
705708
"""dict {group name -> group labels}"""
706-
if len(self.groupings) == 1 and not isinstance(self.shape, tuple):
709+
if len(self.groupings) == 1 and not self.key_dtype_str:
710+
result = self.groupings[0].groups
711+
712+
if self.key_dtype_str and len(self.groupings) == 1:
707713
return self.groupings[0].groups
708714
result_index, ids = self.result_index_and_ids
709715
values = result_index._values
@@ -764,7 +770,7 @@ def result_index_and_ids(self) -> tuple[Index, npt.NDArray[np.intp]]:
764770
if ping._passed_categorical:
765771
levels[k] = level.set_categories(ping._orig_cats)
766772

767-
if len(self.groupings) == 1:
773+
if self.key_dtype_str and len(self.groupings) == 1:
768774
result_index = levels[0]
769775
result_index.name = self.names[0]
770776
ids = ensure_platform_int(self.codes[0])

0 commit comments

Comments
 (0)