Skip to content

Commit 2400d28

Browse files
committed
add more tests
1 parent 33be300 commit 2400d28

File tree

3 files changed

+69
-0
lines changed

3 files changed

+69
-0
lines changed

pandas/tests/window/test_apply.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,3 +316,17 @@ def test_center_reindex_frame(raw):
316316
)
317317
frame_rs = frame.rolling(window=25, min_periods=minp, center=True).apply(f, raw=raw)
318318
tm.assert_frame_equal(frame_xp, frame_rs)
319+
320+
def test_apply_numba_with_kwargs():
321+
# 58995
322+
def func(sr, a=0):
323+
return sr.sum() + a
324+
325+
data = DataFrame(range(10))
326+
327+
result = data.rolling(5).apply(func, engine="numba", raw=True, kwargs={"a": 1})
328+
expected = data.rolling(5).sum() + 1
329+
tm.assert_frame_equal(result, expected)
330+
331+
result = data.rolling(5).apply(func, engine="numba", raw=True, args=(1,))
332+
tm.assert_frame_equal(result, expected)

pandas/tests/window/test_expanding.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,3 +691,18 @@ def test_numeric_only_corr_cov_series(kernel, use_arg, numeric_only, dtype):
691691
op2 = getattr(expanding2, kernel)
692692
expected = op2(*arg2, numeric_only=numeric_only)
693693
tm.assert_series_equal(result, expected)
694+
695+
696+
def test_apply_numba_with_kwargs():
697+
# 58995
698+
def func(sr, a=0):
699+
return sr.sum() + a
700+
701+
data = DataFrame(range(10))
702+
703+
result = data.expanding().apply(func, engine="numba", raw=True, kwargs={"a": 1})
704+
expected = data.expanding().sum() + 1
705+
tm.assert_frame_equal(result, expected)
706+
707+
result = data.expanding().apply(func, engine="numba", raw=True, args=(1,))
708+
tm.assert_frame_equal(result, expected)

pandas/tests/window/test_groupby.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,26 @@ def test_datelike_on_not_monotonic_within_each_group(self):
10241024
with pytest.raises(ValueError, match="Each group within B must be monotonic."):
10251025
df.groupby("A").rolling("365D", on="B")
10261026

1027+
def test_groupby_rolling_apply_numba_with_kwargs(self, roll_frame):
1028+
def func(sr, a=0):
1029+
return sr.sum() + a
1030+
1031+
# 58995
1032+
result = (
1033+
roll_frame.groupby("A")
1034+
.rolling(5)
1035+
.apply(func, engine="numba", raw=True, kwargs={"a": 1})
1036+
)
1037+
expected = roll_frame.groupby("A").rolling(5).sum() + 1
1038+
tm.assert_frame_equal(result, expected)
1039+
1040+
result = (
1041+
roll_frame.groupby("A")
1042+
.rolling(5)
1043+
.apply(func, engine="numba", raw=True, args=(1,))
1044+
)
1045+
tm.assert_frame_equal(result, expected)
1046+
10271047

10281048
class TestExpanding:
10291049
@pytest.fixture
@@ -1134,6 +1154,26 @@ def test_expanding_apply(self, raw, frame):
11341154
expected.index = expected_index
11351155
tm.assert_frame_equal(result, expected)
11361156

1157+
def test_groupby_expanding_apply_numba_with_kwargs(self, roll_frame):
1158+
# 58995
1159+
def func(sr, a=0):
1160+
return sr.sum() + a
1161+
1162+
result = (
1163+
roll_frame.groupby("A")
1164+
.expanding()
1165+
.apply(func, engine="numba", raw=True, kwargs={"a": 1})
1166+
)
1167+
expected = roll_frame.groupby("A").expanding().sum() + 1
1168+
tm.assert_frame_equal(result, expected)
1169+
1170+
result = (
1171+
roll_frame.groupby("A")
1172+
.expanding()
1173+
.apply(func, engine="numba", raw=True, args=(1,))
1174+
)
1175+
tm.assert_frame_equal(result, expected)
1176+
11371177

11381178
class TestEWM:
11391179
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)