Skip to content

Commit c5d52f3

Browse files
committed
TYP: Transparent numpy.shape shape-type annotations.
1 parent 89b6820 commit c5d52f3

File tree

2 files changed

+55
-8
lines changed

2 files changed

+55
-8
lines changed

numpy/_core/fromnumeric.pyi

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
from collections.abc import Sequence
2-
from typing import Any, NoReturn, overload, TypeVar, Literal, SupportsIndex
2+
from typing import (
3+
Any,
4+
Literal,
5+
NoReturn,
6+
Protocol,
7+
SupportsIndex,
8+
TypeAlias,
9+
TypeVar,
10+
overload,
11+
type_check_only,
12+
)
13+
from typing_extensions import Never
314

415
import numpy as np
516
from numpy import (
@@ -29,7 +40,6 @@ from numpy._typing import (
2940
_ArrayLike,
3041
NDArray,
3142
_ShapeLike,
32-
_Shape,
3343
_ArrayLikeBool_co,
3444
_ArrayLikeUInt_co,
3545
_ArrayLikeInt_co,
@@ -46,7 +56,21 @@ from numpy._typing import (
4656

4757
_SCT = TypeVar("_SCT", bound=generic)
4858
_SCT_uifcO = TypeVar("_SCT_uifcO", bound=number[Any] | object_)
49-
_ArrayType = TypeVar("_ArrayType", bound=NDArray[Any])
59+
_ArrayType = TypeVar("_ArrayType", bound=np.ndarray[Any, Any])
60+
_ShapeType = TypeVar("_ShapeType", bound=tuple[int, ...])
61+
_ShapeType_co = TypeVar("_ShapeType_co", bound=tuple[int, ...], covariant=True)
62+
63+
@type_check_only
64+
class _SupportsShape(Protocol[_ShapeType_co]):
65+
# NOTE: it matters that `self` is positional only
66+
@property
67+
def shape(self, /) -> _ShapeType_co: ...
68+
69+
# a "sequence" that isn't a string, bytes, bytearray, or memoryview
70+
_T = TypeVar("_T")
71+
_PyArray: TypeAlias = list[_T] | tuple[_T, ...]
72+
# `int` also covers `bool`
73+
_PyScalar: TypeAlias = int | float | complex | bytes | str
5074

5175
__all__: list[str]
5276

@@ -373,7 +397,24 @@ def nonzero(a: np.generic | np.ndarray[tuple[()], Any]) -> NoReturn: ...
373397
@overload
374398
def nonzero(a: _ArrayLike[Any]) -> tuple[NDArray[intp], ...]: ...
375399

376-
def shape(a: ArrayLike) -> _Shape: ...
400+
# this prevents `Any` from being returned with Pyright
401+
@overload
402+
def shape(a: _SupportsShape[Never]) -> tuple[int, ...]: ...
403+
@overload
404+
def shape(a: _SupportsShape[_ShapeType]) -> _ShapeType: ...
405+
@overload
406+
def shape(a: _PyScalar) -> tuple[()]: ...
407+
# `collections.abc.Sequence` can't be used hesre, since `bytes` and `str` are
408+
# subtypes of it, which would make the return types incompatible.
409+
@overload
410+
def shape(a: _PyArray[_PyScalar]) -> tuple[int]: ...
411+
@overload
412+
def shape(a: _PyArray[_PyArray[_PyScalar]]) -> tuple[int, int]: ...
413+
# this overload will be skipped by typecheckers that don't support PEP 688
414+
@overload
415+
def shape(a: memoryview | bytearray) -> tuple[int]: ...
416+
@overload
417+
def shape(a: ArrayLike) -> tuple[int, ...]: ...
377418

378419
@overload
379420
def compress(

numpy/typing/tests/data/reveal/fromnumeric.pyi

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,17 @@ assert_type(np.nonzero(AR_0d), NoReturn)
139139
assert_type(np.nonzero(AR_1d), tuple[npt.NDArray[np.intp], ...])
140140
assert_type(np.nonzero(AR_nd), tuple[npt.NDArray[np.intp], ...])
141141

142-
assert_type(np.shape(b), tuple[int, ...])
143-
assert_type(np.shape(f4), tuple[int, ...])
144-
assert_type(np.shape(f), tuple[int, ...])
142+
assert_type(np.shape(b), tuple[()])
143+
assert_type(np.shape(f), tuple[()])
144+
assert_type(np.shape([1]), tuple[int])
145+
assert_type(np.shape([[2]]), tuple[int, int])
146+
assert_type(np.shape([[[3]]]), tuple[int, ...])
145147
assert_type(np.shape(AR_b), tuple[int, ...])
146-
assert_type(np.shape(AR_f4), tuple[int, ...])
148+
assert_type(np.shape(AR_nd), tuple[int, ...])
149+
# these fail on mypy, but it works as expected with pyright/pylance
150+
# assert_type(np.shape(AR_0d), tuple[()])
151+
# assert_type(np.shape(AR_1d), tuple[int])
152+
# assert_type(np.shape(AR_2d), tuple[int, int])
147153

148154
assert_type(np.compress([True], b), npt.NDArray[np.bool])
149155
assert_type(np.compress([True], f4), npt.NDArray[np.float32])

0 commit comments

Comments
 (0)