Skip to content

Commit 2b69e46

Browse files
committed
move key_dtype_str from BaseGrouper to Grouping
1 parent 06f5ff3 commit 2b69e46

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

pandas/core/groupby/grouper.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ def __init__(
440440
in_axis: bool = False,
441441
dropna: bool = True,
442442
uniques: ArrayLike | None = None,
443+
key_dtype_str: bool = False,
443444
) -> None:
444445
self.level = level
445446
self._orig_grouper = grouper
@@ -452,6 +453,7 @@ def __init__(
452453
self.in_axis = in_axis
453454
self._dropna = dropna
454455
self._uniques = uniques
456+
self.key_dtype_str = key_dtype_str
455457

456458
# we have a single grouper which may be a myriad of things,
457459
# some of which are dependent on the passing in level
@@ -666,6 +668,8 @@ def groups(self) -> dict[Hashable, Index]:
666668
codes, uniques = self._codes_and_uniques
667669
uniques = Index._with_infer(uniques, name=self.name)
668670
cats = Categorical.from_codes(codes, uniques, validate=False)
671+
if not self.key_dtype_str:
672+
cats = [(i,) for i in cats]
669673
return self._index.groupby(cats)
670674

671675
@property
@@ -893,6 +897,7 @@ def is_in_obj(gpr) -> bool:
893897
observed=observed,
894898
in_axis=in_axis,
895899
dropna=dropna,
900+
key_dtype_str=key_dtype_str,
896901
)
897902
if not isinstance(gpr, Grouping)
898903
else gpr
@@ -905,9 +910,7 @@ def is_in_obj(gpr) -> bool:
905910
groupings.append(Grouping(Index([], dtype="int"), np.array([], dtype=np.intp)))
906911

907912
# create the internals grouper
908-
grouper = ops.BaseGrouper(
909-
group_axis, groupings, sort=sort, dropna=dropna, key_dtype_str=key_dtype_str
910-
)
913+
grouper = ops.BaseGrouper(group_axis, groupings, sort=sort, dropna=dropna)
911914
return grouper, frozenset(exclusions), obj
912915

913916

pandas/core/groupby/ops.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -584,15 +584,13 @@ def __init__(
584584
groupings: Sequence[grouper.Grouping],
585585
sort: bool = True,
586586
dropna: bool = True,
587-
key_dtype_str: bool = True,
588587
) -> None:
589588
assert isinstance(axis, Index), axis
590589

591590
self.axis = axis
592591
self._groupings: list[grouper.Grouping] = list(groupings)
593592
self._sort = sort
594593
self.dropna = dropna
595-
self.key_dtype_str = key_dtype_str
596594

597595
@property
598596
def groupings(self) -> list[grouper.Grouping]:
@@ -706,10 +704,7 @@ def size(self) -> Series:
706704
@cache_readonly
707705
def groups(self) -> dict[Hashable, Index]:
708706
"""dict {group name -> group labels}"""
709-
if len(self.groupings) == 1 and not self.key_dtype_str:
710-
result = self.groupings[0].groups
711-
712-
if self.key_dtype_str and len(self.groupings) == 1:
707+
if len(self.groupings) == 1:
713708
return self.groupings[0].groups
714709
result_index, ids = self.result_index_and_ids
715710
values = result_index._values
@@ -770,7 +765,7 @@ def result_index_and_ids(self) -> tuple[Index, npt.NDArray[np.intp]]:
770765
if ping._passed_categorical:
771766
levels[k] = level.set_categories(ping._orig_cats)
772767

773-
if self.key_dtype_str and len(self.groupings) == 1:
768+
if len(self.groupings) == 1:
774769
result_index = levels[0]
775770
result_index.name = self.names[0]
776771
ids = ensure_platform_int(self.codes[0])

0 commit comments

Comments
 (0)