Skip to content

Commit 053b7e7

Browse files
GH456 PR Feedback
1 parent 3bba101 commit 053b7e7

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

pandas-stubs/core/groupby/base.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class OutputKey:
1010
label: Hashable
1111
position: int
1212

13-
reduction_kernels: TypeAlias = Literal[
13+
ReductionKernelType: TypeAlias = Literal[
1414
"all",
1515
"any",
1616
"corrwith",
@@ -37,7 +37,7 @@ reduction_kernels: TypeAlias = Literal[
3737
"var",
3838
]
3939

40-
transformation_kernels: TypeAlias = Literal[
40+
TransformationKernelType: TypeAlias = Literal[
4141
"bfill",
4242
"cumcount",
4343
"cummax",
@@ -53,4 +53,4 @@ transformation_kernels: TypeAlias = Literal[
5353
"shift",
5454
]
5555

56-
transform_kernel_allowlist: TypeAlias = reduction_kernels | transformation_kernels
56+
TransformReductionListType: TypeAlias = ReductionKernelType | TransformationKernelType

pandas-stubs/core/groupby/generic.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ from typing import (
1919
from matplotlib.axes import Axes as PlotAxes
2020
import numpy as np
2121
from pandas.core.frame import DataFrame
22-
from pandas.core.groupby.base import transform_kernel_allowlist
22+
from pandas.core.groupby.base import TransformReductionListType
2323
from pandas.core.groupby.groupby import (
2424
GroupBy,
2525
GroupByPlot,
@@ -110,7 +110,7 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]):
110110
) -> UnknownSeries: ...
111111
@overload
112112
def transform(
113-
self, func: transform_kernel_allowlist, *args, **kwargs
113+
self, func: TransformReductionListType, *args, **kwargs
114114
) -> UnknownSeries: ...
115115
def filter(
116116
self, func: Callable | str, dropna: bool = ..., *args, **kwargs
@@ -256,7 +256,7 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
256256
) -> DataFrame: ...
257257
@overload
258258
def transform(
259-
self, func: transform_kernel_allowlist, *args, **kwargs
259+
self, func: TransformReductionListType, *args, **kwargs
260260
) -> DataFrame: ...
261261
def filter(
262262
self, func: Callable, dropna: bool = ..., *args, **kwargs

tests/test_series.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,7 +1081,7 @@ def test_types_groupby_agg() -> None:
10811081

10821082
def sum_sr(s: pd.Series[int]) -> int:
10831083
# type of `sum` not well inferred by mypy
1084-
return sum(s)
1084+
return s.sum()
10851085

10861086
check(
10871087
assert_type(s.groupby(level=0).agg(sum_sr), "pd.Series[int]"),
@@ -1133,7 +1133,11 @@ def func(s: pd.Series[int]) -> float:
11331133
return s.astype(float).min()
11341134

11351135
s = pd.Series([1, 2, 3, 4])
1136-
s.groupby([1, 1, 2, 2]).agg(lambda x: x.astype(float).min())
1136+
check(
1137+
assert_type(s.groupby([1, 1, 2, 2]).agg(func), "pd.Series[float]"),
1138+
pd.Series,
1139+
np.floating,
1140+
)
11371141
check(
11381142
assert_type(s.groupby(level=0).aggregate(func), "pd.Series[float]"),
11391143
pd.Series,
@@ -1155,7 +1159,7 @@ def func(s: pd.Series[int]) -> float:
11551159

11561160
def sum_sr(s: pd.Series[int]) -> int:
11571161
# type of `sum` not well inferred by mypy
1158-
return sum(s)
1162+
return s.sum()
11591163

11601164
check(
11611165
assert_type(s.groupby(level=0).aggregate(sum_sr), "pd.Series[int]"),

0 commit comments

Comments
 (0)