Skip to content

Commit 14f8d17

Browse files
committed
Add a numba_switch decorator to nanops and replace most of the bottleneck switches
1 parent c384677 commit 14f8d17

File tree

1 file changed

+44
-4
lines changed

1 file changed

+44
-4
lines changed

pandas/core/nanops.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@
4848
notna,
4949
)
5050

51+
from pandas.core.util.numba_ import GLOBAL_USE_NUMBA
52+
53+
if GLOBAL_USE_NUMBA:
54+
from pandas.core import nanops_numba
55+
56+
5157
if TYPE_CHECKING:
5258
from collections.abc import Callable
5359

@@ -97,6 +103,38 @@ def _f(*args, **kwargs):
97103
return cast(F, _f)
98104

99105

106+
class numba_switch:
107+
def __init__(self, name=None, **kwargs) -> None:
108+
self.name = name
109+
self.kwargs = kwargs
110+
111+
def __call__(self, alt: F) -> F:
112+
nb_name = self.name or alt.__name__
113+
114+
try:
115+
nb_func = getattr(nanops_numba, nb_name)
116+
except (AttributeError, NameError): # pragma: no cover
117+
return alt
118+
119+
@functools.wraps(alt)
120+
def f(
121+
values: np.ndarray,
122+
*,
123+
axis: AxisInt | None = None,
124+
skipna: bool = True,
125+
**kwds,
126+
):
127+
disallowed = values.dtype == "O"
128+
if not disallowed:
129+
result = nb_func(values, skipna=skipna, axis=axis, **kwds)
130+
else:
131+
result = alt(values, axis=axis, skipna=skipna, **kwds)
132+
133+
return result
134+
135+
return cast(F, f)
136+
137+
100138
class bottleneck_switch:
101139
def __init__(self, name=None, **kwargs) -> None:
102140
self.name = name
@@ -593,6 +631,7 @@ def nanall(
593631
return values.all(axis) # type: ignore[return-value]
594632

595633

634+
@numba_switch()
596635
@disallow("M8")
597636
@_datetimelike_compat
598637
@maybe_operate_rowwise
@@ -658,7 +697,7 @@ def _mask_datetimelike_result(
658697
return result
659698

660699

661-
@bottleneck_switch()
700+
@numba_switch()
662701
@_datetimelike_compat
663702
def nanmean(
664703
values: np.ndarray,
@@ -908,7 +947,7 @@ def _get_counts_nanvar(
908947
return count, d
909948

910949

911-
@bottleneck_switch(ddof=1)
950+
@numba_switch(ddof=1)
912951
def nanstd(
913952
values,
914953
*,
@@ -955,7 +994,7 @@ def nanstd(
955994

956995

957996
@disallow("M8", "m8")
958-
@bottleneck_switch(ddof=1)
997+
@numba_switch(ddof=1)
959998
def nanvar(
960999
values: np.ndarray,
9611000
*,
@@ -1033,6 +1072,7 @@ def nanvar(
10331072
return result
10341073

10351074

1075+
@numba_switch()
10361076
@disallow("M8", "m8")
10371077
def nansem(
10381078
values: np.ndarray,
@@ -1087,7 +1127,7 @@ def nansem(
10871127

10881128

10891129
def _nanminmax(meth, fill_value_typ):
1090-
@bottleneck_switch(name=f"nan{meth}")
1130+
@numba_switch(name=f"nan{meth}")
10911131
@_datetimelike_compat
10921132
def reduction(
10931133
values: np.ndarray,

0 commit comments

Comments
 (0)