Skip to content

Commit 96581a3

Browse files
committed
add *args for raw=False as well; merge tests together
1 parent c026845 commit 96581a3

File tree

3 files changed

+22
-32
lines changed

3 files changed

+22
-32
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ Other
498498
- Bug in :class:`DataFrame` when passing a ``dict`` with a NA scalar and ``columns`` that would always return ``np.nan`` (:issue:`57205`)
499499
- Bug in :func:`eval` where the names of the :class:`Series` were not preserved when using ``engine="numexpr"``. (:issue:`10239`)
500500
- Bug in :func:`unique` on :class:`Index` not always returning :class:`Index` (:issue:`57043`)
501-
- Bug in :meth:`DataFrame.apply` where passing ``raw=True`` and ``engine="numba"`` ignored ``args`` passed to the applied function (:issue:`58712`)
501+
- Bug in :meth:`DataFrame.apply` where passing ``engine="numba"`` ignored ``args`` passed to the applied function (:issue:`58712`)
502502
- Bug in :meth:`DataFrame.eval` and :meth:`DataFrame.query` which caused an exception when using NumPy attributes via ``@`` notation, e.g., ``df.eval("@np.floor(a)")``. (:issue:`58041`)
503503
- Bug in :meth:`DataFrame.eval` and :meth:`DataFrame.query` which did not allow to use ``tan`` function. (:issue:`55091`)
504504
- Bug in :meth:`DataFrame.sort_index` when passing ``axis="columns"`` and ``ignore_index=True`` and ``ascending=False`` not returning a :class:`RangeIndex` columns (:issue:`57293`)

pandas/core/apply.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,21 +1122,22 @@ def generate_numba_apply_func(
11221122
# Currently the parallel argument doesn't get passed through here
11231123
# (it's disabled) since the dicts in numba aren't thread-safe.
11241124
@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
1125-
def numba_func(values, col_names, df_index):
1125+
def numba_func(values, col_names, df_index, *args):
11261126
results = {}
11271127
for j in range(values.shape[1]):
11281128
# Create the series
11291129
ser = Series(
11301130
values[:, j], index=df_index, name=maybe_cast_str(col_names[j])
11311131
)
1132-
results[j] = jitted_udf(ser)
1132+
results[j] = jitted_udf(ser, *args)
11331133
return results
11341134

11351135
return numba_func
11361136

11371137
def apply_with_numba(self) -> dict[int, Any]:
11381138
nb_func = self.generate_numba_apply_func(
1139-
cast(Callable, self.func), **self.engine_kwargs
1139+
cast(Callable, self.func),
1140+
**get_jit_arguments(self.engine_kwargs, self.kwargs),
11401141
)
11411142
from pandas.core._numba.extensions import set_numba_data
11421143

@@ -1151,7 +1152,7 @@ def apply_with_numba(self) -> dict[int, Any]:
11511152
# Convert from numba dict to regular dict
11521153
# Our isinstance checks in the df constructor don't pass for numbas typed dict
11531154
with set_numba_data(index) as index, set_numba_data(columns) as columns:
1154-
res = dict(nb_func(self.values, columns, index))
1155+
res = dict(nb_func(self.values, columns, index, *self.args))
11551156
return res
11561157

11571158
@property
@@ -1259,7 +1260,7 @@ def generate_numba_apply_func(
12591260
jitted_udf = numba.extending.register_jitable(func)
12601261

12611262
@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
1262-
def numba_func(values, col_names_index, index):
1263+
def numba_func(values, col_names_index, index, *args):
12631264
results = {}
12641265
# Currently the parallel argument doesn't get passed through here
12651266
# (it's disabled) since the dicts in numba aren't thread-safe.
@@ -1271,15 +1272,16 @@ def numba_func(values, col_names_index, index):
12711272
index=col_names_index,
12721273
name=maybe_cast_str(index[i]),
12731274
)
1274-
results[i] = jitted_udf(ser)
1275+
results[i] = jitted_udf(ser, *args)
12751276

12761277
return results
12771278

12781279
return numba_func
12791280

12801281
def apply_with_numba(self) -> dict[int, Any]:
12811282
nb_func = self.generate_numba_apply_func(
1282-
cast(Callable, self.func), **self.engine_kwargs
1283+
cast(Callable, self.func),
1284+
**get_jit_arguments(self.engine_kwargs, self.kwargs),
12831285
)
12841286

12851287
from pandas.core._numba.extensions import set_numba_data
@@ -1290,7 +1292,7 @@ def apply_with_numba(self) -> dict[int, Any]:
12901292
set_numba_data(self.obj.index) as index,
12911293
set_numba_data(self.columns) as columns,
12921294
):
1293-
res = dict(nb_func(self.values, columns, index))
1295+
res = dict(nb_func(self.values, columns, index, *self.args))
12941296

12951297
return res
12961298

pandas/tests/apply/test_frame_apply.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,23 @@ def test_apply(float_frame, engine, request):
6363

6464
@pytest.mark.parametrize("axis", [0, 1])
6565
@pytest.mark.parametrize("raw", [True, False])
66-
def test_apply_args(float_frame, axis, raw, engine, request):
67-
if engine == "numba" and raw is False:
68-
mark = pytest.mark.xfail(reason="numba engine doesn't support args")
69-
request.node.add_marker(mark)
66+
def test_apply_args(float_frame, axis, raw, engine):
67+
# GH:58712
7068
result = float_frame.apply(
7169
lambda x, y: x + y, axis, args=(1,), raw=raw, engine=engine
7270
)
7371
expected = float_frame + 1
7472
tm.assert_frame_equal(result, expected)
7573

74+
if engine == "numba":
75+
with pytest.raises(
76+
pd.errors.NumbaUtilError,
77+
match="numba does not support kwargs with nopython=True",
78+
):
79+
float_frame.apply(
80+
lambda x, a, b: x + a + b, args=(1,), b=2, engine=engine, raw=raw
81+
)
82+
7683

7784
def test_apply_categorical_func():
7885
# GH 9573
@@ -1718,22 +1725,3 @@ def test_agg_dist_like_and_nonunique_columns():
17181725
result = df.agg({"A": "count"})
17191726
expected = df["A"].count()
17201727
tm.assert_series_equal(result, expected)
1721-
1722-
1723-
def test_numba_raw_apply_with_args(engine):
1724-
if engine == "numba":
1725-
# GH:58712
1726-
df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
1727-
result = df.apply(
1728-
lambda x, a, b: x + a + b, args=(1, 2), engine=engine, raw=True
1729-
)
1730-
# note: result is always float dtype,
1731-
# see core._numba.executor.py:generate_apply_looper
1732-
expected = df + 3.0
1733-
tm.assert_frame_equal(result, expected)
1734-
1735-
with pytest.raises(
1736-
pd.errors.NumbaUtilError,
1737-
match="numba does not support kwargs with nopython=True",
1738-
):
1739-
df.apply(lambda x, a, b: x + a + b, args=(1,), b=2, engine=engine, raw=True)

0 commit comments

Comments
 (0)