Skip to content

Commit 9add8bb

Browse files
committed
Move numba tests to test_numba.py
1 parent d10238a commit 9add8bb

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

pandas/tests/groupby/aggregate/test_numba.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,23 @@ def test_multifunc_numba_vs_cython_frame(agg_kwargs):
186186
tm.assert_frame_equal(result, expected)
187187

188188

189+
@pytest.mark.parametrize("func", ["sum", "mean"])
190+
def test_multifunc_numba_vs_cython_frame_noskipna(func):
191+
pytest.importorskip("numba")
192+
data = DataFrame(
193+
{
194+
0: ["a", "a", "b", "b", "a"],
195+
1: [1.0, np.nan, 3.0, 4.0, 5.0],
196+
2: [1, 2, 3, 4, 5],
197+
},
198+
columns=[0, 1, 2],
199+
)
200+
grouped = data.groupby(0)
201+
result = grouped.agg(func, skipna=False, engine="numba")
202+
expected = grouped.agg(func, skipna=False, engine="cython")
203+
tm.assert_frame_equal(result, expected)
204+
205+
189206
@pytest.mark.parametrize(
190207
"agg_kwargs,expected_func",
191208
[

pandas/tests/groupby/test_reductions.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -466,10 +466,6 @@ def test_mean_skipna(values, dtype, result_dtype, skipna):
466466
)
467467
result = df.groupby("cat")["val"].mean(skipna=skipna)
468468
tm.assert_series_equal(result, expected)
469-
if dtype == "float64":
470-
# For float64, test the numba version as well
471-
result = df.groupby("cat")["val"].mean(skipna=skipna, engine="numba")
472-
tm.assert_series_equal(result, expected)
473469

474470

475471
@pytest.mark.parametrize(
@@ -496,10 +492,6 @@ def test_sum_skipna(values, dtype, skipna):
496492
)
497493
result = df.groupby("cat")["val"].sum(skipna=skipna)
498494
tm.assert_series_equal(result, expected)
499-
if dtype == "float64":
500-
# For float64, test the numba version as well
501-
result = df.groupby("cat")["val"].sum(skipna=skipna, engine="numba")
502-
tm.assert_series_equal(result, expected)
503495

504496

505497
def test_sum_skipna_object(skipna):

0 commit comments

Comments
 (0)