Skip to content

Commit 0727fbb

Browse files
committed
Add a numba_switch decorator to nanops and replace most of the bottleneck switches
1 parent 3ca9712 commit 0727fbb

File tree

1 file changed

+52
-4
lines changed

1 file changed

+52
-4
lines changed

pandas/core/nanops.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
notna,
4949
)
5050

51+
from pandas.core import nanops_numba
52+
5153
if TYPE_CHECKING:
5254
from collections.abc import Callable
5355

@@ -66,6 +68,18 @@ def set_use_bottleneck(v: bool = True) -> None:
6668
set_use_bottleneck(get_option("compute.use_bottleneck"))
6769

6870

71+
_USE_NUMBA = True
72+
73+
74+
def set_use_numba(v: bool = True) -> None:
75+
# set/unset to use bottleneck
76+
global _USE_NUMBA
77+
_USE_NUMBA = v
78+
79+
80+
# set_use_numba(get_option("compute.use_numba"))
81+
82+
6983
class disallow:
7084
def __init__(self, *dtypes: Dtype) -> None:
7185
super().__init__()
@@ -97,6 +111,38 @@ def _f(*args, **kwargs):
97111
return cast(F, _f)
98112

99113

114+
class numba_switch:
115+
def __init__(self, name=None, **kwargs) -> None:
116+
self.name = name
117+
self.kwargs = kwargs
118+
119+
def __call__(self, alt: F) -> F:
120+
nb_name = self.name or alt.__name__
121+
122+
try:
123+
nb_func = getattr(nanops_numba, nb_name)
124+
except (AttributeError, NameError): # pragma: no cover
125+
nb_func = None
126+
127+
@functools.wraps(alt)
128+
def f(
129+
values: np.ndarray,
130+
*,
131+
axis: AxisInt | None = None,
132+
skipna: bool = True,
133+
**kwds,
134+
):
135+
disallowed = values.dtype == "O"
136+
if _USE_NUMBA and not disallowed:
137+
result = nb_func(values, skipna=skipna, axis=axis, **kwds)
138+
else:
139+
result = alt(values, axis=axis, skipna=skipna, **kwds)
140+
141+
return result
142+
143+
return cast(F, f)
144+
145+
100146
class bottleneck_switch:
101147
def __init__(self, name=None, **kwargs) -> None:
102148
self.name = name
@@ -593,6 +639,7 @@ def nanall(
593639
return values.all(axis) # type: ignore[return-value]
594640

595641

642+
@numba_switch()
596643
@disallow("M8")
597644
@_datetimelike_compat
598645
@maybe_operate_rowwise
@@ -660,7 +707,7 @@ def _mask_datetimelike_result(
660707
return result
661708

662709

663-
@bottleneck_switch()
710+
@numba_switch()
664711
@_datetimelike_compat
665712
def nanmean(
666713
values: np.ndarray,
@@ -910,7 +957,7 @@ def _get_counts_nanvar(
910957
return count, d
911958

912959

913-
@bottleneck_switch(ddof=1)
960+
@numba_switch(ddof=1)
914961
def nanstd(
915962
values,
916963
*,
@@ -957,7 +1004,7 @@ def nanstd(
9571004

9581005

9591006
@disallow("M8", "m8")
960-
@bottleneck_switch(ddof=1)
1007+
@numba_switch(ddof=1)
9611008
def nanvar(
9621009
values: np.ndarray,
9631010
*,
@@ -1035,6 +1082,7 @@ def nanvar(
10351082
return result
10361083

10371084

1085+
@numba_switch()
10381086
@disallow("M8", "m8")
10391087
def nansem(
10401088
values: np.ndarray,
@@ -1089,7 +1137,7 @@ def nansem(
10891137

10901138

10911139
def _nanminmax(meth, fill_value_typ):
1092-
@bottleneck_switch(name=f"nan{meth}")
1140+
@numba_switch(name=f"nan{meth}")
10931141
@_datetimelike_compat
10941142
def reduction(
10951143
values: np.ndarray,

0 commit comments

Comments
 (0)