diff --git a/tests/test_groupby.py b/tests/test_groupby.py index 95c7cdefa..f69aa3979 100644 --- a/tests/test_groupby.py +++ b/tests/test_groupby.py @@ -31,9 +31,12 @@ RollingGroupby, ) +from pandas.errors import Pandas4Warning + from tests import ( TYPE_CHECKING_INVALID_USAGE, check, + pytest_warns_bounded, ) if TYPE_CHECKING: @@ -925,3 +928,171 @@ def test_frame_groupby_aggregate() -> None: check(assert_type(df.groupby("b").agg(a=("a", "mean")), DataFrame), DataFrame) check(assert_type(df.groupby("b").agg(**dico), DataFrame), DataFrame) + + +def test_frame_groupby_transform_reduction_kernels() -> None: + """Test DataFrameGroupBy.transform with ReductionKernelType literals.""" + check(assert_type(GB_DF.transform("all"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("any"), DataFrame), DataFrame) + with pytest_warns_bounded(Pandas4Warning, "corrwith is deprecated", lower="2.99"): + check(assert_type(GB_DF.transform("corrwith", other=DF), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("count"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("first"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("idxmax"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("idxmin"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("last"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("max"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("mean"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("median"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("min"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("nunique"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("prod"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("quantile"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("sem"), DataFrame), DataFrame) + # TODO: pandas-dev/pandas-stubs#1671, size, cumcount, ngroup return Series at runtime on DataFrameGroupBy + check(assert_type(GB_DF.transform("skew"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("std"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("sum"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("var"), DataFrame), DataFrame) + + +def test_frame_groupby_transform_transformation_kernels() -> None: + """Test DataFrameGroupBy.transform with TransformationKernelType literals.""" + check(assert_type(GB_DF.transform("bfill"), DataFrame), DataFrame) + # TODO: pandas-dev/pandas-stubs#1671, cumcount and ngroup return Series at runtime on DataFrameGroupBy + check(assert_type(GB_DF.transform("cummax"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("cummin"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("cumprod"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("cumsum"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("diff"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("ffill"), DataFrame), DataFrame) + # TODO: pandas-dev/pandas-stubs#1671, fillna is not a valid function name for transform(name) at runtime + check(assert_type(GB_DF.transform("pct_change"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("rank"), DataFrame), DataFrame) + check(assert_type(GB_DF.transform("shift"), DataFrame), DataFrame) + + +def test_series_groupby_transform_reduction_kernels() -> None: + """Test SeriesGroupBy.transform with ReductionKernelType literals.""" + check(assert_type(GB_S.transform("all"), Series), Series) + check(assert_type(GB_S.transform("any"), Series), Series) + # TODO: pandas-dev/pandas-stubs#1671, corrwith does not exist on SeriesGroupBy + check(assert_type(GB_S.transform("count"), Series), Series) + check(assert_type(GB_S.transform("first"), Series), Series) + check(assert_type(GB_S.transform("idxmax"), Series), Series) + check(assert_type(GB_S.transform("idxmin"), Series), Series) + check(assert_type(GB_S.transform("last"), Series), Series) + check(assert_type(GB_S.transform("max"), Series), Series) + check(assert_type(GB_S.transform("mean"), Series), Series) + check(assert_type(GB_S.transform("median"), Series), Series) + check(assert_type(GB_S.transform("min"), Series), Series) + check(assert_type(GB_S.transform("nunique"), Series), Series) + check(assert_type(GB_S.transform("prod"), Series), Series) + check(assert_type(GB_S.transform("quantile"), Series), Series) + check(assert_type(GB_S.transform("sem"), Series), Series) + check(assert_type(GB_S.transform("size"), Series), Series) + check(assert_type(GB_S.transform("skew"), Series), Series) + check(assert_type(GB_S.transform("std"), Series), Series) + check(assert_type(GB_S.transform("sum"), Series), Series) + check(assert_type(GB_S.transform("var"), Series), Series) + + +def test_series_groupby_transform_transformation_kernels() -> None: + """Test SeriesGroupBy.transform with TransformationKernelType literals.""" + check(assert_type(GB_S.transform("bfill"), Series), Series) + check(assert_type(GB_S.transform("cumcount"), Series), Series) + check(assert_type(GB_S.transform("cummax"), Series), Series) + check(assert_type(GB_S.transform("cummin"), Series), Series) + check(assert_type(GB_S.transform("cumprod"), Series), Series) + check(assert_type(GB_S.transform("cumsum"), Series), Series) + check(assert_type(GB_S.transform("diff"), Series), Series) + check(assert_type(GB_S.transform("ffill"), Series), Series) + # TODO: pandas-dev/pandas-stubs#1671, fillna is not a valid function name for transform(name) at runtime + check(assert_type(GB_S.transform("ngroup"), Series), Series) + check(assert_type(GB_S.transform("pct_change"), Series), Series) + check(assert_type(GB_S.transform("rank"), Series), Series) + check(assert_type(GB_S.transform("shift"), Series), Series) + + +def test_frame_groupby_agg_reduction_kernels() -> None: + """Test DataFrameGroupBy.agg with ReductionKernelType literals.""" + check(assert_type(GB_DF.agg("all"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("any"), DataFrame), DataFrame) + with pytest_warns_bounded(Pandas4Warning, "corrwith is deprecated", lower="2.99"): + check(assert_type(GB_DF.agg("corrwith", other=DF), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("count"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("first"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("idxmax"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("idxmin"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("last"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("max"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("mean"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("median"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("min"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("nunique"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("prod"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("quantile"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("sem"), DataFrame), DataFrame) + check(assert_type(GB_DF.aggregate("size"), Series), Series) + check(assert_type(GB_DF.agg("skew"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("std"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("sum"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("var"), DataFrame), DataFrame) + + +def test_frame_groupby_agg_transformation_kernels() -> None: + """Test DataFrameGroupBy.agg with TransformationKernelType literals.""" + check(assert_type(GB_DF.agg("bfill"), DataFrame), DataFrame) + # TODO: pandas-dev/pandas-stubs#1671, cumcount and ngroup return Series at runtime on DataFrameGroupBy + check(assert_type(GB_DF.agg("cummax"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("cummin"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("cumprod"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("cumsum"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("diff"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("ffill"), DataFrame), DataFrame) + # TODO: pandas-dev/pandas-stubs#1671, fillna is not a valid function for DataFrameGroupBy at runtime + check(assert_type(GB_DF.agg("pct_change"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("rank"), DataFrame), DataFrame) + check(assert_type(GB_DF.agg("shift"), DataFrame), DataFrame) + + +def test_series_groupby_agg_reduction_kernels() -> None: + """Test SeriesGroupBy.agg with ReductionKernelType literals.""" + check(assert_type(GB_S.agg("all"), Series), Series) + check(assert_type(GB_S.agg("any"), Series), Series) + # TODO: pandas-dev/pandas-stubs#1671, corrwith does not exist on SeriesGroupBy + check(assert_type(GB_S.agg("count"), Series), Series) + check(assert_type(GB_S.agg("first"), Series), Series) + check(assert_type(GB_S.agg("idxmax"), Series), Series) + check(assert_type(GB_S.agg("idxmin"), Series), Series) + check(assert_type(GB_S.agg("last"), Series), Series) + check(assert_type(GB_S.agg("max"), Series), Series) + check(assert_type(GB_S.agg("mean"), Series), Series) + check(assert_type(GB_S.agg("median"), Series), Series) + check(assert_type(GB_S.agg("min"), Series), Series) + check(assert_type(GB_S.agg("nunique"), Series), Series) + check(assert_type(GB_S.agg("prod"), Series), Series) + check(assert_type(GB_S.agg("quantile"), Series), Series) + check(assert_type(GB_S.agg("sem"), Series), Series) + check(assert_type(GB_S.agg("size"), Series), Series) + check(assert_type(GB_S.agg("skew"), Series), Series) + check(assert_type(GB_S.agg("std"), Series), Series) + check(assert_type(GB_S.agg("sum"), Series), Series) + check(assert_type(GB_S.agg("var"), Series), Series) + + +def test_series_groupby_agg_transformation_kernels() -> None: + """Test SeriesGroupBy.agg with TransformationKernelType literals.""" + check(assert_type(GB_S.agg("bfill"), Series), Series) + check(assert_type(GB_S.agg("cumcount"), Series), Series) + check(assert_type(GB_S.agg("cummax"), Series), Series) + check(assert_type(GB_S.agg("cummin"), Series), Series) + check(assert_type(GB_S.agg("cumprod"), Series), Series) + check(assert_type(GB_S.agg("cumsum"), Series), Series) + check(assert_type(GB_S.agg("diff"), Series), Series) + check(assert_type(GB_S.agg("ffill"), Series), Series) + # TODO: pandas-dev/pandas-stubs#1671, fillna does not exist on SeriesGroupBy at runtime + check(assert_type(GB_S.agg("ngroup"), Series), Series) + check(assert_type(GB_S.agg("pct_change"), Series), Series) + check(assert_type(GB_S.agg("rank"), Series), Series) + check(assert_type(GB_S.agg("shift"), Series), Series)