@@ -7,6 +7,7 @@ from collections.abc import (
77)
88from typing import (
99 Any ,
10+ Concatenate ,
1011 Generic ,
1112 Literal ,
1213 NamedTuple ,
@@ -18,6 +19,7 @@ from typing import (
1819from matplotlib .axes import Axes as PlotAxes
1920import numpy as np
2021from pandas .core .frame import DataFrame
22+ from pandas .core .groupby .base import TransformReductionListType
2123from pandas .core .groupby .groupby import (
2224 GroupBy ,
2325 GroupByPlot ,
@@ -31,6 +33,7 @@ from typing_extensions import (
3133from pandas ._libs .tslibs .timestamps import Timestamp
3234from pandas ._typing import (
3335 S2 ,
36+ S3 ,
3437 AggFuncTypeBase ,
3538 AggFuncTypeFrame ,
3639 ByT ,
@@ -40,6 +43,7 @@ from pandas._typing import (
4043 Level ,
4144 ListLike ,
4245 NsmallestNlargestKeep ,
46+ P ,
4347 Scalar ,
4448 TakeIndexer ,
4549 WindowingEngine ,
@@ -53,10 +57,30 @@ class NamedAgg(NamedTuple):
5357 aggfunc : AggScalar
5458
5559class SeriesGroupBy (GroupBy [Series [S2 ]], Generic [S2 , ByT ]):
60+ @overload
61+ def aggregate (
62+ self ,
63+ func : Callable [Concatenate [Series [S2 ], P ], S3 ],
64+ / ,
65+ * args ,
66+ engine : WindowingEngine = ...,
67+ engine_kwargs : WindowingEngineKwargs = ...,
68+ ** kwargs ,
69+ ) -> Series [S3 ]: ...
70+ @overload
71+ def aggregate (
72+ self ,
73+ func : Callable [[Series ], S3 ],
74+ * args ,
75+ engine : WindowingEngine = ...,
76+ engine_kwargs : WindowingEngineKwargs = ...,
77+ ** kwargs ,
78+ ) -> Series [S3 ]: ...
5679 @overload
5780 def aggregate (
5881 self ,
5982 func : list [AggFuncTypeBase ],
83+ / ,
6084 * args ,
6185 engine : WindowingEngine = ...,
6286 engine_kwargs : WindowingEngineKwargs = ...,
@@ -66,19 +90,33 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]):
6690 def aggregate (
6791 self ,
6892 func : AggFuncTypeBase | None = ...,
93+ / ,
6994 * args ,
7095 engine : WindowingEngine = ...,
7196 engine_kwargs : WindowingEngineKwargs = ...,
7297 ** kwargs ,
7398 ) -> Series : ...
7499 agg = aggregate
100+ @overload
75101 def transform (
76102 self ,
77- func : Callable | str ,
78- * args ,
103+ func : Callable [Concatenate [Series [S2 ], P ], Series [S3 ]],
104+ / ,
105+ * args : Any ,
79106 engine : WindowingEngine = ...,
80107 engine_kwargs : WindowingEngineKwargs = ...,
81- ** kwargs ,
108+ ** kwargs : Any ,
109+ ) -> Series [S3 ]: ...
110+ @overload
111+ def transform (
112+ self ,
113+ func : Callable ,
114+ * args : Any ,
115+ ** kwargs : Any ,
116+ ) -> Series : ...
117+ @overload
118+ def transform (
119+ self , func : TransformReductionListType , * args , ** kwargs
82120 ) -> Series : ...
83121 def filter (
84122 self , func : Callable | str , dropna : bool = ..., * args , ** kwargs
@@ -206,13 +244,25 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
206244 ** kwargs ,
207245 ) -> DataFrame : ...
208246 agg = aggregate
247+ @overload
209248 def transform (
210249 self ,
211- func : Callable | str ,
212- * args ,
250+ func : Callable [ Concatenate [ DataFrame , P ], DataFrame ] ,
251+ * args : Any ,
213252 engine : WindowingEngine = ...,
214253 engine_kwargs : WindowingEngineKwargs = ...,
215- ** kwargs ,
254+ ** kwargs : Any ,
255+ ) -> DataFrame : ...
256+ @overload
257+ def transform (
258+ self ,
259+ func : Callable ,
260+ * args : Any ,
261+ ** kwargs : Any ,
262+ ) -> DataFrame : ...
263+ @overload
264+ def transform (
265+ self , func : TransformReductionListType , * args , ** kwargs
216266 ) -> DataFrame : ...
217267 def filter (
218268 self , func : Callable , dropna : bool = ..., * args , ** kwargs
0 commit comments