1
1
from collections .abc import Sequence
2
- from typing import Any , NoReturn , overload , TypeVar , Literal , SupportsIndex
2
+ from typing import (
3
+ Any ,
4
+ Literal ,
5
+ NoReturn ,
6
+ Protocol ,
7
+ SupportsIndex ,
8
+ TypeAlias ,
9
+ TypeVar ,
10
+ overload ,
11
+ type_check_only ,
12
+ )
13
+ from typing_extensions import Never
3
14
4
15
import numpy as np
5
16
from numpy import (
@@ -29,7 +40,6 @@ from numpy._typing import (
29
40
_ArrayLike ,
30
41
NDArray ,
31
42
_ShapeLike ,
32
- _Shape ,
33
43
_ArrayLikeBool_co ,
34
44
_ArrayLikeUInt_co ,
35
45
_ArrayLikeInt_co ,
@@ -46,7 +56,21 @@ from numpy._typing import (
46
56
47
57
_SCT = TypeVar ("_SCT" , bound = generic )
48
58
_SCT_uifcO = TypeVar ("_SCT_uifcO" , bound = number [Any ] | object_ )
49
- _ArrayType = TypeVar ("_ArrayType" , bound = NDArray [Any ])
59
+ _ArrayType = TypeVar ("_ArrayType" , bound = np .ndarray [Any , Any ])
60
+ _ShapeType = TypeVar ("_ShapeType" , bound = tuple [int , ...])
61
+ _ShapeType_co = TypeVar ("_ShapeType_co" , bound = tuple [int , ...], covariant = True )
62
+
63
+ @type_check_only
64
+ class _SupportsShape (Protocol [_ShapeType_co ]):
65
+ # NOTE: it matters that `self` is positional only
66
+ @property
67
+ def shape (self , / ) -> _ShapeType_co : ...
68
+
69
+ # a "sequence" that isn't a string, bytes, bytearray, or memoryview
70
+ _T = TypeVar ("_T" )
71
+ _PyArray : TypeAlias = list [_T ] | tuple [_T , ...]
72
+ # `int` also covers `bool`
73
+ _PyScalar : TypeAlias = int | float | complex | bytes | str
50
74
51
75
__all__ : list [str ]
52
76
@@ -379,7 +403,24 @@ def nonzero(a: np.generic | np.ndarray[tuple[()], Any]) -> NoReturn: ...
379
403
@overload
380
404
def nonzero (a : _ArrayLike [Any ]) -> tuple [NDArray [intp ], ...]: ...
381
405
382
- def shape (a : ArrayLike ) -> _Shape : ...
406
+ # this prevents `Any` from being returned with Pyright
407
+ @overload
408
+ def shape (a : _SupportsShape [Never ]) -> tuple [int , ...]: ...
409
+ @overload
410
+ def shape (a : _SupportsShape [_ShapeType ]) -> _ShapeType : ...
411
+ @overload
412
+ def shape (a : _PyScalar ) -> tuple [()]: ...
413
+ # `collections.abc.Sequence` can't be used hesre, since `bytes` and `str` are
414
+ # subtypes of it, which would make the return types incompatible.
415
+ @overload
416
+ def shape (a : _PyArray [_PyScalar ]) -> tuple [int ]: ...
417
+ @overload
418
+ def shape (a : _PyArray [_PyArray [_PyScalar ]]) -> tuple [int , int ]: ...
419
+ # this overload will be skipped by typecheckers that don't support PEP 688
420
+ @overload
421
+ def shape (a : memoryview | bytearray ) -> tuple [int ]: ...
422
+ @overload
423
+ def shape (a : ArrayLike ) -> tuple [int , ...]: ...
383
424
384
425
@overload
385
426
def compress (
0 commit comments