10
10
_real_numeric_dtypes ,
11
11
_numeric_dtypes ,
12
12
_result_type ,
13
- _dtype_categories as _dtype_dtype_categories ,
13
+ _dtype_categories ,
14
14
)
15
15
from ._array_object import Array
16
16
from ._flags import requires_api_version
@@ -51,6 +51,22 @@ def inner(x1: Array, x2: Array, /) -> Array:
51
51
return inner
52
52
53
53
54
+
55
+ # static type annotation for ArrayOrPythonScalar arguments given a category
56
+ # NB: keep the keys in sync with the _dtype_categories dict
57
+ _annotations = {
58
+ "all" : "bool | int | float | complex | Array" ,
59
+ "real numeric" : "int | float | Array" ,
60
+ "numeric" : "int | float | complex | Array" ,
61
+ "integer" : "int | Array" ,
62
+ "integer or boolean" : "int | bool | Array" ,
63
+ "boolean" : "bool | Array" ,
64
+ "real floating-point" : "float | Array" ,
65
+ "complex floating-point" : "complex | Array" ,
66
+ "floating-point" : "float | complex | Array" ,
67
+ }
68
+
69
+
54
70
# func_name: dtype_category (must match that from _dtypes.py)
55
71
_binary_funcs = {
56
72
"add" : "numeric" ,
@@ -97,7 +113,7 @@ def inner(x1: Array, x2: Array, /) -> Array:
97
113
# create and attach functions to the module
98
114
for func_name , dtype_category in _binary_funcs .items ():
99
115
# sanity check
100
- assert dtype_category in _dtype_dtype_categories
116
+ assert dtype_category in _dtype_categories
101
117
102
118
numpy_name = _numpy_renames .get (func_name , func_name )
103
119
np_func = getattr (np , numpy_name )
@@ -106,6 +122,8 @@ def inner(x1: Array, x2: Array, /) -> Array:
106
122
func .__name__ = func_name
107
123
108
124
func .__doc__ = _binary_docstring_template % (numpy_name , numpy_name )
125
+ func .__annotations__ ['x1' ] = _annotations [dtype_category ]
126
+ func .__annotations__ ['x2' ] = _annotations [dtype_category ]
109
127
110
128
vars ()[func_name ] = func
111
129
@@ -117,15 +135,15 @@ def inner(x1: Array, x2: Array, /) -> Array:
117
135
nextafter = requires_api_version ('2024.12' )(nextafter ) # noqa: F821
118
136
119
137
120
- def bitwise_left_shift (x1 : Array , x2 : Array , / ) -> Array :
138
+ def bitwise_left_shift (x1 : int | Array , x2 : int | Array , / ) -> Array :
121
139
is_negative = np .any (x2 ._array < 0 ) if isinstance (x2 , Array ) else x2 < 0
122
140
if is_negative :
123
141
raise ValueError ("bitwise_left_shift(x1, x2) is only defined for x2 >= 0" )
124
142
return _bitwise_left_shift (x1 , x2 ) # noqa: F821
125
143
bitwise_left_shift .__doc__ = _bitwise_left_shift .__doc__ # noqa: F821
126
144
127
145
128
- def bitwise_right_shift (x1 : Array , x2 : Array , / ) -> Array :
146
+ def bitwise_right_shift (x1 : int | Array , x2 : int | Array , / ) -> Array :
129
147
is_negative = np .any (x2 ._array < 0 ) if isinstance (x2 , Array ) else x2 < 0
130
148
if is_negative :
131
149
raise ValueError ("bitwise_left_shift(x1, x2) is only defined for x2 >= 0" )
0 commit comments