Skip to content

Commit d4ea91e

Browse files
GH203 Split groupby with as_index
1 parent fecd8e9 commit d4ea91e

File tree

3 files changed

+38
-5
lines changed

3 files changed

+38
-5
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ from pandas import (
2525
)
2626
from pandas.core.arraylike import OpsMixin
2727
from pandas.core.generic import NDFrame
28-
from pandas.core.groupby.generic import DataFrameGroupBy
28+
from pandas.core.groupby.generic import (
29+
DataFrameGroupBy,
30+
SeriesGroupBy,
31+
)
2932
from pandas.core.groupby.grouper import Grouper
3033
from pandas.core.indexers import BaseIndexer
3134
from pandas.core.indexes.base import Index
@@ -1052,18 +1055,30 @@ class DataFrame(NDFrame, OpsMixin):
10521055
errors: IgnoreRaise = ...,
10531056
) -> None: ...
10541057
@overload
1055-
def groupby(
1058+
def groupby( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
10561059
self,
10571060
by: Scalar,
10581061
axis: AxisIndex | NoDefault = ...,
10591062
level: IndexLabel | None = ...,
1060-
as_index: _bool = ...,
1063+
as_index: Literal[False] = ...,
10611064
sort: _bool = ...,
10621065
group_keys: _bool = ...,
10631066
observed: _bool | NoDefault = ...,
10641067
dropna: _bool = ...,
10651068
) -> DataFrameGroupBy[Scalar]: ...
10661069
@overload
1070+
def groupby(
1071+
self,
1072+
by: Scalar,
1073+
axis: AxisIndex | NoDefault = ...,
1074+
level: IndexLabel | None = ...,
1075+
as_index: Literal[True] = True,
1076+
sort: _bool = ...,
1077+
group_keys: _bool = ...,
1078+
observed: _bool | NoDefault = ...,
1079+
dropna: _bool = ...,
1080+
) -> SeriesGroupBy: ...
1081+
@overload
10671082
def groupby(
10681083
self,
10691084
by: DatetimeIndex,

pandas-stubs/core/groupby/groupby.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ class GroupBy(BaseGroupBy[NDFrameT]):
236236
@overload
237237
def size(self: GroupBy[Series]) -> Series[int]: ...
238238
@overload # return type depends on `as_index` for dataframe groupby
239-
def size(self: GroupBy[DataFrame]) -> DataFrame | Series[int]: ...
239+
def size(self: GroupBy[DataFrame]) -> DataFrame: ...
240240
@final
241241
def sum(
242242
self,

tests/test_frame.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,24 @@ def test_types_pivot_table() -> None:
10251025
)
10261026

10271027

1028+
def test_types_groupby_as_index() -> None:
1029+
df = pd.DataFrame({"a": [1, 2, 3]})
1030+
check(
1031+
assert_type(
1032+
df.groupby("a", as_index=False).size(),
1033+
pd.DataFrame,
1034+
),
1035+
pd.DataFrame,
1036+
)
1037+
check(
1038+
assert_type(
1039+
df.groupby("a", as_index=True).size(),
1040+
"pd.Series[int]",
1041+
),
1042+
pd.Series,
1043+
)
1044+
1045+
10281046
def test_types_groupby() -> None:
10291047
df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5], "col3": [0, 1, 0]})
10301048
df.index.name = "ind"
@@ -1048,7 +1066,7 @@ def test_types_groupby() -> None:
10481066

10491067
df1: pd.DataFrame = df.groupby(by="col1").agg("sum")
10501068
df2: pd.DataFrame = df.groupby(level="ind").aggregate("sum")
1051-
df3: pd.DataFrame = df.groupby(by="col1", sort=False, as_index=True).transform(
1069+
df3: pd.Series = df.groupby(by="col1", sort=False, as_index=True).transform(
10521070
lambda x: x.max()
10531071
)
10541072
df4: pd.DataFrame = df.groupby(by=["col1", "col2"]).count()

0 commit comments

Comments
 (0)