Skip to content

Commit 780d4d8

Browse files
committed
TYP: Shape-typed generic.reshape, and added the missing copy kwarg
1 parent 44c2263 commit 780d4d8

File tree

2 files changed

+103
-16
lines changed

2 files changed

+103
-16
lines changed

numpy/__init__.pyi

Lines changed: 84 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import builtins
22
import sys
3-
import os
43
import mmap
54
import ctypes as ct
65
import array as _array
@@ -210,7 +209,7 @@ from typing import (
210209
# library include `typing_extensions` stubs:
211210
# https://github.com/python/typeshed/blob/main/stdlib/typing_extensions.pyi
212211
from _typeshed import StrOrBytesPath, SupportsFlush, SupportsLenAndGetItem, SupportsWrite
213-
from typing_extensions import CapsuleType, Generic, LiteralString, Protocol, Self, TypeVar, Unpack, deprecated, overload
212+
from typing_extensions import CapsuleType, Generic, LiteralString, Never, Protocol, Self, TypeVar, Unpack, deprecated, overload
214213

215214
from numpy import (
216215
core,
@@ -1757,11 +1756,11 @@ _IntegralArrayT = TypeVar("_IntegralArrayT", bound=NDArray[integer[Any] | np.boo
17571756
_RealArrayT = TypeVar("_RealArrayT", bound=NDArray[floating[Any] | integer[Any] | timedelta64 | np.bool | object_])
17581757
_NumericArrayT = TypeVar("_NumericArrayT", bound=NDArray[number[Any] | timedelta64 | object_])
17591758

1760-
_Shape1D: TypeAlias = tuple[int]
17611759
_Shape2D: TypeAlias = tuple[int, int]
17621760

17631761
_ShapeType = TypeVar("_ShapeType", bound=_Shape)
17641762
_ShapeType_co = TypeVar("_ShapeType_co", covariant=True, bound=_Shape)
1763+
_Shape1NType = TypeVar("_Shape1NType", bound=tuple[L[1], Unpack[tuple[L[1], ...]]]) # (1,) | (1, 1) | (1, 1, 1) | ...
17651764
_Shape2DType_co = TypeVar("_Shape2DType_co", covariant=True, bound=_Shape2D)
17661765
_NumberType = TypeVar("_NumberType", bound=number[Any])
17671766

@@ -3204,10 +3203,88 @@ class generic(_ArrayOrScalarCommon):
32043203
def flatten(self, order: _OrderKACF = ...) -> NDArray[Self]: ...
32053204
def ravel(self, order: _OrderKACF = ...) -> NDArray[Self]: ...
32063205

3207-
@overload
3208-
def reshape(self, shape: _ShapeLike, /, *, order: _OrderACF = ...) -> NDArray[Self]: ...
3209-
@overload
3210-
def reshape(self, *shape: SupportsIndex, order: _OrderACF = ...) -> NDArray[Self]: ...
3206+
@overload # (() | [])
3207+
def reshape(
3208+
self,
3209+
shape: tuple[()] | list[Never],
3210+
/,
3211+
*,
3212+
order: _OrderACF = "C",
3213+
copy: builtins.bool | None = None,
3214+
) -> Self: ...
3215+
@overload # ((1, *(1, ...))@_ShapeType)
3216+
def reshape(
3217+
self,
3218+
shape: _Shape1NType,
3219+
/,
3220+
*,
3221+
order: _OrderACF = "C",
3222+
copy: builtins.bool | None = None,
3223+
) -> ndarray[_Shape1NType, dtype[Self]]: ...
3224+
@overload # (Sequence[index, ...]) # not recommended
3225+
def reshape(
3226+
self,
3227+
shape: Sequence[SupportsIndex],
3228+
/,
3229+
*,
3230+
order: _OrderACF = "C",
3231+
copy: builtins.bool | None = None,
3232+
) -> Self | ndarray[tuple[L[1], ...], dtype[Self]]: ...
3233+
@overload # _(index)
3234+
def reshape(
3235+
self,
3236+
size1: SupportsIndex,
3237+
/,
3238+
*,
3239+
order: _OrderACF = "C",
3240+
copy: builtins.bool | None = None,
3241+
) -> ndarray[tuple[L[1]], dtype[Self]]: ...
3242+
@overload # _(index, index)
3243+
def reshape(
3244+
self,
3245+
size1: SupportsIndex,
3246+
size2: SupportsIndex,
3247+
/,
3248+
*,
3249+
order: _OrderACF = "C",
3250+
copy: builtins.bool | None = None,
3251+
) -> ndarray[tuple[L[1], L[1]], dtype[Self]]: ...
3252+
@overload # _(index, index, index)
3253+
def reshape(
3254+
self,
3255+
size1: SupportsIndex,
3256+
size2: SupportsIndex,
3257+
size3: SupportsIndex,
3258+
/,
3259+
*,
3260+
order: _OrderACF = "C",
3261+
copy: builtins.bool | None = None,
3262+
) -> ndarray[tuple[L[1], L[1], L[1]], dtype[Self]]: ...
3263+
@overload # _(index, index, index, index)
3264+
def reshape(
3265+
self,
3266+
size1: SupportsIndex,
3267+
size2: SupportsIndex,
3268+
size3: SupportsIndex,
3269+
size4: SupportsIndex,
3270+
/,
3271+
*,
3272+
order: _OrderACF = "C",
3273+
copy: builtins.bool | None = None,
3274+
) -> ndarray[tuple[L[1], L[1], L[1], L[1]], dtype[Self]]: ...
3275+
@overload # _(index, index, index, index, index, *index) # ndim >= 5
3276+
def reshape(
3277+
self,
3278+
size1: SupportsIndex,
3279+
size2: SupportsIndex,
3280+
size3: SupportsIndex,
3281+
size4: SupportsIndex,
3282+
size5: SupportsIndex,
3283+
/,
3284+
*sizes6_: SupportsIndex,
3285+
order: _OrderACF = "C",
3286+
copy: builtins.bool | None = None,
3287+
) -> ndarray[tuple[L[1], L[1], L[1], L[1], L[1], Unpack[tuple[L[1], ...]]], dtype[Self]]: ...
32113288

32123289
def squeeze(self, axis: None | L[0] | tuple[()] = ...) -> Self: ...
32133290
def transpose(self, axes: None | tuple[()] = ..., /) -> Self: ...

numpy/typing/tests/data/reveal/scalars.pyi

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from typing import Any, Literal
1+
from typing import Any, Literal, TypeAlias
2+
from typing_extensions import Unpack, assert_type
23

34
import numpy as np
45
import numpy.typing as npt
56

6-
from typing_extensions import assert_type
7+
_1: TypeAlias = Literal[1]
78

89
b: np.bool
910
u8: np.uint64
@@ -109,13 +110,22 @@ assert_type(c16.flatten(), npt.NDArray[np.complex128])
109110
assert_type(U.flatten(), npt.NDArray[np.str_])
110111
assert_type(S.flatten(), npt.NDArray[np.bytes_])
111112

112-
assert_type(b.reshape(1), npt.NDArray[np.bool])
113-
assert_type(i8.reshape(1), npt.NDArray[np.int64])
114-
assert_type(u8.reshape(1), npt.NDArray[np.uint64])
115-
assert_type(f8.reshape(1), npt.NDArray[np.float64])
116-
assert_type(c16.reshape(1), npt.NDArray[np.complex128])
117-
assert_type(U.reshape(1), npt.NDArray[np.str_])
118-
assert_type(S.reshape(1), npt.NDArray[np.bytes_])
113+
assert_type(b.reshape(()), np.bool)
114+
assert_type(i8.reshape([]), np.int64)
115+
assert_type(b.reshape(1), np.ndarray[tuple[_1], np.dtype[np.bool]])
116+
assert_type(i8.reshape(-1), np.ndarray[tuple[_1], np.dtype[np.int64]])
117+
assert_type(u8.reshape(1, 1), np.ndarray[tuple[_1, _1], np.dtype[np.uint64]])
118+
assert_type(f8.reshape(1, -1), np.ndarray[tuple[_1, _1], np.dtype[np.float64]])
119+
assert_type(c16.reshape(1, 1, 1), np.ndarray[tuple[_1, _1, _1], np.dtype[np.complex128]])
120+
assert_type(U.reshape(1, 1, 1, 1), np.ndarray[tuple[_1, _1, _1, _1], np.dtype[np.str_]])
121+
assert_type(
122+
S.reshape(1, 1, 1, 1, 1),
123+
np.ndarray[
124+
# len(shape) >= 5
125+
tuple[_1, _1, _1, _1, _1, Unpack[tuple[_1, ...]]],
126+
np.dtype[np.bytes_],
127+
],
128+
)
119129

120130
assert_type(i8.astype(float), Any)
121131
assert_type(i8.astype(np.float64), np.float64)

0 commit comments

Comments
 (0)