Skip to content

Commit 1221d34

Browse files
authored
Support args and kwargs for rolling aggregations (dask#11856)
1 parent 4399aef commit 1221d34

File tree

2 files changed

+50
-33
lines changed

2 files changed

+50
-33
lines changed

dask/dataframe/dask_expr/_rolling.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -287,56 +287,56 @@ def _single_agg(self, expr_cls, how_args=(), how_kwargs=None):
287287
)
288288

289289
@derived_from(pd_Rolling)
290-
def cov(self):
291-
return self._single_agg(RollingCov)
290+
def cov(self, *args, **kwargs):
291+
return self._single_agg(RollingCov, how_args=args, how_kwargs=kwargs)
292292

293293
@derived_from(pd_Rolling)
294294
def apply(self, func, *args, **kwargs):
295295
return self._single_agg(RollingApply, how_args=(func, *args), how_kwargs=kwargs)
296296

297297
@derived_from(pd_Rolling)
298-
def count(self):
299-
return self._single_agg(RollingCount)
298+
def count(self, *args, **kwargs):
299+
return self._single_agg(RollingCount, how_args=args, how_kwargs=kwargs)
300300

301301
@derived_from(pd_Rolling)
302-
def sum(self):
303-
return self._single_agg(RollingSum)
302+
def sum(self, *args, **kwargs):
303+
return self._single_agg(RollingSum, how_args=args, how_kwargs=kwargs)
304304

305305
@derived_from(pd_Rolling)
306-
def mean(self):
307-
return self._single_agg(RollingMean)
306+
def mean(self, *args, **kwargs):
307+
return self._single_agg(RollingMean, how_args=args, how_kwargs=kwargs)
308308

309309
@derived_from(pd_Rolling)
310-
def min(self):
311-
return self._single_agg(RollingMin)
310+
def min(self, *args, **kwargs):
311+
return self._single_agg(RollingMin, how_args=args, how_kwargs=kwargs)
312312

313313
@derived_from(pd_Rolling)
314-
def max(self):
315-
return self._single_agg(RollingMax)
314+
def max(self, *args, **kwargs):
315+
return self._single_agg(RollingMax, how_args=args, how_kwargs=kwargs)
316316

317317
@derived_from(pd_Rolling)
318-
def var(self):
319-
return self._single_agg(RollingVar)
318+
def var(self, *args, **kwargs):
319+
return self._single_agg(RollingVar, how_args=args, how_kwargs=kwargs)
320320

321321
@derived_from(pd_Rolling)
322-
def std(self):
323-
return self._single_agg(RollingStd)
322+
def std(self, *args, **kwargs):
323+
return self._single_agg(RollingStd, how_args=args, how_kwargs=kwargs)
324324

325325
@derived_from(pd_Rolling)
326-
def median(self):
327-
return self._single_agg(RollingMedian)
326+
def median(self, *args, **kwargs):
327+
return self._single_agg(RollingMedian, how_args=args, how_kwargs=kwargs)
328328

329329
@derived_from(pd_Rolling)
330-
def quantile(self, q):
331-
return self._single_agg(RollingQuantile, how_args=(q,))
330+
def quantile(self, q, *args, **kwargs):
331+
return self._single_agg(RollingQuantile, how_args=(q, *args), how_kwargs=kwargs)
332332

333333
@derived_from(pd_Rolling)
334-
def skew(self):
335-
return self._single_agg(RollingSkew)
334+
def skew(self, *args, **kwargs):
335+
return self._single_agg(RollingSkew, how_args=args, how_kwargs=kwargs)
336336

337337
@derived_from(pd_Rolling)
338-
def kurt(self):
339-
return self._single_agg(RollingKurt)
338+
def kurt(self, *args, **kwargs):
339+
return self._single_agg(RollingKurt, how_args=args, how_kwargs=kwargs)
340340

341341
@derived_from(pd_Rolling)
342342
def agg(self, func, *args, **kwargs):

dask/dataframe/dask_expr/tests/test_rolling.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
@pytest.fixture
1515
def pdf():
1616
idx = pd.date_range("2000-01-01", periods=12, freq="min")
17-
pdf = pd.DataFrame({"foo": range(len(idx))}, index=idx)
17+
pdf = pd.DataFrame(
18+
{"foo": range(len(idx)), "bar": idx},
19+
index=idx,
20+
)
1821
pdf["bar"] = 1
1922
yield pdf
2023

@@ -34,30 +37,44 @@ def df(pdf, request):
3437
("min", ()),
3538
("max", ()),
3639
("var", ()),
40+
("var", (2,)), # ddof
3741
("std", ()),
42+
("std", (2,)), # ddof
3843
("median", ()),
3944
("skew", ()),
4045
("quantile", (0.5,)),
4146
("kurt", ()),
4247
],
4348
)
49+
@pytest.mark.parametrize("numeric_only", [True, False])
4450
@pytest.mark.parametrize("window,min_periods", ((1, None), (3, 2), (3, 3)))
4551
@pytest.mark.parametrize("center", (True, False))
4652
@pytest.mark.parametrize("df", (1, 2), indirect=True)
47-
def test_rolling_apis(df, pdf, window, api, how_args, min_periods, center):
53+
def test_rolling_apis(
54+
df, pdf, window, api, how_args, min_periods, center, numeric_only
55+
):
4856
args = (window,)
49-
kwargs = dict(min_periods=min_periods, center=center)
50-
51-
result = getattr(df.rolling(*args, **kwargs), api)(*how_args)
52-
expected = getattr(pdf.rolling(*args, **kwargs), api)(*how_args)
57+
kwargs = dict(
58+
min_periods=min_periods,
59+
center=center,
60+
)
61+
how_kwargs = dict(
62+
numeric_only=numeric_only,
63+
)
64+
result = getattr(df.rolling(*args, **kwargs), api)(*how_args, **how_kwargs)
65+
expected = getattr(pdf.rolling(*args, **kwargs), api)(*how_args, **how_kwargs)
5366
assert_eq(result, expected)
5467

55-
result = getattr(df.rolling(*args, **kwargs), api)(*how_args)["foo"]
56-
expected = getattr(pdf.rolling(*args, **kwargs), api)(*how_args)["foo"]
68+
result = getattr(df.rolling(*args, **kwargs), api)(*how_args, **how_kwargs)["foo"]
69+
expected = getattr(pdf.rolling(*args, **kwargs), api)(*how_args, **how_kwargs)[
70+
"foo"
71+
]
5772
assert_eq(result, expected)
5873

5974
q = result.simplify()
60-
eq = getattr(df["foo"].rolling(*args, **kwargs), api)(*how_args).simplify()
75+
eq = getattr(df["foo"].rolling(*args, **kwargs), api)(
76+
*how_args, **how_kwargs
77+
).simplify()
6178
assert q._name == eq._name
6279

6380

0 commit comments

Comments
 (0)