Skip to content

Commit 2d02a3d

Browse files
committed
TYP: Add method annotations in ndarray
1 parent 7bfdc8c commit 2d02a3d

File tree

1 file changed

+23
-20
lines changed

1 file changed

+23
-20
lines changed

numpy/__init__.pyi

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ from typing import (
210210
# library include `typing_extensions` stubs:
211211
# https://github.com/python/typeshed/blob/main/stdlib/typing_extensions.pyi
212212
from _typeshed import StrOrBytesPath, SupportsFlush, SupportsLenAndGetItem, SupportsWrite
213-
from typing_extensions import CapsuleType, Generic, LiteralString, Protocol, Self, TypeVar, deprecated, overload
213+
from typing_extensions import CapsuleType, Generic, LiteralString, Protocol, Self, TypeVar, Unpack, deprecated, overload
214214

215215
from numpy import (
216216
core,
@@ -1792,6 +1792,8 @@ _ArrayComplex_co: TypeAlias = NDArray[np.bool | integer[Any] | floating[Any] | c
17921792
_ArrayNumber_co: TypeAlias = NDArray[np.bool | number[Any]]
17931793
_ArrayTD64_co: TypeAlias = NDArray[np.bool | integer[Any] | timedelta64]
17941794

1795+
_ArrayIndexLike: TypeAlias = SupportsIndex | slice | EllipsisType | _ArrayLikeInt_co | None
1796+
17951797
# Introduce an alias for `dtype` to avoid naming conflicts.
17961798
_dtype: TypeAlias = dtype[_ScalarType]
17971799

@@ -1906,26 +1908,20 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
19061908
) -> ndarray[_ShapeType, _DType]: ...
19071909

19081910
@overload
1909-
def __getitem__(self, key: (
1910-
NDArray[integer[Any]]
1911-
| NDArray[np.bool]
1912-
| tuple[NDArray[integer[Any]] | NDArray[np.bool], ...]
1913-
)) -> ndarray[_Shape, _DType_co]: ...
1911+
def __getitem__(self, key: _ArrayInt_co | tuple[_ArrayInt_co, ...], /) -> ndarray[_Shape, _DType_co]: ...
1912+
@overload
1913+
def __getitem__(self, key: SupportsIndex | tuple[SupportsIndex, ...], /) -> Any: ...
19141914
@overload
1915-
def __getitem__(self, key: SupportsIndex | tuple[SupportsIndex, ...]) -> Any: ...
1915+
def __getitem__(self, key: _ArrayIndexLike | tuple[_ArrayIndexLike, ...], /) -> ndarray[_Shape, _DType_co]: ...
19161916
@overload
1917-
def __getitem__(self, key: (
1918-
None
1919-
| slice
1920-
| EllipsisType
1921-
| SupportsIndex
1922-
| _ArrayLikeInt_co
1923-
| tuple[None | slice | EllipsisType | _ArrayLikeInt_co | SupportsIndex, ...]
1924-
)) -> ndarray[_Shape, _DType_co]: ...
1917+
def __getitem__(self: NDArray[void], key: str, /) -> ndarray[_ShapeType_co, np.dtype[Any]]: ...
1918+
@overload
1919+
def __getitem__(self: NDArray[void], key: list[str], /) -> ndarray[_ShapeType_co, _dtype[void]]: ...
1920+
19251921
@overload
1926-
def __getitem__(self: NDArray[void], key: str) -> NDArray[Any]: ...
1922+
def __setitem__(self: NDArray[void], key: str | list[str], value: ArrayLike, /) -> None: ...
19271923
@overload
1928-
def __getitem__(self: NDArray[void], key: list[str]) -> ndarray[_ShapeType_co, _dtype[void]]: ...
1924+
def __setitem__(self, key: _ArrayIndexLike | tuple[_ArrayIndexLike, ...], value: ArrayLike, /) -> None: ...
19291925

19301926
@property
19311927
def ctypes(self) -> _ctypes[int]: ...
@@ -2272,9 +2268,16 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
22722268
def __complex__(self: NDArray[number[Any] | np.bool | object_], /) -> complex: ...
22732269

22742270
def __len__(self) -> int: ...
2275-
def __setitem__(self, key, value): ...
2276-
def __iter__(self) -> Any: ...
2277-
def __contains__(self, key) -> builtins.bool: ...
2271+
def __contains__(self, value: object, /) -> builtins.bool: ...
2272+
2273+
@overload # == 1-d & object_
2274+
def __iter__(self: ndarray[tuple[int], dtype[object_]], /) -> Iterator[Any]: ...
2275+
@overload # == 1-d
2276+
def __iter__(self: ndarray[tuple[int], dtype[_SCT]], /) -> Iterator[_SCT]: ...
2277+
@overload # >= 2-d
2278+
def __iter__(self: ndarray[tuple[int, int, Unpack[tuple[int, ...]]], dtype[_SCT]], /) -> Iterator[NDArray[_SCT]]: ...
2279+
@overload # ?-d
2280+
def __iter__(self, /) -> Iterator[Any]: ...
22782281

22792282
# The last overload is for catching recursive objects whose
22802283
# nesting is too deep.

0 commit comments

Comments
 (0)