Skip to content

Commit bc2939b

Browse files
committed
Updated with reviewer suggestions and added axis normalizing
1 parent f8f1166 commit bc2939b

File tree

2 files changed

+58
-56
lines changed

2 files changed

+58
-56
lines changed

pandas/core/apply.py

Lines changed: 55 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
import pandas.core.common as com
5050
from pandas.core.construction import ensure_wrapped_if_datetimelike
5151
from pandas.core.util.numba_ import (
52-
get_jit_arguments,
5352
prepare_function_arguments,
5453
)
5554

@@ -195,7 +194,7 @@ def map(
195194
"""
196195
Elementwise map for the Numba engine. Currently not supported.
197196
"""
198-
raise NotImplementedError("Numba map is not implemented yet.")
197+
raise NotImplementedError("The Numba engine is not implemented for the map method yet.")
199198

200199
@staticmethod
201200
def apply(
@@ -210,35 +209,27 @@ def apply(
210209
Apply `func` along the given axis using Numba.
211210
"""
212211

213-
if is_list_like(func):
214-
raise NotImplementedError(
215-
"the 'numba' engine doesn't support lists of callables yet"
216-
)
217-
218-
if isinstance(func, str):
219-
raise NotImplementedError(
220-
"the 'numba' engine doesn't support using "
221-
"a string as the callable function"
222-
)
212+
NumbaExecutionEngine.check_numba_support(func)
223213

224-
elif isinstance(func, np.ufunc):
225-
raise NotImplementedError(
226-
"the 'numba' engine doesn't support "
227-
"using a numpy ufunc as the callable function"
228-
)
214+
# normalize axis values
215+
if axis in (0, "index"):
216+
axis = 0
217+
else:
218+
axis = 1
229219

230220
# check for data typing
231221
if not isinstance(data, np.ndarray):
232-
if len(data.columns) == 0 and len(data.index) == 0:
222+
if data.empty:
233223
return data.copy() # mimic apply_empty_result()
224+
NumbaExecutionEngine.validate_values_for_numba_raw_false(
225+
data,
226+
decorator if isinstance(decorator, dict) else {}
227+
)
228+
234229
return NumbaExecutionEngine.apply_raw_false(
235230
data, func, args, kwargs, decorator, axis
236231
)
237232

238-
engine_kwargs: dict[str, bool] | None = (
239-
decorator if isinstance(decorator, dict) else None
240-
)
241-
242233
looper_args, looper_kwargs = prepare_function_arguments(
243234
func,
244235
args,
@@ -249,14 +240,33 @@ def apply(
249240
# incompatible type "Callable[..., Any] | str | list[Callable
250241
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str |
251242
# list[Callable[..., Any] | str]]"; expected "Hashable"
252-
nb_looper = generate_apply_looper(
243+
numba_looper = generate_apply_looper(
253244
func,
254-
**get_jit_arguments(engine_kwargs),
245+
decorator,
255246
)
256-
result = nb_looper(data, axis, *looper_args)
247+
result = numba_looper(data, axis, *looper_args)
257248
# If we made the result 2-D, squeeze it back to 1-D
258249
return np.squeeze(result)
259250

251+
@staticmethod
252+
def check_numba_support(func):
253+
if is_list_like(func):
254+
raise NotImplementedError(
255+
"the 'numba' engine doesn't support lists of callables yet"
256+
)
257+
258+
elif isinstance(func, str):
259+
raise NotImplementedError(
260+
"the 'numba' engine doesn't support using "
261+
"a string as the callable function"
262+
)
263+
264+
elif isinstance(func, np.ufunc):
265+
raise NotImplementedError(
266+
"the 'numba' engine doesn't support "
267+
"using a numpy ufunc as the callable function"
268+
)
269+
260270
@staticmethod
261271
def apply_raw_false(
262272
data: Series | DataFrame,
@@ -271,21 +281,8 @@ def apply_raw_false(
271281
Series,
272282
)
273283

274-
engine_kwargs: dict[str, bool] = (
275-
decorator if isinstance(decorator, dict) else {}
276-
)
277-
278-
if engine_kwargs.get("parallel", False):
279-
raise NotImplementedError(
280-
"Parallel apply is not supported when raw=False and engine='numba'"
281-
)
282-
if not data.index.is_unique or not data.columns.is_unique:
283-
raise NotImplementedError(
284-
"The index/columns must be unique when raw=False and engine='numba'"
285-
)
286-
NumbaExecutionEngine.validate_values_for_numba(data)
287284
results = NumbaExecutionEngine.apply_with_numba(
288-
data, func, args, kwargs, engine_kwargs, axis
285+
data, func, args, kwargs, decorator, axis
289286
)
290287

291288
if results:
@@ -301,21 +298,30 @@ def apply_raw_false(
301298
return DataFrame() if isinstance(data, DataFrame) else Series()
302299

303300
@staticmethod
304-
def validate_values_for_numba(obj: Series | DataFrame) -> None:
301+
def validate_values_for_numba_raw_false(data: Series | DataFrame, engine_kwargs: dict[str, bool]) -> None:
305302
from pandas import Series
306303

307-
if isinstance(obj, Series):
308-
if not is_numeric_dtype(obj.dtype):
304+
if engine_kwargs.get("parallel", False):
305+
raise NotImplementedError(
306+
"Parallel apply is not supported when raw=False and engine='numba'"
307+
)
308+
if not data.index.is_unique or not data.columns.is_unique:
309+
raise NotImplementedError(
310+
"The index/columns must be unique when raw=False and engine='numba'"
311+
)
312+
313+
if isinstance(data, Series):
314+
if not is_numeric_dtype(data.dtype):
309315
raise ValueError(
310-
f"Series must have a numeric dtype. Found '{obj.dtype}' instead"
316+
f"Series must have a numeric dtype. Found '{data.dtype}' instead"
311317
)
312-
if is_extension_array_dtype(obj.dtype):
318+
if is_extension_array_dtype(data.dtype):
313319
raise ValueError(
314320
"Series is backed by an extension array, "
315321
"which is not supported by the numba engine."
316322
)
317323
else:
318-
for colname, dtype in obj.dtypes.items():
324+
for colname, dtype in data.dtypes.items():
319325
if not is_numeric_dtype(dtype):
320326
raise ValueError(
321327
f"Column {colname} must have a numeric dtype. "
@@ -330,15 +336,15 @@ def validate_values_for_numba(obj: Series | DataFrame) -> None:
330336
@staticmethod
331337
@functools.cache
332338
def generate_numba_apply_func(
333-
func, axis, nogil=True, nopython=True, parallel=False
339+
func, axis, decorator: Callable
334340
) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]:
335341
numba = import_optional_dependency("numba")
336342
from pandas import Series
337343
from pandas.core._numba.extensions import maybe_cast_str
338344

339345
jitted_udf = numba.extending.register_jitable(func)
340346

341-
@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
347+
@decorator # type: ignore
342348
def numba_func(values, col_names_index, index, *args):
343349
results = {}
344350
for i in range(values.shape[1 - axis]):
@@ -363,14 +369,14 @@ def numba_func(values, col_names_index, index, *args):
363369

364370
@staticmethod
365371
def apply_with_numba(
366-
data, func, args, kwargs, engine_kwargs, axis=0
372+
data, func, args, kwargs, decorator, axis=0
367373
) -> dict[int, Any]:
368374
func = cast(Callable, func)
369375
args, kwargs = prepare_function_arguments(
370376
func, args, kwargs, num_required_args=1
371377
)
372378
nb_func = NumbaExecutionEngine.generate_numba_apply_func(
373-
func, axis, **get_jit_arguments(engine_kwargs)
379+
func, axis, decorator
374380
)
375381

376382
from pandas.core._numba.extensions import set_numba_data

pandas/core/frame.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10621,17 +10621,13 @@ def apply(
1062110621
"""
1062210622
if engine == "numba":
1062310623
numba = import_optional_dependency("numba")
10624-
if engine_kwargs is not None:
10625-
numba_jit = numba.jit(**engine_kwargs)
10626-
else:
10627-
numba_jit = numba.jit()
10628-
numba_jit.__pandas_udf__ = NumbaExecutionEngine
10629-
engine = numba_jit
10624+
engine = numba.jit(**engine_kwargs or {})
10625+
engine.__pandas_udf__ = NumbaExecutionEngine
1063010626

1063110627
if engine is None or isinstance(engine, str):
1063210628
from pandas.core.apply import frame_apply
1063310629

10634-
if engine not in ["python"] and engine is not None:
10630+
if engine not in ["python", None]:
1063510631
raise ValueError(f"Unknown engine '{engine}'")
1063610632

1063710633
op = frame_apply(

0 commit comments

Comments
 (0)