Skip to content

Commit 2aae933

Browse files
committed
add prepare_function_arguments
1 parent 96581a3 commit 2aae933

File tree

4 files changed

+78
-10
lines changed

4 files changed

+78
-10
lines changed

pandas/core/_numba/executor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414

1515
from pandas.compat._optional import import_optional_dependency
1616

17+
from pandas.core.util.numba_ import jit_user_function
18+
1719

1820
@functools.cache
1921
def generate_apply_looper(func, nopython=True, nogil=True, parallel=False):
2022
if TYPE_CHECKING:
2123
import numba
2224
else:
2325
numba = import_optional_dependency("numba")
24-
nb_compat_func = numba.extending.register_jitable(func)
26+
nb_compat_func = jit_user_function(func)
2527

2628
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
2729
def nb_looper(values, axis, *args):

pandas/core/apply.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@
5151
from pandas.core._numba.executor import generate_apply_looper
5252
import pandas.core.common as com
5353
from pandas.core.construction import ensure_wrapped_if_datetimelike
54-
from pandas.core.util.numba_ import get_jit_arguments
54+
from pandas.core.util.numba_ import (
55+
get_jit_arguments,
56+
prepare_function_arguments,
57+
)
5558

5659
if TYPE_CHECKING:
5760
from collections.abc import (
@@ -973,15 +976,16 @@ def wrapper(*args, **kwargs):
973976
return wrapper
974977

975978
if engine == "numba":
979+
args, kwargs = prepare_function_arguments(self.func, self.args, self.kwargs)
976980
# error: Argument 1 to "__call__" of "_lru_cache_wrapper" has
977981
# incompatible type "Callable[..., Any] | str | list[Callable
978982
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str |
979983
# list[Callable[..., Any] | str]]"; expected "Hashable"
980984
nb_looper = generate_apply_looper(
981985
self.func, # type: ignore[arg-type]
982-
**get_jit_arguments(engine_kwargs, self.kwargs),
986+
**get_jit_arguments(engine_kwargs, kwargs),
983987
)
984-
result = nb_looper(self.values, self.axis, *self.args)
988+
result = nb_looper(self.values, self.axis, *args)
985989
# If we made the result 2-D, squeeze it back to 1-D
986990
result = np.squeeze(result)
987991
else:
@@ -1135,9 +1139,10 @@ def numba_func(values, col_names, df_index, *args):
11351139
return numba_func
11361140

11371141
def apply_with_numba(self) -> dict[int, Any]:
1142+
args, kwargs = prepare_function_arguments(self.func, self.args, self.kwargs)
11381143
nb_func = self.generate_numba_apply_func(
11391144
cast(Callable, self.func),
1140-
**get_jit_arguments(self.engine_kwargs, self.kwargs),
1145+
**get_jit_arguments(self.engine_kwargs, kwargs),
11411146
)
11421147
from pandas.core._numba.extensions import set_numba_data
11431148

@@ -1152,7 +1157,7 @@ def apply_with_numba(self) -> dict[int, Any]:
11521157
# Convert from numba dict to regular dict
11531158
# Our isinstance checks in the df constructor don't pass for numbas typed dict
11541159
with set_numba_data(index) as index, set_numba_data(columns) as columns:
1155-
res = dict(nb_func(self.values, columns, index, *self.args))
1160+
res = dict(nb_func(self.values, columns, index, *args))
11561161
return res
11571162

11581163
@property
@@ -1279,9 +1284,10 @@ def numba_func(values, col_names_index, index, *args):
12791284
return numba_func
12801285

12811286
def apply_with_numba(self) -> dict[int, Any]:
1287+
args, kwargs = prepare_function_arguments(self.func, self.args, self.kwargs)
12821288
nb_func = self.generate_numba_apply_func(
12831289
cast(Callable, self.func),
1284-
**get_jit_arguments(self.engine_kwargs, self.kwargs),
1290+
**get_jit_arguments(self.engine_kwargs, kwargs),
12851291
)
12861292

12871293
from pandas.core._numba.extensions import set_numba_data
@@ -1292,7 +1298,7 @@ def apply_with_numba(self) -> dict[int, Any]:
12921298
set_numba_data(self.obj.index) as index,
12931299
set_numba_data(self.columns) as columns,
12941300
):
1295-
res = dict(nb_func(self.values, columns, index, *self.args))
1301+
res = dict(nb_func(self.values, columns, index, *args))
12961302

12971303
return res
12981304

pandas/core/util/numba_.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import inspect
56
import types
67
from typing import (
78
TYPE_CHECKING,
@@ -97,3 +98,47 @@ def jit_user_function(func: Callable) -> Callable:
9798
numba_func = numba.extending.register_jitable(func)
9899

99100
return numba_func
101+
102+
103+
_sentinel = object()
104+
105+
106+
def prepare_function_arguments(
107+
func: Callable, args: tuple, kwargs: dict
108+
) -> tuple[tuple, dict]:
109+
"""
110+
Prepare arguments for jitted function. As numba functions do not support kwargs,
111+
we try to move kwargs into args if possible.
112+
113+
Parameters
114+
----------
115+
func : function
116+
user defined function
117+
args : tuple
118+
user input positional arguments
119+
kwargs : dict
120+
user input keyword arguments
121+
122+
Returns
123+
-------
124+
tuple[tuple, dict]
125+
args, kwargs
126+
127+
"""
128+
if not kwargs:
129+
return args, kwargs
130+
131+
# the udf should have this pattern: def udf(value, *args, **kwargs):...
132+
signature = inspect.signature(func)
133+
arguments = signature.bind(_sentinel, *args, **kwargs)
134+
arguments.apply_defaults()
135+
# Ref: https://peps.python.org/pep-0362/
136+
# Arguments which could be passed as part of either *args or **kwargs
137+
# will be included only in the BoundArguments.args attribute.
138+
args = arguments.args
139+
kwargs = arguments.kwargs
140+
141+
assert args[0] is _sentinel
142+
args = args[1:]
143+
144+
return args, kwargs

pandas/tests/apply/test_frame_apply.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,35 @@ def test_apply(float_frame, engine, request):
6464
@pytest.mark.parametrize("axis", [0, 1])
6565
@pytest.mark.parametrize("raw", [True, False])
6666
def test_apply_args(float_frame, axis, raw, engine):
67-
# GH:58712
6867
result = float_frame.apply(
6968
lambda x, y: x + y, axis, args=(1,), raw=raw, engine=engine
7069
)
7170
expected = float_frame + 1
7271
tm.assert_frame_equal(result, expected)
7372

73+
# GH:58712
74+
result = float_frame.apply(
75+
lambda x, a, b: x + a + b, args=(1,), b=2, engine=engine, raw=raw
76+
)
77+
expected = float_frame + 3
78+
tm.assert_frame_equal(result, expected)
79+
7480
if engine == "numba":
81+
# keyword-only arguments are not supported in numba
82+
with pytest.raises(
83+
pd.errors.NumbaUtilError,
84+
match="numba does not support kwargs with nopython=True",
85+
):
86+
float_frame.apply(
87+
lambda x, a, *, b: x + a + b, args=(1,), b=2, engine=engine, raw=raw
88+
)
89+
7590
with pytest.raises(
7691
pd.errors.NumbaUtilError,
7792
match="numba does not support kwargs with nopython=True",
7893
):
7994
float_frame.apply(
80-
lambda x, a, b: x + a + b, args=(1,), b=2, engine=engine, raw=raw
95+
lambda *x, b: x[0] + x[1] + b, args=(1,), b=2, engine=engine, raw=raw
8196
)
8297

8398

0 commit comments

Comments
 (0)