|
48 | 48 | notna,
|
49 | 49 | )
|
50 | 50 |
|
| 51 | + |
| 52 | +from pandas.core.util.numba_ import GLOBAL_USE_NUMBA |
| 53 | +from pandas.core import nanops_numba |
| 54 | + |
51 | 55 | if TYPE_CHECKING:
|
52 | 56 | from collections.abc import Callable
|
53 | 57 |
|
@@ -97,6 +101,38 @@ def _f(*args, **kwargs):
|
97 | 101 | return cast(F, _f)
|
98 | 102 |
|
99 | 103 |
|
| 104 | +class numba_switch: |
| 105 | + def __init__(self, name=None, **kwargs) -> None: |
| 106 | + self.name = name |
| 107 | + self.kwargs = kwargs |
| 108 | + |
| 109 | + def __call__(self, alt: F) -> F: |
| 110 | + nb_name = self.name or alt.__name__ |
| 111 | + |
| 112 | + try: |
| 113 | + nb_func = getattr(nanops_numba, nb_name) |
| 114 | + except (AttributeError, NameError): # pragma: no cover |
| 115 | + nb_func = None |
| 116 | + |
| 117 | + @functools.wraps(alt) |
| 118 | + def f( |
| 119 | + values: np.ndarray, |
| 120 | + *, |
| 121 | + axis: AxisInt | None = None, |
| 122 | + skipna: bool = True, |
| 123 | + **kwds, |
| 124 | + ): |
| 125 | + disallowed = values.dtype == "O" |
| 126 | + if GLOBAL_USE_NUMBA and not disallowed: |
| 127 | + result = nb_func(values, skipna=skipna, axis=axis, **kwds) |
| 128 | + else: |
| 129 | + result = alt(values, axis=axis, skipna=skipna, **kwds) |
| 130 | + |
| 131 | + return result |
| 132 | + |
| 133 | + return cast(F, f) |
| 134 | + |
| 135 | + |
100 | 136 | class bottleneck_switch:
|
101 | 137 | def __init__(self, name=None, **kwargs) -> None:
|
102 | 138 | self.name = name
|
@@ -593,6 +629,7 @@ def nanall(
|
593 | 629 | return values.all(axis) # type: ignore[return-value]
|
594 | 630 |
|
595 | 631 |
|
| 632 | +@numba_switch() |
596 | 633 | @disallow("M8")
|
597 | 634 | @_datetimelike_compat
|
598 | 635 | @maybe_operate_rowwise
|
@@ -660,7 +697,7 @@ def _mask_datetimelike_result(
|
660 | 697 | return result
|
661 | 698 |
|
662 | 699 |
|
663 |
| -@bottleneck_switch() |
| 700 | +@numba_switch() |
664 | 701 | @_datetimelike_compat
|
665 | 702 | def nanmean(
|
666 | 703 | values: np.ndarray,
|
@@ -910,7 +947,7 @@ def _get_counts_nanvar(
|
910 | 947 | return count, d
|
911 | 948 |
|
912 | 949 |
|
913 |
| -@bottleneck_switch(ddof=1) |
| 950 | +@numba_switch(ddof=1) |
914 | 951 | def nanstd(
|
915 | 952 | values,
|
916 | 953 | *,
|
@@ -957,7 +994,7 @@ def nanstd(
|
957 | 994 |
|
958 | 995 |
|
959 | 996 | @disallow("M8", "m8")
|
960 |
| -@bottleneck_switch(ddof=1) |
| 997 | +@numba_switch(ddof=1) |
961 | 998 | def nanvar(
|
962 | 999 | values: np.ndarray,
|
963 | 1000 | *,
|
@@ -1035,6 +1072,7 @@ def nanvar(
|
1035 | 1072 | return result
|
1036 | 1073 |
|
1037 | 1074 |
|
| 1075 | +@numba_switch() |
1038 | 1076 | @disallow("M8", "m8")
|
1039 | 1077 | def nansem(
|
1040 | 1078 | values: np.ndarray,
|
@@ -1089,7 +1127,7 @@ def nansem(
|
1089 | 1127 |
|
1090 | 1128 |
|
1091 | 1129 | def _nanminmax(meth, fill_value_typ):
|
1092 |
| - @bottleneck_switch(name=f"nan{meth}") |
| 1130 | + @numba_switch(name=f"nan{meth}") |
1093 | 1131 | @_datetimelike_compat
|
1094 | 1132 | def reduction(
|
1095 | 1133 | values: np.ndarray,
|
|
0 commit comments