Skip to content

Commit 9883723

Browse files
GH203 Create new overload for DatetimeIndex
1 parent f6e7a4b commit 9883723

File tree

3 files changed

+40
-2
lines changed

3 files changed

+40
-2
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,17 +1076,29 @@ class DataFrame(NDFrame, OpsMixin):
10761076
dropna: _bool = ...,
10771077
) -> DataFrameGroupBy[Scalar, Literal[False]]: ...
10781078
@overload
1079+
def groupby( # type: ignore[overload-overlap] # pyright: ignore reportOverlappingOverload
1080+
self,
1081+
by: DatetimeIndex,
1082+
axis: AxisIndex | NoDefault = ...,
1083+
level: IndexLabel | None = ...,
1084+
as_index: Literal[True] = True,
1085+
sort: _bool = ...,
1086+
group_keys: _bool = ...,
1087+
observed: _bool | NoDefault = ...,
1088+
dropna: _bool = ...,
1089+
) -> DataFrameGroupBy[Timestamp, Literal[True]]: ...
1090+
@overload
10791091
def groupby(
10801092
self,
10811093
by: DatetimeIndex,
10821094
axis: AxisIndex | NoDefault = ...,
10831095
level: IndexLabel | None = ...,
1084-
as_index: _bool = ...,
1096+
as_index: Literal[False] = ...,
10851097
sort: _bool = ...,
10861098
group_keys: _bool = ...,
10871099
observed: _bool | NoDefault = ...,
10881100
dropna: _bool = ...,
1089-
) -> DataFrameGroupBy[Timestamp, bool]: ...
1101+
) -> DataFrameGroupBy[Timestamp, Literal[False]]: ...
10901102
@overload
10911103
def groupby(
10921104
self,

pandas-stubs/core/groupby/generic.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ from typing_extensions import (
3030
)
3131

3232
from pandas._libs.lib import NoDefault
33+
from pandas._libs.tslibs.timestamps import Timestamp
3334
from pandas._typing import (
3435
S1,
3536
AggFuncTypeBase,
@@ -395,3 +396,7 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
395396
def size(self: DataFrameGroupBy[ByT, Literal[True]]) -> Series[int]: ...
396397
@overload
397398
def size(self: DataFrameGroupBy[ByT, Literal[False]]) -> DataFrame: ...
399+
@overload
400+
def size(self: DataFrameGroupBy[Timestamp, Literal[True]]) -> Series[int]: ...
401+
@overload
402+
def size(self: DataFrameGroupBy[Timestamp, Literal[False]]) -> DataFrame: ...

tests/test_frame.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,27 @@ def test_types_groupby_as_index() -> None:
10431043
)
10441044

10451045

1046+
def test_types_groupby_as_index_timestamp() -> None:
1047+
"""Test groupby size with DatetimeIndex."""
1048+
idx = pd.DatetimeIndex(["2023-10-01", "2023-10-02", "2023-10-01"], name="date")
1049+
sub_idx = pd.DatetimeIndex(["2023-10-01", "2023-10-02"], name="date")
1050+
df = pd.DataFrame({"a": [1, 2, 3]}, index=idx)
1051+
check(
1052+
assert_type(
1053+
df.groupby(sub_idx, as_index=False).size(),
1054+
pd.DataFrame,
1055+
),
1056+
pd.DataFrame,
1057+
)
1058+
check(
1059+
assert_type(
1060+
df.groupby(sub_idx, as_index=True).size(),
1061+
"pd.Series[int]",
1062+
),
1063+
pd.Series,
1064+
)
1065+
1066+
10461067
def test_types_groupby_size() -> None:
10471068
"""Test for GH886."""
10481069
data = [

0 commit comments

Comments
 (0)