Skip to content

Commit 1a84695

Browse files
committed
add type annotations to binary functions
1 parent 5ddba5b commit 1a84695

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

array_api_strict/_elementwise_functions.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
_real_numeric_dtypes,
1111
_numeric_dtypes,
1212
_result_type,
13-
_dtype_categories as _dtype_dtype_categories,
13+
_dtype_categories,
1414
)
1515
from ._array_object import Array
1616
from ._flags import requires_api_version
@@ -51,6 +51,22 @@ def inner(x1: Array, x2: Array, /) -> Array:
5151
return inner
5252

5353

54+
55+
# static type annotation for ArrayOrPythonScalar arguments given a category
56+
# NB: keep the keys in sync with the _dtype_categories dict
57+
_annotations = {
58+
"all": "bool | int | float | complex | Array",
59+
"real numeric": "int | float | Array",
60+
"numeric": "int | float | complex | Array",
61+
"integer": "int | Array",
62+
"integer or boolean": "int | bool | Array",
63+
"boolean": "bool | Array",
64+
"real floating-point": "float | Array",
65+
"complex floating-point": "complex | Array",
66+
"floating-point": "float | complex | Array",
67+
}
68+
69+
5470
# func_name: dtype_category (must match that from _dtypes.py)
5571
_binary_funcs = {
5672
"add": "numeric",
@@ -97,7 +113,7 @@ def inner(x1: Array, x2: Array, /) -> Array:
97113
# create and attach functions to the module
98114
for func_name, dtype_category in _binary_funcs.items():
99115
# sanity check
100-
assert dtype_category in _dtype_dtype_categories
116+
assert dtype_category in _dtype_categories
101117

102118
numpy_name = _numpy_renames.get(func_name, func_name)
103119
np_func = getattr(np, numpy_name)
@@ -106,6 +122,8 @@ def inner(x1: Array, x2: Array, /) -> Array:
106122
func.__name__ = func_name
107123

108124
func.__doc__ = _binary_docstring_template % (numpy_name, numpy_name)
125+
func.__annotations__['x1'] = _annotations[dtype_category]
126+
func.__annotations__['x2'] = _annotations[dtype_category]
109127

110128
vars()[func_name] = func
111129

@@ -117,15 +135,15 @@ def inner(x1: Array, x2: Array, /) -> Array:
117135
nextafter = requires_api_version('2024.12')(nextafter) # noqa: F821
118136

119137

120-
def bitwise_left_shift(x1: Array, x2: Array, /) -> Array:
138+
def bitwise_left_shift(x1: int | Array, x2: int | Array, /) -> Array:
121139
is_negative = np.any(x2._array < 0) if isinstance(x2, Array) else x2 < 0
122140
if is_negative:
123141
raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0")
124142
return _bitwise_left_shift(x1, x2) # noqa: F821
125143
bitwise_left_shift.__doc__ = _bitwise_left_shift.__doc__ # noqa: F821
126144

127145

128-
def bitwise_right_shift(x1: Array, x2: Array, /) -> Array:
146+
def bitwise_right_shift(x1: int | Array, x2: int | Array, /) -> Array:
129147
is_negative = np.any(x2._array < 0) if isinstance(x2, Array) else x2 < 0
130148
if is_negative:
131149
raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0")

0 commit comments

Comments
 (0)