48
48
notna ,
49
49
)
50
50
51
+ from pandas .core import nanops_numba
52
+
51
53
if TYPE_CHECKING :
52
54
from collections .abc import Callable
53
55
@@ -66,6 +68,18 @@ def set_use_bottleneck(v: bool = True) -> None:
66
68
set_use_bottleneck (get_option ("compute.use_bottleneck" ))
67
69
68
70
71
+ _USE_NUMBA = True
72
+
73
+
74
+ def set_use_numba (v : bool = True ) -> None :
75
+ # set/unset to use bottleneck
76
+ global _USE_NUMBA
77
+ _USE_NUMBA = v
78
+
79
+
80
+ # set_use_numba(get_option("compute.use_numba"))
81
+
82
+
69
83
class disallow :
70
84
def __init__ (self , * dtypes : Dtype ) -> None :
71
85
super ().__init__ ()
@@ -97,6 +111,38 @@ def _f(*args, **kwargs):
97
111
return cast (F , _f )
98
112
99
113
114
+ class numba_switch :
115
+ def __init__ (self , name = None , ** kwargs ) -> None :
116
+ self .name = name
117
+ self .kwargs = kwargs
118
+
119
+ def __call__ (self , alt : F ) -> F :
120
+ nb_name = self .name or alt .__name__
121
+
122
+ try :
123
+ nb_func = getattr (nanops_numba , nb_name )
124
+ except (AttributeError , NameError ): # pragma: no cover
125
+ nb_func = None
126
+
127
+ @functools .wraps (alt )
128
+ def f (
129
+ values : np .ndarray ,
130
+ * ,
131
+ axis : AxisInt | None = None ,
132
+ skipna : bool = True ,
133
+ ** kwds ,
134
+ ):
135
+ disallowed = values .dtype == "O"
136
+ if _USE_NUMBA and not disallowed :
137
+ result = nb_func (values , skipna = skipna , axis = axis , ** kwds )
138
+ else :
139
+ result = alt (values , axis = axis , skipna = skipna , ** kwds )
140
+
141
+ return result
142
+
143
+ return cast (F , f )
144
+
145
+
100
146
class bottleneck_switch :
101
147
def __init__ (self , name = None , ** kwargs ) -> None :
102
148
self .name = name
@@ -593,6 +639,7 @@ def nanall(
593
639
return values .all (axis ) # type: ignore[return-value]
594
640
595
641
642
+ @numba_switch ()
596
643
@disallow ("M8" )
597
644
@_datetimelike_compat
598
645
@maybe_operate_rowwise
@@ -660,7 +707,7 @@ def _mask_datetimelike_result(
660
707
return result
661
708
662
709
663
- @bottleneck_switch ()
710
+ @numba_switch ()
664
711
@_datetimelike_compat
665
712
def nanmean (
666
713
values : np .ndarray ,
@@ -910,7 +957,7 @@ def _get_counts_nanvar(
910
957
return count , d
911
958
912
959
913
- @bottleneck_switch (ddof = 1 )
960
+ @numba_switch (ddof = 1 )
914
961
def nanstd (
915
962
values ,
916
963
* ,
@@ -957,7 +1004,7 @@ def nanstd(
957
1004
958
1005
959
1006
@disallow ("M8" , "m8" )
960
- @bottleneck_switch (ddof = 1 )
1007
+ @numba_switch (ddof = 1 )
961
1008
def nanvar (
962
1009
values : np .ndarray ,
963
1010
* ,
@@ -1035,6 +1082,7 @@ def nanvar(
1035
1082
return result
1036
1083
1037
1084
1085
+ @numba_switch ()
1038
1086
@disallow ("M8" , "m8" )
1039
1087
def nansem (
1040
1088
values : np .ndarray ,
@@ -1089,7 +1137,7 @@ def nansem(
1089
1137
1090
1138
1091
1139
def _nanminmax (meth , fill_value_typ ):
1092
- @bottleneck_switch (name = f"nan{ meth } " )
1140
+ @numba_switch (name = f"nan{ meth } " )
1093
1141
@_datetimelike_compat
1094
1142
def reduction (
1095
1143
values : np .ndarray ,
0 commit comments