@@ -231,10 +231,8 @@ def apply(
231
231
if not isinstance (data , np .ndarray ):
232
232
if len (data .columns ) == 0 and len (data .index ) == 0 :
233
233
return data .copy () # mimic apply_empty_result()
234
- # TODO:
235
- # Rewrite FrameApply.apply_series_numba() logic without FrameApply object
236
- raise NotImplementedError (
237
- "raw=False is not yet supported in NumbaExecutionEngine."
234
+ return NumbaExecutionEngine .apply_raw_false (
235
+ data , func , args , kwargs , decorator , axis
238
236
)
239
237
240
238
engine_kwargs : dict [str , bool ] | None = (
@@ -259,6 +257,128 @@ def apply(
259
257
# If we made the result 2-D, squeeze it back to 1-D
260
258
return np .squeeze (result )
261
259
260
+ @staticmethod
261
+ def apply_raw_false (
262
+ data : Series | DataFrame ,
263
+ func ,
264
+ args : tuple ,
265
+ kwargs : dict ,
266
+ decorator : Callable ,
267
+ axis : int | str ,
268
+ ):
269
+ from pandas import (
270
+ DataFrame ,
271
+ Series ,
272
+ )
273
+
274
+ engine_kwargs : dict [str , bool ] | None = (
275
+ decorator if isinstance (decorator , dict ) else {}
276
+ )
277
+
278
+ if engine_kwargs .get ("parallel" , False ):
279
+ raise NotImplementedError (
280
+ "Parallel apply is not supported when raw=False and engine='numba'"
281
+ )
282
+ if not data .index .is_unique or not data .columns .is_unique :
283
+ raise NotImplementedError (
284
+ "The index/columns must be unique when raw=False and engine='numba'"
285
+ )
286
+ NumbaExecutionEngine .validate_values_for_numba (data )
287
+ results = NumbaExecutionEngine .apply_with_numba (
288
+ data , func , args , kwargs , engine_kwargs , axis
289
+ )
290
+
291
+ if results :
292
+ sample = next (iter (results .values ()))
293
+ if isinstance (sample , Series ):
294
+ df_result = DataFrame .from_dict (
295
+ results , orient = "index" if axis == 1 else "columns"
296
+ )
297
+ return df_result
298
+ else :
299
+ return Series (results )
300
+
301
+ return DataFrame () if isinstance (data , DataFrame ) else Series ()
302
+
303
+ @staticmethod
304
+ def validate_values_for_numba (df : DataFrame ) -> None :
305
+ for colname , dtype in df .dtypes .items ():
306
+ if not is_numeric_dtype (dtype ):
307
+ raise ValueError (
308
+ f"Column { colname } must have numeric dtype. Found '{ dtype } '."
309
+ )
310
+ if is_extension_array_dtype (dtype ):
311
+ raise ValueError (
312
+ f"Column { colname } uses extension array dtype, "
313
+ "not supported by Numba."
314
+ )
315
+
316
+ @staticmethod
317
+ @functools .cache
318
+ def generate_numba_apply_func (
319
+ func , axis , nogil = True , nopython = True , parallel = False
320
+ ) -> Callable [[npt .NDArray , Index , Index ], dict [int , Any ]]:
321
+ numba = import_optional_dependency ("numba" )
322
+ from pandas import Series
323
+ from pandas .core ._numba .extensions import maybe_cast_str
324
+
325
+ jitted_udf = numba .extending .register_jitable (func )
326
+
327
+ @numba .jit (nogil = nogil , nopython = nopython , parallel = parallel )
328
+ def numba_func (values , col_names_index , index , * args ):
329
+ results = {}
330
+ for i in range (values .shape [1 - axis ]):
331
+ if axis == 0 or axis == "index" :
332
+ arr = values [:, i ]
333
+ result_key = index [i ]
334
+ arr_index = col_names_index
335
+ else :
336
+ arr = values [i ].copy ()
337
+ result_key = index [i ]
338
+ arr_index = col_names_index
339
+ ser = Series (
340
+ arr ,
341
+ index = arr_index ,
342
+ name = maybe_cast_str (result_key ),
343
+ )
344
+ results [result_key ] = jitted_udf (ser , * args )
345
+
346
+ return results
347
+
348
+ return numba_func
349
+
350
+ @staticmethod
351
+ def apply_with_numba (
352
+ data , func , args , kwargs , engine_kwargs , axis = 0
353
+ ) -> dict [int , Any ]:
354
+ func = cast (Callable , func )
355
+ args , kwargs = prepare_function_arguments (
356
+ func , args , kwargs , num_required_args = 1
357
+ )
358
+ nb_func = NumbaExecutionEngine .generate_numba_apply_func (
359
+ func , axis , ** get_jit_arguments (engine_kwargs )
360
+ )
361
+
362
+ from pandas .core ._numba .extensions import set_numba_data
363
+
364
+ # Convert from numba dict to regular dict
365
+ # Our isinstance checks in the df constructor don't pass for numbas typed dict
366
+
367
+ if axis == 0 or axis == "index" :
368
+ col_names_index = data .index
369
+ result_index = data .columns
370
+ else :
371
+ col_names_index = data .columns
372
+ result_index = data .index
373
+
374
+ with (
375
+ set_numba_data (result_index ) as index ,
376
+ set_numba_data (col_names_index ) as columns ,
377
+ ):
378
+ res = dict (nb_func (data .values , columns , index , * args ))
379
+
380
+ return res
381
+
262
382
263
383
def frame_apply (
264
384
obj : DataFrame ,
0 commit comments