Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/source/whatsnew/v1.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,11 @@ Groupby/resample/rolling
- Bug in :meth:`DataFrame.groupby` where a ``ValueError`` would be raised when grouping by a categorical column with read-only categories and ``sort=False`` (:issue:`33410`)
- Bug in :meth:`GroupBy.first` and :meth:`GroupBy.last` where None is not preserved in object dtype (:issue:`32800`)

Groupby/rolling
^^^^^^^^^^^^^^^^^^^^^^^^

- Bug in :meth:`GroupBy.rolling.apply` ignores args and kwargs parameters (:issue:`33433`)

Reshaping
^^^^^^^^^

Expand Down
2 changes: 2 additions & 0 deletions pandas/core/window/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,6 +1300,8 @@ def apply(
use_numba_cache=engine == "numba",
raw=raw,
original_func=func,
args=args,
kwargs=kwargs,
)

def _generate_cython_apply_func(self, args, kwargs, raw, offset, func):
Expand Down
30 changes: 29 additions & 1 deletion pandas/tests/window/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pandas.util._test_decorators as td

from pandas import DataFrame, Series, Timestamp, date_range
from pandas import DataFrame, Index, MultiIndex, Series, Timestamp, date_range
import pandas._testing as tm


Expand Down Expand Up @@ -138,3 +138,31 @@ def test_invalid_kwargs_nopython():
Series(range(1)).rolling(1).apply(
lambda x: x, kwargs={"a": 1}, engine="numba", raw=True
)


def test_rolling_apply_args_kwargs():
# GH 33433
def foo(x, par):
return np.sum(x + par)

df = DataFrame({"gr": [1, 1], "a": [1, 2]})

idx = Index(["gr", "a"])
expected = DataFrame([[11.0, 11.0], [11.0, 12.0]], columns=idx)

result = df.rolling(1).apply(foo, kwargs={"par": 10})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you parameterize over the kwargs & args?

pytest.mark.parameterize('args_kwargs', [[None, {"par": 10}], [(10,), None]])

tm.assert_frame_equal(result, expected)

result = df.rolling(1).apply(foo, args=(10,))
tm.assert_frame_equal(result, expected)

midx = MultiIndex.from_tuples([(1, 0), (1, 1)], names=["gr", None])
expected = Series([11.0, 12.0], index=midx, name="a")

gb_rolling = df.groupby("gr")["a"].rolling(1)

result = gb_rolling.apply(foo, kwargs={"par": 10})
tm.assert_series_equal(result, expected)

result = gb_rolling.apply(foo, args=(10,))
tm.assert_series_equal(result, expected)