Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,12 @@
RollingGroupby,
)

from pandas.errors import Pandas4Warning

from tests import (
TYPE_CHECKING_INVALID_USAGE,
check,
pytest_warns_bounded,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -925,3 +928,171 @@ def test_frame_groupby_aggregate() -> None:

check(assert_type(df.groupby("b").agg(a=("a", "mean")), DataFrame), DataFrame)
check(assert_type(df.groupby("b").agg(**dico), DataFrame), DataFrame)


def test_frame_groupby_transform_reduction_kernels() -> None:
"""Test DataFrameGroupBy.transform with ReductionKernelType literals."""
check(assert_type(GB_DF.transform("all"), DataFrame), DataFrame)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for DataFrame.transform, all str's are accepted by the stubs at the moment. Shall we switch to a Literal and also add one or two negative tests in TYPE_CHECKING_INVALID_USAGE?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to make the switch, I was just worried there are some functions not in the list and we would break for the user. What do you think? Should we take the risk?

check(assert_type(GB_DF.transform("any"), DataFrame), DataFrame)
with pytest_warns_bounded(Pandas4Warning, "corrwith is deprecated", lower="2.99"):
check(assert_type(GB_DF.transform("corrwith", other=DF), DataFrame), DataFrame)
Comment on lines +937 to +938
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't support 2.x anymore, do we?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in pytest_warns_bounded the check for version is strict so if you put 3.0 it will not trigger for the 3.0 version, that way I am forced to stick to 2.99. (we had the issue in a different PR).

check(assert_type(GB_DF.transform("count"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("first"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("idxmax"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("idxmin"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("last"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("max"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("mean"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("median"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("min"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("nunique"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("prod"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("quantile"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("sem"), DataFrame), DataFrame)
# TODO: pandas-dev/pandas-stubs#1671, size, cumcount, ngroup return Series at runtime on DataFrameGroupBy
check(assert_type(GB_DF.transform("skew"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("std"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("sum"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("var"), DataFrame), DataFrame)


def test_frame_groupby_transform_transformation_kernels() -> None:
"""Test DataFrameGroupBy.transform with TransformationKernelType literals."""
check(assert_type(GB_DF.transform("bfill"), DataFrame), DataFrame)
# TODO: pandas-dev/pandas-stubs#1671, cumcount and ngroup return Series at runtime on DataFrameGroupBy
check(assert_type(GB_DF.transform("cummax"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("cummin"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("cumprod"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("cumsum"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("diff"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("ffill"), DataFrame), DataFrame)
# TODO: pandas-dev/pandas-stubs#1671, fillna is not a valid function name for transform(name) at runtime
check(assert_type(GB_DF.transform("pct_change"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("rank"), DataFrame), DataFrame)
check(assert_type(GB_DF.transform("shift"), DataFrame), DataFrame)


def test_series_groupby_transform_reduction_kernels() -> None:
"""Test SeriesGroupBy.transform with ReductionKernelType literals."""
check(assert_type(GB_S.transform("all"), Series), Series)
check(assert_type(GB_S.transform("any"), Series), Series)
# TODO: pandas-dev/pandas-stubs#1671, corrwith does not exist on SeriesGroupBy
check(assert_type(GB_S.transform("count"), Series), Series)
check(assert_type(GB_S.transform("first"), Series), Series)
check(assert_type(GB_S.transform("idxmax"), Series), Series)
check(assert_type(GB_S.transform("idxmin"), Series), Series)
check(assert_type(GB_S.transform("last"), Series), Series)
check(assert_type(GB_S.transform("max"), Series), Series)
check(assert_type(GB_S.transform("mean"), Series), Series)
check(assert_type(GB_S.transform("median"), Series), Series)
check(assert_type(GB_S.transform("min"), Series), Series)
check(assert_type(GB_S.transform("nunique"), Series), Series)
check(assert_type(GB_S.transform("prod"), Series), Series)
check(assert_type(GB_S.transform("quantile"), Series), Series)
check(assert_type(GB_S.transform("sem"), Series), Series)
check(assert_type(GB_S.transform("size"), Series), Series)
check(assert_type(GB_S.transform("skew"), Series), Series)
check(assert_type(GB_S.transform("std"), Series), Series)
check(assert_type(GB_S.transform("sum"), Series), Series)
check(assert_type(GB_S.transform("var"), Series), Series)


def test_series_groupby_transform_transformation_kernels() -> None:
"""Test SeriesGroupBy.transform with TransformationKernelType literals."""
check(assert_type(GB_S.transform("bfill"), Series), Series)
check(assert_type(GB_S.transform("cumcount"), Series), Series)
check(assert_type(GB_S.transform("cummax"), Series), Series)
check(assert_type(GB_S.transform("cummin"), Series), Series)
check(assert_type(GB_S.transform("cumprod"), Series), Series)
check(assert_type(GB_S.transform("cumsum"), Series), Series)
check(assert_type(GB_S.transform("diff"), Series), Series)
check(assert_type(GB_S.transform("ffill"), Series), Series)
# TODO: pandas-dev/pandas-stubs#1671, fillna is not a valid function name for transform(name) at runtime
check(assert_type(GB_S.transform("ngroup"), Series), Series)
check(assert_type(GB_S.transform("pct_change"), Series), Series)
check(assert_type(GB_S.transform("rank"), Series), Series)
check(assert_type(GB_S.transform("shift"), Series), Series)


def test_frame_groupby_agg_reduction_kernels() -> None:
"""Test DataFrameGroupBy.agg with ReductionKernelType literals."""
check(assert_type(GB_DF.agg("all"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("any"), DataFrame), DataFrame)
with pytest_warns_bounded(Pandas4Warning, "corrwith is deprecated", lower="2.99"):
check(assert_type(GB_DF.agg("corrwith", other=DF), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("count"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("first"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("idxmax"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("idxmin"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("last"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("max"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("mean"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("median"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("min"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("nunique"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("prod"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("quantile"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("sem"), DataFrame), DataFrame)
check(assert_type(GB_DF.aggregate("size"), Series), Series)
check(assert_type(GB_DF.agg("skew"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("std"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("sum"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("var"), DataFrame), DataFrame)


def test_frame_groupby_agg_transformation_kernels() -> None:
"""Test DataFrameGroupBy.agg with TransformationKernelType literals."""
check(assert_type(GB_DF.agg("bfill"), DataFrame), DataFrame)
# TODO: pandas-dev/pandas-stubs#1671, cumcount and ngroup return Series at runtime on DataFrameGroupBy
check(assert_type(GB_DF.agg("cummax"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("cummin"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("cumprod"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("cumsum"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("diff"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("ffill"), DataFrame), DataFrame)
# TODO: pandas-dev/pandas-stubs#1671, fillna is not a valid function for DataFrameGroupBy at runtime
check(assert_type(GB_DF.agg("pct_change"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("rank"), DataFrame), DataFrame)
check(assert_type(GB_DF.agg("shift"), DataFrame), DataFrame)


def test_series_groupby_agg_reduction_kernels() -> None:
"""Test SeriesGroupBy.agg with ReductionKernelType literals."""
check(assert_type(GB_S.agg("all"), Series), Series)
check(assert_type(GB_S.agg("any"), Series), Series)
# TODO: pandas-dev/pandas-stubs#1671, corrwith does not exist on SeriesGroupBy
check(assert_type(GB_S.agg("count"), Series), Series)
check(assert_type(GB_S.agg("first"), Series), Series)
check(assert_type(GB_S.agg("idxmax"), Series), Series)
check(assert_type(GB_S.agg("idxmin"), Series), Series)
check(assert_type(GB_S.agg("last"), Series), Series)
check(assert_type(GB_S.agg("max"), Series), Series)
check(assert_type(GB_S.agg("mean"), Series), Series)
check(assert_type(GB_S.agg("median"), Series), Series)
check(assert_type(GB_S.agg("min"), Series), Series)
check(assert_type(GB_S.agg("nunique"), Series), Series)
check(assert_type(GB_S.agg("prod"), Series), Series)
check(assert_type(GB_S.agg("quantile"), Series), Series)
check(assert_type(GB_S.agg("sem"), Series), Series)
check(assert_type(GB_S.agg("size"), Series), Series)
check(assert_type(GB_S.agg("skew"), Series), Series)
check(assert_type(GB_S.agg("std"), Series), Series)
check(assert_type(GB_S.agg("sum"), Series), Series)
check(assert_type(GB_S.agg("var"), Series), Series)


def test_series_groupby_agg_transformation_kernels() -> None:
"""Test SeriesGroupBy.agg with TransformationKernelType literals."""
check(assert_type(GB_S.agg("bfill"), Series), Series)
check(assert_type(GB_S.agg("cumcount"), Series), Series)
check(assert_type(GB_S.agg("cummax"), Series), Series)
check(assert_type(GB_S.agg("cummin"), Series), Series)
check(assert_type(GB_S.agg("cumprod"), Series), Series)
check(assert_type(GB_S.agg("cumsum"), Series), Series)
check(assert_type(GB_S.agg("diff"), Series), Series)
check(assert_type(GB_S.agg("ffill"), Series), Series)
# TODO: pandas-dev/pandas-stubs#1671, fillna does not exist on SeriesGroupBy at runtime
check(assert_type(GB_S.agg("ngroup"), Series), Series)
check(assert_type(GB_S.agg("pct_change"), Series), Series)
check(assert_type(GB_S.agg("rank"), Series), Series)
check(assert_type(GB_S.agg("shift"), Series), Series)