Skip to content

Commit 9563f04

Browse files
Experiment for size
1 parent d0d08a9 commit 9563f04

File tree

4 files changed

+36
-34
lines changed

4 files changed

+36
-34
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ from re import Pattern
1212
from typing import (
1313
Any,
1414
ClassVar,
15-
Generic,
1615
Literal,
17-
TypeVar,
1816
overload,
1917
)
2018

@@ -27,10 +25,7 @@ from pandas import (
2725
)
2826
from pandas.core.arraylike import OpsMixin
2927
from pandas.core.generic import NDFrame
30-
from pandas.core.groupby.generic import (
31-
DataFrameGroupBy,
32-
SeriesGroupBy,
33-
)
28+
from pandas.core.groupby.generic import DataFrameGroupBy
3429
from pandas.core.groupby.grouper import Grouper
3530
from pandas.core.indexers import BaseIndexer
3631
from pandas.core.indexes.base import Index
@@ -79,7 +74,6 @@ from pandas._typing import (
7974
Axis,
8075
AxisColumn,
8176
AxisIndex,
82-
ByT,
8377
CalculationMethod,
8478
ColspaceArgType,
8579
CompressionOptions,
@@ -235,11 +229,6 @@ class _LocIndexerFrame(_LocIndexer):
235229
value: Scalar | NAType | NaTType | ArrayLike | Series | list | None,
236230
) -> None: ...
237231

238-
_TT = TypeVar("TT", bound=Literal[True, False])
239-
240-
class DataFrameGroupByGen(DataFrameGroupBy[ByT], Generic[ByT, _TT]): ...
241-
class SeriesGroupByGen(SeriesGroupBy, Generic[_TT, ByT]): ...
242-
243232
class DataFrame(NDFrame, OpsMixin):
244233
__hash__: ClassVar[None] # type: ignore[assignment]
245234

@@ -1073,7 +1062,7 @@ class DataFrame(NDFrame, OpsMixin):
10731062
group_keys: _bool = ...,
10741063
observed: _bool | NoDefault = ...,
10751064
dropna: _bool = ...,
1076-
) -> DataFrameGroupByGen[Scalar, Literal[True]]: ...
1065+
) -> DataFrameGroupBy[Scalar, Literal[True]]: ...
10771066
@overload
10781067
def groupby(
10791068
self,
@@ -1085,7 +1074,7 @@ class DataFrame(NDFrame, OpsMixin):
10851074
group_keys: _bool = ...,
10861075
observed: _bool | NoDefault = ...,
10871076
dropna: _bool = ...,
1088-
) -> DataFrameGroupByGen[Scalar, Literal[False]]: ...
1077+
) -> DataFrameGroupBy[Scalar, Literal[False]]: ...
10891078
@overload
10901079
def groupby(
10911080
self,
@@ -1097,7 +1086,7 @@ class DataFrame(NDFrame, OpsMixin):
10971086
group_keys: _bool = ...,
10981087
observed: _bool | NoDefault = ...,
10991088
dropna: _bool = ...,
1100-
) -> DataFrameGroupBy[Timestamp]: ...
1089+
) -> DataFrameGroupBy[Timestamp, bool]: ...
11011090
@overload
11021091
def groupby(
11031092
self,
@@ -1109,7 +1098,7 @@ class DataFrame(NDFrame, OpsMixin):
11091098
group_keys: _bool = ...,
11101099
observed: _bool | NoDefault = ...,
11111100
dropna: _bool = ...,
1112-
) -> DataFrameGroupBy[Timedelta]: ...
1101+
) -> DataFrameGroupBy[Timedelta, bool]: ...
11131102
@overload
11141103
def groupby(
11151104
self,
@@ -1121,7 +1110,7 @@ class DataFrame(NDFrame, OpsMixin):
11211110
group_keys: _bool = ...,
11221111
observed: _bool | NoDefault = ...,
11231112
dropna: _bool = ...,
1124-
) -> DataFrameGroupBy[Period]: ...
1113+
) -> DataFrameGroupBy[Period, bool]: ...
11251114
@overload
11261115
def groupby(
11271116
self,
@@ -1133,7 +1122,7 @@ class DataFrame(NDFrame, OpsMixin):
11331122
group_keys: _bool = ...,
11341123
observed: _bool | NoDefault = ...,
11351124
dropna: _bool = ...,
1136-
) -> DataFrameGroupBy[IntervalT]: ...
1125+
) -> DataFrameGroupBy[IntervalT, bool]: ...
11371126
@overload
11381127
def groupby(
11391128
self,
@@ -1145,7 +1134,7 @@ class DataFrame(NDFrame, OpsMixin):
11451134
group_keys: _bool = ...,
11461135
observed: _bool | NoDefault = ...,
11471136
dropna: _bool = ...,
1148-
) -> DataFrameGroupBy[tuple]: ...
1137+
) -> DataFrameGroupBy[tuple, bool]: ...
11491138
@overload
11501139
def groupby(
11511140
self,
@@ -1157,7 +1146,7 @@ class DataFrame(NDFrame, OpsMixin):
11571146
group_keys: _bool = ...,
11581147
observed: _bool | NoDefault = ...,
11591148
dropna: _bool = ...,
1160-
) -> DataFrameGroupBy[SeriesByT]: ...
1149+
) -> DataFrameGroupBy[SeriesByT, bool]: ...
11611150
@overload
11621151
def groupby(
11631152
self,
@@ -1169,7 +1158,7 @@ class DataFrame(NDFrame, OpsMixin):
11691158
group_keys: _bool = ...,
11701159
observed: _bool | NoDefault = ...,
11711160
dropna: _bool = ...,
1172-
) -> DataFrameGroupBy[Any]: ...
1161+
) -> DataFrameGroupBy[Any, bool]: ...
11731162
def pivot(
11741163
self,
11751164
*,

pandas-stubs/core/groupby/generic.pyi

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ from typing import (
1111
Generic,
1212
Literal,
1313
NamedTuple,
14+
TypeVar,
1415
final,
1516
overload,
1617
)
@@ -182,7 +183,9 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]):
182183
self,
183184
) -> Iterator[tuple[ByT, Series[S1]]]: ...
184185

185-
class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]):
186+
_TT = TypeVar("_TT", bound=Literal[True, False])
187+
188+
class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
186189
# error: Overload 3 for "apply" will never be used because its parameters overlap overload 1
187190
@overload # type: ignore[override]
188191
def apply( # type: ignore[overload-overlap]
@@ -236,7 +239,7 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]):
236239
@overload
237240
def __getitem__( # pyright: ignore[reportIncompatibleMethodOverride]
238241
self, key: Iterable[Hashable] | slice
239-
) -> DataFrameGroupBy[ByT]: ...
242+
) -> DataFrameGroupBy[ByT, bool]: ...
240243
def nunique(self, dropna: bool = ...) -> DataFrame: ...
241244
def idxmax(
242245
self,
@@ -388,3 +391,7 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]):
388391
def __iter__( # pyright: ignore[reportIncompatibleMethodOverride]
389392
self,
390393
) -> Iterator[tuple[ByT, DataFrame]]: ...
394+
@overload
395+
def size(self: DataFrameGroupBy[ByT, Literal[True]]) -> Series[int]: ...
396+
@overload
397+
def size(self: DataFrameGroupBy[ByT, Literal[False]]) -> DataFrame: ...

pandas-stubs/core/groupby/groupby.pyi

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@ from typing import (
1818

1919
import numpy as np
2020
from pandas.core.base import SelectionMixin
21-
from pandas.core.frame import (
22-
DataFrame,
23-
DataFrameGroupByGen,
24-
)
21+
from pandas.core.frame import DataFrame
2522
from pandas.core.groupby import (
2623
generic,
2724
ops,
@@ -56,7 +53,6 @@ from pandas._typing import (
5653
AnyArrayLike,
5754
Axis,
5855
AxisInt,
59-
ByT,
6056
CalculationMethod,
6157
Dtype,
6258
Frequency,
@@ -236,13 +232,7 @@ class GroupBy(BaseGroupBy[NDFrameT]):
236232
def sem(
237233
self: GroupBy[DataFrame], ddof: int = ..., numeric_only: bool = ...
238234
) -> DataFrame: ...
239-
@final
240-
@overload
241235
def size(self: GroupBy[Series]) -> Series[int]: ...
242-
@overload
243-
def size(self: DataFrameGroupByGen[ByT, Literal[True]]) -> Series[int]: ... # type: ignore[misc]
244-
@overload
245-
def size(self: DataFrameGroupByGen[ByT, Literal[False]]) -> DataFrame: ... # type: ignore[misc]
246236
@final
247237
def sum(
248238
self,

tests/test_frame.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,22 @@ def test_types_groupby_as_index() -> None:
10431043
)
10441044

10451045

1046+
def test_types_groupby_size() -> None:
1047+
"""Test for GH886."""
1048+
data = [
1049+
{"date": "2023-12-01", "val": 12},
1050+
{"date": "2023-12-02", "val": 2},
1051+
{"date": "2023-12-03", "val": 1},
1052+
{"date": "2023-12-03", "val": 10},
1053+
]
1054+
1055+
df = pd.DataFrame(data)
1056+
groupby = df.groupby("date")
1057+
size = groupby.size()
1058+
frame = size.to_frame()
1059+
check(assert_type(frame.reset_index(), pd.DataFrame), pd.DataFrame)
1060+
1061+
10461062
def test_types_groupby() -> None:
10471063
df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5], "col3": [0, 1, 0]})
10481064
df.index.name = "ind"

0 commit comments

Comments
 (0)