File tree Expand file tree Collapse file tree 4 files changed +36
-4
lines changed
Expand file tree Collapse file tree 4 files changed +36
-4
lines changed Original file line number Diff line number Diff line change 4949import pandas .core .common as com
5050from pandas .core .construction import ensure_wrapped_if_datetimelike
5151from pandas .core .util .numba_ import (
52+ extract_numba_options ,
5253 get_jit_arguments ,
5354 prepare_function_arguments ,
5455)
@@ -224,9 +225,9 @@ def apply(
224225 if not isinstance (data , np .ndarray ):
225226 if data .empty :
226227 return data .copy () # mimic apply_empty_result()
227- engine_kwargs = (
228- decorator . engine_kwargs if hasattr (decorator , "engine_kwargs" ) else {}
229- )
228+
229+ engine_kwargs = extract_numba_options (decorator )
230+
230231 NumbaExecutionEngine .validate_values_for_numba_raw_false (
231232 data , get_jit_arguments (engine_kwargs )
232233 )
Original file line number Diff line number Diff line change @@ -10623,7 +10623,6 @@ def apply(
1062310623 numba = import_optional_dependency ("numba" )
1062410624 numba_jit = numba .jit (** engine_kwargs or {})
1062510625 numba_jit .__pandas_udf__ = NumbaExecutionEngine
10626- numba_jit .engine_kwargs = engine_kwargs
1062710626 engine = numba_jit
1062810627
1062910628 if engine is None or isinstance (engine , str ):
Original file line number Diff line number Diff line change @@ -148,3 +148,24 @@ def prepare_function_arguments(
148148
149149 args = args [num_required_args :]
150150 return args , kwargs
151+
152+
153+ def extract_numba_options (decorator ):
154+ """
155+ Extract targetoptions from a numba.jit decorator
156+ """
157+ try :
158+ closure = decorator .__closure__
159+ if closure is None :
160+ return {}
161+ freevars = decorator .__code__ .co_freevars
162+ if "targetoptions" not in freevars :
163+ return {}
164+ idx = freevars .index ("targetoptions" )
165+ cell = closure [idx ]
166+ targetoptions = cell .cell_contents
167+ if isinstance (targetoptions , dict ):
168+ return targetoptions
169+ return {}
170+ except Exception :
171+ return {}
Original file line number Diff line number Diff line change 22import pytest
33
44from pandas .compat import is_platform_arm
5+ from pandas .core .util .numba_ import extract_numba_options
56import pandas .util ._test_decorators as td
67
78import pandas as pd
@@ -127,3 +128,13 @@ def test_numba_unsupported_dtypes(apply_axis):
127128 "which is not supported by the numba engine." ,
128129 ):
129130 df ["c" ].to_frame ().apply (f , engine = "numba" , axis = apply_axis )
131+
132+
133+ @pytest .mark .parametrize ("jit_args" , [
134+ {"parallel" : True , "nogil" : True },
135+ {"parallel" : False , "nogil" : False },
136+ ])
137+ def test_extract_numba_options_from_user_decorated_function (jit_args ):
138+ extracted = extract_numba_options (numba .jit (** jit_args ))
139+ for k , v in jit_args .items ():
140+ assert extracted .get (k ) == v
You can’t perform that action at this time.
0 commit comments