File tree Expand file tree Collapse file tree 4 files changed +36
-4
lines changed Expand file tree Collapse file tree 4 files changed +36
-4
lines changed Original file line number Diff line number Diff line change 49
49
import pandas .core .common as com
50
50
from pandas .core .construction import ensure_wrapped_if_datetimelike
51
51
from pandas .core .util .numba_ import (
52
+ extract_numba_options ,
52
53
get_jit_arguments ,
53
54
prepare_function_arguments ,
54
55
)
@@ -224,9 +225,9 @@ def apply(
224
225
if not isinstance (data , np .ndarray ):
225
226
if data .empty :
226
227
return data .copy () # mimic apply_empty_result()
227
- engine_kwargs = (
228
- decorator . engine_kwargs if hasattr (decorator , "engine_kwargs" ) else {}
229
- )
228
+
229
+ engine_kwargs = extract_numba_options (decorator )
230
+
230
231
NumbaExecutionEngine .validate_values_for_numba_raw_false (
231
232
data , get_jit_arguments (engine_kwargs )
232
233
)
Original file line number Diff line number Diff line change @@ -10623,7 +10623,6 @@ def apply(
10623
10623
numba = import_optional_dependency ("numba" )
10624
10624
numba_jit = numba .jit (** engine_kwargs or {})
10625
10625
numba_jit .__pandas_udf__ = NumbaExecutionEngine
10626
- numba_jit .engine_kwargs = engine_kwargs
10627
10626
engine = numba_jit
10628
10627
10629
10628
if engine is None or isinstance (engine , str ):
Original file line number Diff line number Diff line change @@ -148,3 +148,24 @@ def prepare_function_arguments(
148
148
149
149
args = args [num_required_args :]
150
150
return args , kwargs
151
+
152
+
153
+ def extract_numba_options (decorator ):
154
+ """
155
+ Extract targetoptions from a numba.jit decorator
156
+ """
157
+ try :
158
+ closure = decorator .__closure__
159
+ if closure is None :
160
+ return {}
161
+ freevars = decorator .__code__ .co_freevars
162
+ if "targetoptions" not in freevars :
163
+ return {}
164
+ idx = freevars .index ("targetoptions" )
165
+ cell = closure [idx ]
166
+ targetoptions = cell .cell_contents
167
+ if isinstance (targetoptions , dict ):
168
+ return targetoptions
169
+ return {}
170
+ except Exception :
171
+ return {}
Original file line number Diff line number Diff line change 2
2
import pytest
3
3
4
4
from pandas .compat import is_platform_arm
5
+ from pandas .core .util .numba_ import extract_numba_options
5
6
import pandas .util ._test_decorators as td
6
7
7
8
import pandas as pd
@@ -127,3 +128,13 @@ def test_numba_unsupported_dtypes(apply_axis):
127
128
"which is not supported by the numba engine." ,
128
129
):
129
130
df ["c" ].to_frame ().apply (f , engine = "numba" , axis = apply_axis )
131
+
132
+
133
+ @pytest .mark .parametrize ("jit_args" , [
134
+ {"parallel" : True , "nogil" : True },
135
+ {"parallel" : False , "nogil" : False },
136
+ ])
137
+ def test_extract_numba_options_from_user_decorated_function (jit_args ):
138
+ extracted = extract_numba_options (numba .jit (** jit_args ))
139
+ for k , v in jit_args .items ():
140
+ assert extracted .get (k ) == v
You can’t perform that action at this time.
0 commit comments