Skip to content

Commit 1eba10b

Browse files
committed
move the tests to test_numba.py
1 parent f672f9b commit 1eba10b

File tree

4 files changed

+61
-70
lines changed

4 files changed

+61
-70
lines changed

pandas/tests/window/test_apply.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -316,18 +316,3 @@ def test_center_reindex_frame(raw):
316316
)
317317
frame_rs = frame.rolling(window=25, min_periods=minp, center=True).apply(f, raw=raw)
318318
tm.assert_frame_equal(frame_xp, frame_rs)
319-
320-
321-
def test_apply_numba_with_kwargs():
322-
# 58995
323-
def func(sr, a=0):
324-
return sr.sum() + a
325-
326-
data = DataFrame(range(10))
327-
328-
result = data.rolling(5).apply(func, engine="numba", raw=True, kwargs={"a": 1})
329-
expected = data.rolling(5).sum() + 1
330-
tm.assert_frame_equal(result, expected)
331-
332-
result = data.rolling(5).apply(func, engine="numba", raw=True, args=(1,))
333-
tm.assert_frame_equal(result, expected)

pandas/tests/window/test_expanding.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -691,18 +691,3 @@ def test_numeric_only_corr_cov_series(kernel, use_arg, numeric_only, dtype):
691691
op2 = getattr(expanding2, kernel)
692692
expected = op2(*arg2, numeric_only=numeric_only)
693693
tm.assert_series_equal(result, expected)
694-
695-
696-
def test_apply_numba_with_kwargs():
697-
# 58995
698-
def func(sr, a=0):
699-
return sr.sum() + a
700-
701-
data = DataFrame(range(10))
702-
703-
result = data.expanding().apply(func, engine="numba", raw=True, kwargs={"a": 1})
704-
expected = data.expanding().sum() + 1
705-
tm.assert_frame_equal(result, expected)
706-
707-
result = data.expanding().apply(func, engine="numba", raw=True, args=(1,))
708-
tm.assert_frame_equal(result, expected)

pandas/tests/window/test_groupby.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,26 +1024,6 @@ def test_datelike_on_not_monotonic_within_each_group(self):
10241024
with pytest.raises(ValueError, match="Each group within B must be monotonic."):
10251025
df.groupby("A").rolling("365D", on="B")
10261026

1027-
def test_groupby_rolling_apply_numba_with_kwargs(self, roll_frame):
1028-
def func(sr, a=0):
1029-
return sr.sum() + a
1030-
1031-
# 58995
1032-
result = (
1033-
roll_frame.groupby("A")
1034-
.rolling(5)
1035-
.apply(func, engine="numba", raw=True, kwargs={"a": 1})
1036-
)
1037-
expected = roll_frame.groupby("A").rolling(5).sum() + 1
1038-
tm.assert_frame_equal(result, expected)
1039-
1040-
result = (
1041-
roll_frame.groupby("A")
1042-
.rolling(5)
1043-
.apply(func, engine="numba", raw=True, args=(1,))
1044-
)
1045-
tm.assert_frame_equal(result, expected)
1046-
10471027

10481028
class TestExpanding:
10491029
@pytest.fixture
@@ -1154,26 +1134,6 @@ def test_expanding_apply(self, raw, frame):
11541134
expected.index = expected_index
11551135
tm.assert_frame_equal(result, expected)
11561136

1157-
def test_groupby_expanding_apply_numba_with_kwargs(self, roll_frame):
1158-
# 58995
1159-
def func(sr, a=0):
1160-
return sr.sum() + a
1161-
1162-
result = (
1163-
roll_frame.groupby("A")
1164-
.expanding()
1165-
.apply(func, engine="numba", raw=True, kwargs={"a": 1})
1166-
)
1167-
expected = roll_frame.groupby("A").expanding().sum() + 1
1168-
tm.assert_frame_equal(result, expected)
1169-
1170-
result = (
1171-
roll_frame.groupby("A")
1172-
.expanding()
1173-
.apply(func, engine="numba", raw=True, args=(1,))
1174-
)
1175-
tm.assert_frame_equal(result, expected)
1176-
11771137

11781138
class TestEWM:
11791139
@pytest.mark.parametrize(

pandas/tests/window/test_numba.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ def arithmetic_numba_supported_operators(request):
3838
return request.param
3939

4040

41+
@pytest.fixture
42+
def roll_frame():
43+
return DataFrame({"A": [1] * 20 + [2] * 12 + [3] * 8, "B": np.arange(40)})
44+
45+
4146
@td.skip_if_no("numba")
4247
@pytest.mark.filterwarnings("ignore")
4348
# Filter warnings when parallel=True and the function can't be parallelized by Numba
@@ -67,6 +72,62 @@ def f(x, *args):
6772
)
6873
tm.assert_series_equal(result, expected)
6974

75+
def test_apply_numba_with_kwargs(self, roll_frame):
76+
# GH 58995
77+
# rolling apply
78+
def func(sr, a=0):
79+
return sr.sum() + a
80+
81+
data = DataFrame(range(10))
82+
83+
result = data.rolling(5).apply(func, engine="numba", raw=True, kwargs={"a": 1})
84+
expected = data.rolling(5).sum() + 1
85+
tm.assert_frame_equal(result, expected)
86+
87+
result = data.rolling(5).apply(func, engine="numba", raw=True, args=(1,))
88+
tm.assert_frame_equal(result, expected)
89+
90+
# expanding apply
91+
92+
result = data.expanding().apply(func, engine="numba", raw=True, kwargs={"a": 1})
93+
expected = data.expanding().sum() + 1
94+
tm.assert_frame_equal(result, expected)
95+
96+
result = data.expanding().apply(func, engine="numba", raw=True, args=(1,))
97+
tm.assert_frame_equal(result, expected)
98+
99+
# groupby rolling
100+
result = (
101+
roll_frame.groupby("A")
102+
.rolling(5)
103+
.apply(func, engine="numba", raw=True, kwargs={"a": 1})
104+
)
105+
expected = roll_frame.groupby("A").rolling(5).sum() + 1
106+
tm.assert_frame_equal(result, expected)
107+
108+
result = (
109+
roll_frame.groupby("A")
110+
.rolling(5)
111+
.apply(func, engine="numba", raw=True, args=(1,))
112+
)
113+
tm.assert_frame_equal(result, expected)
114+
# groupby expanding
115+
116+
result = (
117+
roll_frame.groupby("A")
118+
.expanding()
119+
.apply(func, engine="numba", raw=True, kwargs={"a": 1})
120+
)
121+
expected = roll_frame.groupby("A").expanding().sum() + 1
122+
tm.assert_frame_equal(result, expected)
123+
124+
result = (
125+
roll_frame.groupby("A")
126+
.expanding()
127+
.apply(func, engine="numba", raw=True, args=(1,))
128+
)
129+
tm.assert_frame_equal(result, expected)
130+
70131
def test_numba_min_periods(self):
71132
# GH 58868
72133
def last_row(x):

0 commit comments

Comments
 (0)