Skip to content

Commit d14b8f6

Browse files
committed
TYP: add a few type annotations to numpy.array_api.Array
This fixes the majority of the complaints for `$ mypy numpy/array_api`. The comment indicating that one fix is blocked by lack of support in Mypy for `NotImplemented` is responsible for another several dozen errors. [skip ci] Original NumPy Commit: 5c04e06dacb714fb8381db003eb0f22c08c91417
1 parent c76538d commit d14b8f6

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

array_api_strict/_array_object.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class Array:
5555
functions, such as asarray().
5656
5757
"""
58+
_array: np.ndarray
5859

5960
# Use a custom constructor instead of __init__, as manually initializing
6061
# this class is not supported API.
@@ -124,6 +125,9 @@ def __array__(self, dtype: None | np.dtype[Any] = None) -> npt.NDArray[Any]:
124125
# spec in places where it either deviates from or is more strict than
125126
# NumPy behavior
126127

128+
# NOTE: no valid type annotation possible. E.g `Union[Array,
129+
# NotImplemented]` is forbidden, see https://github.com/python/mypy/issues/363
130+
# Maybe change returned object to `Literal['NotImplemented']`?
127131
def _check_allowed_dtypes(self, other, dtype_category, op):
128132
"""
129133
Helper function for operators to only allow specific input dtypes
@@ -200,7 +204,7 @@ def _promote_scalar(self, scalar):
200204
return Array._new(np.array(scalar, self.dtype))
201205

202206
@staticmethod
203-
def _normalize_two_args(x1, x2):
207+
def _normalize_two_args(x1, x2) -> Tuple[Array, Array]:
204208
"""
205209
Normalize inputs to two arg functions to fix type promotion rules
206210

0 commit comments

Comments
 (0)