Skip to content

Commit 4081c27

Browse files
committed
allow args and kwargs in groupby.apply
1 parent 69cf85e commit 4081c27

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

pandas-stubs/core/groupby/generic.pyi

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ from typing import (
1111
Generic,
1212
Literal,
1313
NamedTuple,
14+
Protocol,
1415
TypeVar,
1516
final,
1617
overload,
@@ -208,26 +209,35 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]):
208209

209210
_TT = TypeVar("_TT", bound=Literal[True, False])
210211

212+
class DFCallable1(Protocol):
213+
def __call__(self, df: DataFrame, /, *args, **kwargs) -> Scalar | list | dict: ...
214+
215+
class DFCallable2(Protocol):
216+
def __call__(self, df: DataFrame, /, *args, **kwargs) -> DataFrame | Series: ...
217+
218+
class DFCallable3(Protocol):
219+
def __call__(self, df: Iterable, /, *args, **kwargs) -> float: ...
220+
211221
class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
212222
# error: Overload 3 for "apply" will never be used because its parameters overlap overload 1
213223
@overload # type: ignore[override]
214224
def apply(
215225
self,
216-
func: Callable[[DataFrame], Scalar | list | dict],
226+
func: DFCallable1,
217227
*args,
218228
**kwargs,
219229
) -> Series: ...
220230
@overload
221231
def apply(
222232
self,
223-
func: Callable[[DataFrame], Series | DataFrame],
233+
func: DFCallable2,
224234
*args,
225235
**kwargs,
226236
) -> DataFrame: ...
227237
@overload
228238
def apply( # pyright: ignore[reportOverlappingOverload]
229239
self,
230-
func: Callable[[Iterable], float],
240+
func: DFCallable3,
231241
*args,
232242
**kwargs,
233243
) -> DataFrame: ...

tests/test_groupby.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,3 +1102,23 @@ def test_dataframe_value_counts() -> None:
11021102
Series,
11031103
np.int64,
11041104
)
1105+
1106+
1107+
def test_dataframe_apply_kwargs() -> None:
1108+
# GH 1266
1109+
df = DataFrame({"group": ["A", "A", "B", "B", "C"], "value": [10, 15, 10, 25, 30]})
1110+
1111+
def add_constant_to_mean(group: DataFrame, constant: int) -> DataFrame:
1112+
mean_val = group["value"].mean()
1113+
group["adjusted"] = mean_val + constant
1114+
return group
1115+
1116+
check(
1117+
assert_type(
1118+
df.groupby("group", group_keys=False)[["group", "value"]].apply(
1119+
add_constant_to_mean, constant=5
1120+
),
1121+
DataFrame,
1122+
),
1123+
DataFrame,
1124+
)

0 commit comments

Comments
 (0)