49
49
import pandas .core .common as com
50
50
from pandas .core .construction import ensure_wrapped_if_datetimelike
51
51
from pandas .core .util .numba_ import (
52
- get_jit_arguments ,
53
52
prepare_function_arguments ,
54
53
)
55
54
@@ -195,7 +194,7 @@ def map(
195
194
"""
196
195
Elementwise map for the Numba engine. Currently not supported.
197
196
"""
198
- raise NotImplementedError ("Numba map is not implemented yet." )
197
+ raise NotImplementedError ("The Numba engine is not implemented for the map method yet." )
199
198
200
199
@staticmethod
201
200
def apply (
@@ -210,35 +209,27 @@ def apply(
210
209
Apply `func` along the given axis using Numba.
211
210
"""
212
211
213
- if is_list_like (func ):
214
- raise NotImplementedError (
215
- "the 'numba' engine doesn't support lists of callables yet"
216
- )
217
-
218
- if isinstance (func , str ):
219
- raise NotImplementedError (
220
- "the 'numba' engine doesn't support using "
221
- "a string as the callable function"
222
- )
212
+ NumbaExecutionEngine .check_numba_support (func )
223
213
224
- elif isinstance ( func , np . ufunc ):
225
- raise NotImplementedError (
226
- "the 'numba' engine doesn't support "
227
- "using a numpy ufunc as the callable function"
228
- )
214
+ # normalize axis values
215
+ if axis in ( 0 , "index" ):
216
+ axis = 0
217
+ else :
218
+ axis = 1
229
219
230
220
# check for data typing
231
221
if not isinstance (data , np .ndarray ):
232
- if len ( data .columns ) == 0 and len ( data . index ) == 0 :
222
+ if data .empty :
233
223
return data .copy () # mimic apply_empty_result()
224
+ NumbaExecutionEngine .validate_values_for_numba_raw_false (
225
+ data ,
226
+ decorator if isinstance (decorator , dict ) else {}
227
+ )
228
+
234
229
return NumbaExecutionEngine .apply_raw_false (
235
230
data , func , args , kwargs , decorator , axis
236
231
)
237
232
238
- engine_kwargs : dict [str , bool ] | None = (
239
- decorator if isinstance (decorator , dict ) else None
240
- )
241
-
242
233
looper_args , looper_kwargs = prepare_function_arguments (
243
234
func ,
244
235
args ,
@@ -249,14 +240,33 @@ def apply(
249
240
# incompatible type "Callable[..., Any] | str | list[Callable
250
241
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str |
251
242
# list[Callable[..., Any] | str]]"; expected "Hashable"
252
- nb_looper = generate_apply_looper (
243
+ numba_looper = generate_apply_looper (
253
244
func ,
254
- ** get_jit_arguments ( engine_kwargs ) ,
245
+ decorator ,
255
246
)
256
- result = nb_looper (data , axis , * looper_args )
247
+ result = numba_looper (data , axis , * looper_args )
257
248
# If we made the result 2-D, squeeze it back to 1-D
258
249
return np .squeeze (result )
259
250
251
+ @staticmethod
252
+ def check_numba_support (func ):
253
+ if is_list_like (func ):
254
+ raise NotImplementedError (
255
+ "the 'numba' engine doesn't support lists of callables yet"
256
+ )
257
+
258
+ elif isinstance (func , str ):
259
+ raise NotImplementedError (
260
+ "the 'numba' engine doesn't support using "
261
+ "a string as the callable function"
262
+ )
263
+
264
+ elif isinstance (func , np .ufunc ):
265
+ raise NotImplementedError (
266
+ "the 'numba' engine doesn't support "
267
+ "using a numpy ufunc as the callable function"
268
+ )
269
+
260
270
@staticmethod
261
271
def apply_raw_false (
262
272
data : Series | DataFrame ,
@@ -271,21 +281,8 @@ def apply_raw_false(
271
281
Series ,
272
282
)
273
283
274
- engine_kwargs : dict [str , bool ] = (
275
- decorator if isinstance (decorator , dict ) else {}
276
- )
277
-
278
- if engine_kwargs .get ("parallel" , False ):
279
- raise NotImplementedError (
280
- "Parallel apply is not supported when raw=False and engine='numba'"
281
- )
282
- if not data .index .is_unique or not data .columns .is_unique :
283
- raise NotImplementedError (
284
- "The index/columns must be unique when raw=False and engine='numba'"
285
- )
286
- NumbaExecutionEngine .validate_values_for_numba (data )
287
284
results = NumbaExecutionEngine .apply_with_numba (
288
- data , func , args , kwargs , engine_kwargs , axis
285
+ data , func , args , kwargs , decorator , axis
289
286
)
290
287
291
288
if results :
@@ -301,21 +298,30 @@ def apply_raw_false(
301
298
return DataFrame () if isinstance (data , DataFrame ) else Series ()
302
299
303
300
@staticmethod
304
- def validate_values_for_numba ( obj : Series | DataFrame ) -> None :
301
+ def validate_values_for_numba_raw_false ( data : Series | DataFrame , engine_kwargs : dict [ str , bool ] ) -> None :
305
302
from pandas import Series
306
303
307
- if isinstance (obj , Series ):
308
- if not is_numeric_dtype (obj .dtype ):
304
+ if engine_kwargs .get ("parallel" , False ):
305
+ raise NotImplementedError (
306
+ "Parallel apply is not supported when raw=False and engine='numba'"
307
+ )
308
+ if not data .index .is_unique or not data .columns .is_unique :
309
+ raise NotImplementedError (
310
+ "The index/columns must be unique when raw=False and engine='numba'"
311
+ )
312
+
313
+ if isinstance (data , Series ):
314
+ if not is_numeric_dtype (data .dtype ):
309
315
raise ValueError (
310
- f"Series must have a numeric dtype. Found '{ obj .dtype } ' instead"
316
+ f"Series must have a numeric dtype. Found '{ data .dtype } ' instead"
311
317
)
312
- if is_extension_array_dtype (obj .dtype ):
318
+ if is_extension_array_dtype (data .dtype ):
313
319
raise ValueError (
314
320
"Series is backed by an extension array, "
315
321
"which is not supported by the numba engine."
316
322
)
317
323
else :
318
- for colname , dtype in obj .dtypes .items ():
324
+ for colname , dtype in data .dtypes .items ():
319
325
if not is_numeric_dtype (dtype ):
320
326
raise ValueError (
321
327
f"Column { colname } must have a numeric dtype. "
@@ -330,15 +336,15 @@ def validate_values_for_numba(obj: Series | DataFrame) -> None:
330
336
@staticmethod
331
337
@functools .cache
332
338
def generate_numba_apply_func (
333
- func , axis , nogil = True , nopython = True , parallel = False
339
+ func , axis , decorator : Callable
334
340
) -> Callable [[npt .NDArray , Index , Index ], dict [int , Any ]]:
335
341
numba = import_optional_dependency ("numba" )
336
342
from pandas import Series
337
343
from pandas .core ._numba .extensions import maybe_cast_str
338
344
339
345
jitted_udf = numba .extending .register_jitable (func )
340
346
341
- @numba . jit ( nogil = nogil , nopython = nopython , parallel = parallel )
347
+ @decorator # type: ignore
342
348
def numba_func (values , col_names_index , index , * args ):
343
349
results = {}
344
350
for i in range (values .shape [1 - axis ]):
@@ -363,14 +369,14 @@ def numba_func(values, col_names_index, index, *args):
363
369
364
370
@staticmethod
365
371
def apply_with_numba (
366
- data , func , args , kwargs , engine_kwargs , axis = 0
372
+ data , func , args , kwargs , decorator , axis = 0
367
373
) -> dict [int , Any ]:
368
374
func = cast (Callable , func )
369
375
args , kwargs = prepare_function_arguments (
370
376
func , args , kwargs , num_required_args = 1
371
377
)
372
378
nb_func = NumbaExecutionEngine .generate_numba_apply_func (
373
- func , axis , ** get_jit_arguments ( engine_kwargs )
379
+ func , axis , decorator
374
380
)
375
381
376
382
from pandas .core ._numba .extensions import set_numba_data
0 commit comments