Skip to content

Commit 2703f86

Browse files
committed
Implement and refactor raw=False logic into NumbaExecutionEngine.apply
1 parent 65b9d32 commit 2703f86

File tree

2 files changed

+125
-5
lines changed

2 files changed

+125
-5
lines changed

pandas/core/apply.py

Lines changed: 124 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,8 @@ def apply(
231231
if not isinstance(data, np.ndarray):
232232
if len(data.columns) == 0 and len(data.index) == 0:
233233
return data.copy() # mimic apply_empty_result()
234-
# TODO:
235-
# Rewrite FrameApply.apply_series_numba() logic without FrameApply object
236-
raise NotImplementedError(
237-
"raw=False is not yet supported in NumbaExecutionEngine."
234+
return NumbaExecutionEngine.apply_raw_false(
235+
data, func, args, kwargs, decorator, axis
238236
)
239237

240238
engine_kwargs: dict[str, bool] | None = (
@@ -259,6 +257,128 @@ def apply(
259257
# If we made the result 2-D, squeeze it back to 1-D
260258
return np.squeeze(result)
261259

260+
@staticmethod
261+
def apply_raw_false(
262+
data: Series | DataFrame,
263+
func,
264+
args: tuple,
265+
kwargs: dict,
266+
decorator: Callable,
267+
axis: int | str,
268+
):
269+
from pandas import (
270+
DataFrame,
271+
Series,
272+
)
273+
274+
engine_kwargs: dict[str, bool] | None = (
275+
decorator if isinstance(decorator, dict) else {}
276+
)
277+
278+
if engine_kwargs.get("parallel", False):
279+
raise NotImplementedError(
280+
"Parallel apply is not supported when raw=False and engine='numba'"
281+
)
282+
if not data.index.is_unique or not data.columns.is_unique:
283+
raise NotImplementedError(
284+
"The index/columns must be unique when raw=False and engine='numba'"
285+
)
286+
NumbaExecutionEngine.validate_values_for_numba(data)
287+
results = NumbaExecutionEngine.apply_with_numba(
288+
data, func, args, kwargs, engine_kwargs, axis
289+
)
290+
291+
if results:
292+
sample = next(iter(results.values()))
293+
if isinstance(sample, Series):
294+
df_result = DataFrame.from_dict(
295+
results, orient="index" if axis == 1 else "columns"
296+
)
297+
return df_result
298+
else:
299+
return Series(results)
300+
301+
return DataFrame() if isinstance(data, DataFrame) else Series()
302+
303+
@staticmethod
304+
def validate_values_for_numba(df: DataFrame) -> None:
305+
for colname, dtype in df.dtypes.items():
306+
if not is_numeric_dtype(dtype):
307+
raise ValueError(
308+
f"Column {colname} must have numeric dtype. Found '{dtype}'."
309+
)
310+
if is_extension_array_dtype(dtype):
311+
raise ValueError(
312+
f"Column {colname} uses extension array dtype, "
313+
"not supported by Numba."
314+
)
315+
316+
@staticmethod
317+
@functools.cache
318+
def generate_numba_apply_func(
319+
func, axis, nogil=True, nopython=True, parallel=False
320+
) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]:
321+
numba = import_optional_dependency("numba")
322+
from pandas import Series
323+
from pandas.core._numba.extensions import maybe_cast_str
324+
325+
jitted_udf = numba.extending.register_jitable(func)
326+
327+
@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
328+
def numba_func(values, col_names_index, index, *args):
329+
results = {}
330+
for i in range(values.shape[1 - axis]):
331+
if axis == 0 or axis == "index":
332+
arr = values[:, i]
333+
result_key = index[i]
334+
arr_index = col_names_index
335+
else:
336+
arr = values[i].copy()
337+
result_key = index[i]
338+
arr_index = col_names_index
339+
ser = Series(
340+
arr,
341+
index=arr_index,
342+
name=maybe_cast_str(result_key),
343+
)
344+
results[result_key] = jitted_udf(ser, *args)
345+
346+
return results
347+
348+
return numba_func
349+
350+
@staticmethod
351+
def apply_with_numba(
352+
data, func, args, kwargs, engine_kwargs, axis=0
353+
) -> dict[int, Any]:
354+
func = cast(Callable, func)
355+
args, kwargs = prepare_function_arguments(
356+
func, args, kwargs, num_required_args=1
357+
)
358+
nb_func = NumbaExecutionEngine.generate_numba_apply_func(
359+
func, axis, **get_jit_arguments(engine_kwargs)
360+
)
361+
362+
from pandas.core._numba.extensions import set_numba_data
363+
364+
# Convert from numba dict to regular dict
365+
# Our isinstance checks in the df constructor don't pass for numbas typed dict
366+
367+
if axis == 0 or axis == "index":
368+
col_names_index = data.index
369+
result_index = data.columns
370+
else:
371+
col_names_index = data.columns
372+
result_index = data.index
373+
374+
with (
375+
set_numba_data(result_index) as index,
376+
set_numba_data(col_names_index) as columns,
377+
):
378+
res = dict(nb_func(data.values, columns, index, *args))
379+
380+
return res
381+
262382

263383
def frame_apply(
264384
obj: DataFrame,

pandas/core/frame.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10648,7 +10648,7 @@ def apply(
1064810648
)
1064910649
return op.apply().__finalize__(self, method="apply")
1065010650

10651-
if hasattr(engine, "__pandas_udf__"):
10651+
elif hasattr(engine, "__pandas_udf__"):
1065210652
if result_type is not None:
1065310653
raise NotImplementedError(
1065410654
f"{result_type=} only implemented for the default engine"

0 commit comments

Comments
 (0)