4949import pandas .core .common as com
5050from pandas .core .construction import ensure_wrapped_if_datetimelike
5151from pandas .core .util .numba_ import (
52- get_jit_arguments ,
5352 prepare_function_arguments ,
5453)
5554
@@ -195,7 +194,7 @@ def map(
195194 """
196195 Elementwise map for the Numba engine. Currently not supported.
197196 """
198- raise NotImplementedError ("Numba map is not implemented yet." )
197+ raise NotImplementedError ("The Numba engine is not implemented for the map method yet." )
199198
200199 @staticmethod
201200 def apply (
@@ -210,35 +209,27 @@ def apply(
210209 Apply `func` along the given axis using Numba.
211210 """
212211
213- if is_list_like (func ):
214- raise NotImplementedError (
215- "the 'numba' engine doesn't support lists of callables yet"
216- )
217-
218- if isinstance (func , str ):
219- raise NotImplementedError (
220- "the 'numba' engine doesn't support using "
221- "a string as the callable function"
222- )
212+ NumbaExecutionEngine .check_numba_support (func )
223213
224- elif isinstance ( func , np . ufunc ):
225- raise NotImplementedError (
226- "the 'numba' engine doesn't support "
227- "using a numpy ufunc as the callable function"
228- )
214+ # normalize axis values
215+ if axis in ( 0 , "index" ):
216+ axis = 0
217+ else :
218+ axis = 1
229219
230220 # check for data typing
231221 if not isinstance (data , np .ndarray ):
232- if len ( data .columns ) == 0 and len ( data . index ) == 0 :
222+ if data .empty :
233223 return data .copy () # mimic apply_empty_result()
224+ NumbaExecutionEngine .validate_values_for_numba_raw_false (
225+ data ,
226+ decorator if isinstance (decorator , dict ) else {}
227+ )
228+
234229 return NumbaExecutionEngine .apply_raw_false (
235230 data , func , args , kwargs , decorator , axis
236231 )
237232
238- engine_kwargs : dict [str , bool ] | None = (
239- decorator if isinstance (decorator , dict ) else None
240- )
241-
242233 looper_args , looper_kwargs = prepare_function_arguments (
243234 func ,
244235 args ,
@@ -249,14 +240,33 @@ def apply(
249240 # incompatible type "Callable[..., Any] | str | list[Callable
250241 # [..., Any] | str] | dict[Hashable,Callable[..., Any] | str |
251242 # list[Callable[..., Any] | str]]"; expected "Hashable"
252- nb_looper = generate_apply_looper (
243+ numba_looper = generate_apply_looper (
253244 func ,
254- ** get_jit_arguments ( engine_kwargs ) ,
245+ decorator ,
255246 )
256- result = nb_looper (data , axis , * looper_args )
247+ result = numba_looper (data , axis , * looper_args )
257248 # If we made the result 2-D, squeeze it back to 1-D
258249 return np .squeeze (result )
259250
251+ @staticmethod
252+ def check_numba_support (func ):
253+ if is_list_like (func ):
254+ raise NotImplementedError (
255+ "the 'numba' engine doesn't support lists of callables yet"
256+ )
257+
258+ elif isinstance (func , str ):
259+ raise NotImplementedError (
260+ "the 'numba' engine doesn't support using "
261+ "a string as the callable function"
262+ )
263+
264+ elif isinstance (func , np .ufunc ):
265+ raise NotImplementedError (
266+ "the 'numba' engine doesn't support "
267+ "using a numpy ufunc as the callable function"
268+ )
269+
260270 @staticmethod
261271 def apply_raw_false (
262272 data : Series | DataFrame ,
@@ -271,21 +281,8 @@ def apply_raw_false(
271281 Series ,
272282 )
273283
274- engine_kwargs : dict [str , bool ] = (
275- decorator if isinstance (decorator , dict ) else {}
276- )
277-
278- if engine_kwargs .get ("parallel" , False ):
279- raise NotImplementedError (
280- "Parallel apply is not supported when raw=False and engine='numba'"
281- )
282- if not data .index .is_unique or not data .columns .is_unique :
283- raise NotImplementedError (
284- "The index/columns must be unique when raw=False and engine='numba'"
285- )
286- NumbaExecutionEngine .validate_values_for_numba (data )
287284 results = NumbaExecutionEngine .apply_with_numba (
288- data , func , args , kwargs , engine_kwargs , axis
285+ data , func , args , kwargs , decorator , axis
289286 )
290287
291288 if results :
@@ -301,21 +298,30 @@ def apply_raw_false(
301298 return DataFrame () if isinstance (data , DataFrame ) else Series ()
302299
303300 @staticmethod
304- def validate_values_for_numba ( obj : Series | DataFrame ) -> None :
301+ def validate_values_for_numba_raw_false ( data : Series | DataFrame , engine_kwargs : dict [ str , bool ] ) -> None :
305302 from pandas import Series
306303
307- if isinstance (obj , Series ):
308- if not is_numeric_dtype (obj .dtype ):
304+ if engine_kwargs .get ("parallel" , False ):
305+ raise NotImplementedError (
306+ "Parallel apply is not supported when raw=False and engine='numba'"
307+ )
308+ if not data .index .is_unique or not data .columns .is_unique :
309+ raise NotImplementedError (
310+ "The index/columns must be unique when raw=False and engine='numba'"
311+ )
312+
313+ if isinstance (data , Series ):
314+ if not is_numeric_dtype (data .dtype ):
309315 raise ValueError (
310- f"Series must have a numeric dtype. Found '{ obj .dtype } ' instead"
316+ f"Series must have a numeric dtype. Found '{ data .dtype } ' instead"
311317 )
312- if is_extension_array_dtype (obj .dtype ):
318+ if is_extension_array_dtype (data .dtype ):
313319 raise ValueError (
314320 "Series is backed by an extension array, "
315321 "which is not supported by the numba engine."
316322 )
317323 else :
318- for colname , dtype in obj .dtypes .items ():
324+ for colname , dtype in data .dtypes .items ():
319325 if not is_numeric_dtype (dtype ):
320326 raise ValueError (
321327 f"Column { colname } must have a numeric dtype. "
@@ -330,15 +336,15 @@ def validate_values_for_numba(obj: Series | DataFrame) -> None:
330336 @staticmethod
331337 @functools .cache
332338 def generate_numba_apply_func (
333- func , axis , nogil = True , nopython = True , parallel = False
339+ func , axis , decorator : Callable
334340 ) -> Callable [[npt .NDArray , Index , Index ], dict [int , Any ]]:
335341 numba = import_optional_dependency ("numba" )
336342 from pandas import Series
337343 from pandas .core ._numba .extensions import maybe_cast_str
338344
339345 jitted_udf = numba .extending .register_jitable (func )
340346
341- @numba . jit ( nogil = nogil , nopython = nopython , parallel = parallel )
347+ @decorator # type: ignore
342348 def numba_func (values , col_names_index , index , * args ):
343349 results = {}
344350 for i in range (values .shape [1 - axis ]):
@@ -363,14 +369,14 @@ def numba_func(values, col_names_index, index, *args):
363369
364370 @staticmethod
365371 def apply_with_numba (
366- data , func , args , kwargs , engine_kwargs , axis = 0
372+ data , func , args , kwargs , decorator , axis = 0
367373 ) -> dict [int , Any ]:
368374 func = cast (Callable , func )
369375 args , kwargs = prepare_function_arguments (
370376 func , args , kwargs , num_required_args = 1
371377 )
372378 nb_func = NumbaExecutionEngine .generate_numba_apply_func (
373- func , axis , ** get_jit_arguments ( engine_kwargs )
379+ func , axis , decorator
374380 )
375381
376382 from pandas .core ._numba .extensions import set_numba_data
0 commit comments