Skip to content

Commit 24f75e2

Browse files
committed
Add a numba_switch decorator to nanops and replace most of the bottleneck switches
1 parent 17b90c5 commit 24f75e2

File tree

1 file changed

+42
-4
lines changed

1 file changed

+42
-4
lines changed

pandas/core/nanops.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@
4848
notna,
4949
)
5050

51+
52+
from pandas.core.util.numba_ import GLOBAL_USE_NUMBA
53+
from pandas.core import nanops_numba
54+
5155
if TYPE_CHECKING:
5256
from collections.abc import Callable
5357

@@ -97,6 +101,38 @@ def _f(*args, **kwargs):
97101
return cast(F, _f)
98102

99103

104+
class numba_switch:
105+
def __init__(self, name=None, **kwargs) -> None:
106+
self.name = name
107+
self.kwargs = kwargs
108+
109+
def __call__(self, alt: F) -> F:
110+
nb_name = self.name or alt.__name__
111+
112+
try:
113+
nb_func = getattr(nanops_numba, nb_name)
114+
except (AttributeError, NameError): # pragma: no cover
115+
nb_func = None
116+
117+
@functools.wraps(alt)
118+
def f(
119+
values: np.ndarray,
120+
*,
121+
axis: AxisInt | None = None,
122+
skipna: bool = True,
123+
**kwds,
124+
):
125+
disallowed = values.dtype == "O"
126+
if GLOBAL_USE_NUMBA and not disallowed:
127+
result = nb_func(values, skipna=skipna, axis=axis, **kwds)
128+
else:
129+
result = alt(values, axis=axis, skipna=skipna, **kwds)
130+
131+
return result
132+
133+
return cast(F, f)
134+
135+
100136
class bottleneck_switch:
101137
def __init__(self, name=None, **kwargs) -> None:
102138
self.name = name
@@ -593,6 +629,7 @@ def nanall(
593629
return values.all(axis) # type: ignore[return-value]
594630

595631

632+
@numba_switch()
596633
@disallow("M8")
597634
@_datetimelike_compat
598635
@maybe_operate_rowwise
@@ -660,7 +697,7 @@ def _mask_datetimelike_result(
660697
return result
661698

662699

663-
@bottleneck_switch()
700+
@numba_switch()
664701
@_datetimelike_compat
665702
def nanmean(
666703
values: np.ndarray,
@@ -910,7 +947,7 @@ def _get_counts_nanvar(
910947
return count, d
911948

912949

913-
@bottleneck_switch(ddof=1)
950+
@numba_switch(ddof=1)
914951
def nanstd(
915952
values,
916953
*,
@@ -957,7 +994,7 @@ def nanstd(
957994

958995

959996
@disallow("M8", "m8")
960-
@bottleneck_switch(ddof=1)
997+
@numba_switch(ddof=1)
961998
def nanvar(
962999
values: np.ndarray,
9631000
*,
@@ -1035,6 +1072,7 @@ def nanvar(
10351072
return result
10361073

10371074

1075+
@numba_switch()
10381076
@disallow("M8", "m8")
10391077
def nansem(
10401078
values: np.ndarray,
@@ -1089,7 +1127,7 @@ def nansem(
10891127

10901128

10911129
def _nanminmax(meth, fill_value_typ):
1092-
@bottleneck_switch(name=f"nan{meth}")
1130+
@numba_switch(name=f"nan{meth}")
10931131
@_datetimelike_compat
10941132
def reduction(
10951133
values: np.ndarray,

0 commit comments

Comments
 (0)