Skip to content

Commit f2e6e38

Browse files
Update to the fix
1 parent d4ea91e commit f2e6e38

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ from re import Pattern
1212
from typing import (
1313
Any,
1414
ClassVar,
15+
Generic,
1516
Literal,
17+
TypeVar,
18+
Union,
1619
overload,
1720
)
1821

@@ -77,6 +80,7 @@ from pandas._typing import (
7780
Axis,
7881
AxisColumn,
7982
AxisIndex,
83+
ByT,
8084
CalculationMethod,
8185
ColspaceArgType,
8286
CompressionOptions,
@@ -232,6 +236,14 @@ class _LocIndexerFrame(_LocIndexer):
232236
value: Scalar | NAType | NaTType | ArrayLike | Series | list | None,
233237
) -> None: ...
234238

239+
TT = TypeVar("TT", bound=Union[Literal[True], Literal[False]])
240+
241+
class DataFrameGroupByGen(DataFrameGroupBy[ByT], Generic[ByT, TT]):
242+
pass
243+
244+
class SeriesGroupByGen(SeriesGroupBy, Generic[TT, ByT]):
245+
pass
246+
235247
class DataFrame(NDFrame, OpsMixin):
236248
__hash__: ClassVar[None] # type: ignore[assignment]
237249

@@ -1055,29 +1067,29 @@ class DataFrame(NDFrame, OpsMixin):
10551067
errors: IgnoreRaise = ...,
10561068
) -> None: ...
10571069
@overload
1058-
def groupby( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
1070+
def groupby( # type: ignore[overload-overlap] # pyright: ignore reportOverlappingOverload
10591071
self,
10601072
by: Scalar,
10611073
axis: AxisIndex | NoDefault = ...,
10621074
level: IndexLabel | None = ...,
1063-
as_index: Literal[False] = ...,
1075+
as_index: Literal[True] = True,
10641076
sort: _bool = ...,
10651077
group_keys: _bool = ...,
10661078
observed: _bool | NoDefault = ...,
10671079
dropna: _bool = ...,
1068-
) -> DataFrameGroupBy[Scalar]: ...
1080+
) -> DataFrameGroupByGen[Scalar, Literal[True]]: ...
10691081
@overload
10701082
def groupby(
10711083
self,
10721084
by: Scalar,
10731085
axis: AxisIndex | NoDefault = ...,
10741086
level: IndexLabel | None = ...,
1075-
as_index: Literal[True] = True,
1087+
as_index: Literal[False] = ...,
10761088
sort: _bool = ...,
10771089
group_keys: _bool = ...,
10781090
observed: _bool | NoDefault = ...,
10791091
dropna: _bool = ...,
1080-
) -> SeriesGroupBy: ...
1092+
) -> DataFrameGroupByGen[Scalar, Literal[False]]: ...
10811093
@overload
10821094
def groupby(
10831095
self,

pandas-stubs/core/groupby/groupby.pyi

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ from typing import (
1818

1919
import numpy as np
2020
from pandas.core.base import SelectionMixin
21-
from pandas.core.frame import DataFrame
21+
from pandas.core.frame import (
22+
DataFrame,
23+
DataFrameGroupByGen,
24+
)
2225
from pandas.core.groupby import (
2326
generic,
2427
ops,
@@ -53,6 +56,7 @@ from pandas._typing import (
5356
AnyArrayLike,
5457
Axis,
5558
AxisInt,
59+
ByT,
5660
CalculationMethod,
5761
Dtype,
5862
Frequency,
@@ -235,8 +239,10 @@ class GroupBy(BaseGroupBy[NDFrameT]):
235239
@final
236240
@overload
237241
def size(self: GroupBy[Series]) -> Series[int]: ...
238-
@overload # return type depends on `as_index` for dataframe groupby
239-
def size(self: GroupBy[DataFrame]) -> DataFrame: ...
242+
@overload
243+
def size(self: DataFrameGroupByGen[ByT, Literal[True]]) -> Series[int]: ... # type: ignore[misc]
244+
@overload
245+
def size(self: DataFrameGroupByGen[ByT, Literal[False]]) -> DataFrame: ... # type: ignore[misc]
240246
@final
241247
def sum(
242248
self,

tests/test_frame.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,7 @@ def test_types_groupby() -> None:
10661066

10671067
df1: pd.DataFrame = df.groupby(by="col1").agg("sum")
10681068
df2: pd.DataFrame = df.groupby(level="ind").aggregate("sum")
1069-
df3: pd.Series = df.groupby(by="col1", sort=False, as_index=True).transform(
1069+
df3: pd.DataFrame = df.groupby(by="col1", sort=False, as_index=True).transform(
10701070
lambda x: x.max()
10711071
)
10721072
df4: pd.DataFrame = df.groupby(by=["col1", "col2"]).count()

0 commit comments

Comments
 (0)