Skip to content

Commit b08c361

Browse files
committed
Updated engine_kwargs extraction logic and added test
1 parent f59fb52 commit b08c361

File tree

4 files changed

+36
-4
lines changed

4 files changed

+36
-4
lines changed

pandas/core/apply.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import pandas.core.common as com
5050
from pandas.core.construction import ensure_wrapped_if_datetimelike
5151
from pandas.core.util.numba_ import (
52+
extract_numba_options,
5253
get_jit_arguments,
5354
prepare_function_arguments,
5455
)
@@ -224,9 +225,9 @@ def apply(
224225
if not isinstance(data, np.ndarray):
225226
if data.empty:
226227
return data.copy() # mimic apply_empty_result()
227-
engine_kwargs = (
228-
decorator.engine_kwargs if hasattr(decorator, "engine_kwargs") else {}
229-
)
228+
229+
engine_kwargs = extract_numba_options(decorator)
230+
230231
NumbaExecutionEngine.validate_values_for_numba_raw_false(
231232
data, get_jit_arguments(engine_kwargs)
232233
)

pandas/core/frame.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10623,7 +10623,6 @@ def apply(
1062310623
numba = import_optional_dependency("numba")
1062410624
numba_jit = numba.jit(**engine_kwargs or {})
1062510625
numba_jit.__pandas_udf__ = NumbaExecutionEngine
10626-
numba_jit.engine_kwargs = engine_kwargs
1062710626
engine = numba_jit
1062810627

1062910628
if engine is None or isinstance(engine, str):

pandas/core/util/numba_.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,24 @@ def prepare_function_arguments(
148148

149149
args = args[num_required_args:]
150150
return args, kwargs
151+
152+
153+
def extract_numba_options(decorator):
154+
"""
155+
Extract targetoptions from a numba.jit decorator
156+
"""
157+
try:
158+
closure = decorator.__closure__
159+
if closure is None:
160+
return {}
161+
freevars = decorator.__code__.co_freevars
162+
if "targetoptions" not in freevars:
163+
return {}
164+
idx = freevars.index("targetoptions")
165+
cell = closure[idx]
166+
targetoptions = cell.cell_contents
167+
if isinstance(targetoptions, dict):
168+
return targetoptions
169+
return {}
170+
except Exception:
171+
return {}

pandas/tests/apply/test_numba.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
from pandas.compat import is_platform_arm
5+
from pandas.core.util.numba_ import extract_numba_options
56
import pandas.util._test_decorators as td
67

78
import pandas as pd
@@ -127,3 +128,13 @@ def test_numba_unsupported_dtypes(apply_axis):
127128
"which is not supported by the numba engine.",
128129
):
129130
df["c"].to_frame().apply(f, engine="numba", axis=apply_axis)
131+
132+
133+
@pytest.mark.parametrize("jit_args", [
134+
{"parallel": True, "nogil": True},
135+
{"parallel": False, "nogil": False},
136+
])
137+
def test_extract_numba_options_from_user_decorated_function(jit_args):
138+
extracted = extract_numba_options(numba.jit(**jit_args))
139+
for k, v in jit_args.items():
140+
assert extracted.get(k) == v

0 commit comments

Comments
 (0)