|
48 | 48 | notna,
|
49 | 49 | )
|
50 | 50 |
|
| 51 | +from pandas.core.util.numba_ import GLOBAL_USE_NUMBA |
| 52 | + |
| 53 | +if GLOBAL_USE_NUMBA: |
| 54 | + from pandas.core import nanops_numba |
| 55 | + |
| 56 | + |
51 | 57 | if TYPE_CHECKING:
|
52 | 58 | from collections.abc import Callable
|
53 | 59 |
|
@@ -97,6 +103,38 @@ def _f(*args, **kwargs):
|
97 | 103 | return cast(F, _f)
|
98 | 104 |
|
99 | 105 |
|
| 106 | +class numba_switch: |
| 107 | + def __init__(self, name=None, **kwargs) -> None: |
| 108 | + self.name = name |
| 109 | + self.kwargs = kwargs |
| 110 | + |
| 111 | + def __call__(self, alt: F) -> F: |
| 112 | + nb_name = self.name or alt.__name__ |
| 113 | + |
| 114 | + try: |
| 115 | + nb_func = getattr(nanops_numba, nb_name) |
| 116 | + except (AttributeError, NameError): # pragma: no cover |
| 117 | + return alt |
| 118 | + |
| 119 | + @functools.wraps(alt) |
| 120 | + def f( |
| 121 | + values: np.ndarray, |
| 122 | + *, |
| 123 | + axis: AxisInt | None = None, |
| 124 | + skipna: bool = True, |
| 125 | + **kwds, |
| 126 | + ): |
| 127 | + disallowed = values.dtype == "O" |
| 128 | + if not disallowed: |
| 129 | + result = nb_func(values, skipna=skipna, axis=axis, **kwds) |
| 130 | + else: |
| 131 | + result = alt(values, axis=axis, skipna=skipna, **kwds) |
| 132 | + |
| 133 | + return result |
| 134 | + |
| 135 | + return cast(F, f) |
| 136 | + |
| 137 | + |
100 | 138 | class bottleneck_switch:
|
101 | 139 | def __init__(self, name=None, **kwargs) -> None:
|
102 | 140 | self.name = name
|
@@ -593,6 +631,7 @@ def nanall(
|
593 | 631 | return values.all(axis) # type: ignore[return-value]
|
594 | 632 |
|
595 | 633 |
|
| 634 | +@numba_switch() |
596 | 635 | @disallow("M8")
|
597 | 636 | @_datetimelike_compat
|
598 | 637 | @maybe_operate_rowwise
|
@@ -658,7 +697,7 @@ def _mask_datetimelike_result(
|
658 | 697 | return result
|
659 | 698 |
|
660 | 699 |
|
661 |
| -@bottleneck_switch() |
| 700 | +@numba_switch() |
662 | 701 | @_datetimelike_compat
|
663 | 702 | def nanmean(
|
664 | 703 | values: np.ndarray,
|
@@ -908,7 +947,7 @@ def _get_counts_nanvar(
|
908 | 947 | return count, d
|
909 | 948 |
|
910 | 949 |
|
911 |
| -@bottleneck_switch(ddof=1) |
| 950 | +@numba_switch(ddof=1) |
912 | 951 | def nanstd(
|
913 | 952 | values,
|
914 | 953 | *,
|
@@ -955,7 +994,7 @@ def nanstd(
|
955 | 994 |
|
956 | 995 |
|
957 | 996 | @disallow("M8", "m8")
|
958 |
| -@bottleneck_switch(ddof=1) |
| 997 | +@numba_switch(ddof=1) |
959 | 998 | def nanvar(
|
960 | 999 | values: np.ndarray,
|
961 | 1000 | *,
|
@@ -1033,6 +1072,7 @@ def nanvar(
|
1033 | 1072 | return result
|
1034 | 1073 |
|
1035 | 1074 |
|
| 1075 | +@numba_switch() |
1036 | 1076 | @disallow("M8", "m8")
|
1037 | 1077 | def nansem(
|
1038 | 1078 | values: np.ndarray,
|
@@ -1087,7 +1127,7 @@ def nansem(
|
1087 | 1127 |
|
1088 | 1128 |
|
1089 | 1129 | def _nanminmax(meth, fill_value_typ):
|
1090 |
| - @bottleneck_switch(name=f"nan{meth}") |
| 1130 | + @numba_switch(name=f"nan{meth}") |
1091 | 1131 | @_datetimelike_compat
|
1092 | 1132 | def reduction(
|
1093 | 1133 | values: np.ndarray,
|
|
0 commit comments