Skip to content

Commit ad82a25

Browse files
authored
Merge pull request numpy#27767 from jorenham/typing/reshape-shape-typing
TYP: Support shape-typing in ``reshape`` and ``resize``
2 parents 8eb0415 + d82e440 commit ad82a25

File tree

6 files changed

+291
-63
lines changed

6 files changed

+291
-63
lines changed

numpy/__init__.pyi

Lines changed: 172 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import builtins
22
import sys
3-
import os
43
import mmap
54
import ctypes as ct
65
import array as _array
@@ -208,7 +207,7 @@ from typing import (
208207
# library include `typing_extensions` stubs:
209208
# https://github.com/python/typeshed/blob/main/stdlib/typing_extensions.pyi
210209
from _typeshed import StrOrBytesPath, SupportsFlush, SupportsLenAndGetItem, SupportsWrite
211-
from typing_extensions import CapsuleType, Generic, LiteralString, Protocol, Self, TypeVar, Unpack, deprecated, overload
210+
from typing_extensions import CapsuleType, Generic, LiteralString, Never, Protocol, Self, TypeVar, Unpack, deprecated, overload
212211

213212
from numpy import (
214213
core,
@@ -1755,12 +1754,25 @@ _IntegralArrayT = TypeVar("_IntegralArrayT", bound=NDArray[integer[Any] | np.boo
17551754
_RealArrayT = TypeVar("_RealArrayT", bound=NDArray[floating[Any] | integer[Any] | timedelta64 | np.bool | object_])
17561755
_NumericArrayT = TypeVar("_NumericArrayT", bound=NDArray[number[Any] | timedelta64 | object_])
17571756

1758-
_Shape1D: TypeAlias = tuple[int]
1759-
_Shape2D: TypeAlias = tuple[int, int]
1760-
1757+
_AnyShapeType = TypeVar(
1758+
"_AnyShapeType",
1759+
tuple[()], # 0-d
1760+
tuple[int], # 1-d
1761+
tuple[int, int], # 2-d
1762+
tuple[int, int, int], # 3-d
1763+
tuple[int, int, int, int], # 4-d
1764+
tuple[int, int, int, int, int], # 5-d
1765+
tuple[int, int, int, int, int, int], # 6-d
1766+
tuple[int, int, int, int, int, int, int], # 7-d
1767+
tuple[int, int, int, int, int, int, int, int], # 8-d
1768+
tuple[int, ...], # N-d
1769+
)
17611770
_ShapeType = TypeVar("_ShapeType", bound=_Shape)
17621771
_ShapeType_co = TypeVar("_ShapeType_co", covariant=True, bound=_Shape)
1772+
_Shape2D: TypeAlias = tuple[int, int]
17631773
_Shape2DType_co = TypeVar("_Shape2DType_co", covariant=True, bound=_Shape2D)
1774+
_Shape1NType = TypeVar("_Shape1NType", bound=tuple[L[1], Unpack[tuple[L[1], ...]]]) # (1,) | (1, 1) | (1, 1, 1) | ...
1775+
17641776
_NumberType = TypeVar("_NumberType", bound=number[Any])
17651777

17661778

@@ -2194,21 +2206,86 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
21942206
def flatten(self, /, order: _OrderKACF = "C") -> ndarray[tuple[int], _DType_co]: ...
21952207
def ravel(self, /, order: _OrderKACF = "C") -> ndarray[tuple[int], _DType_co]: ...
21962208

2197-
@overload
2209+
# NOTE: reshape also accepts negative integers, so we can't use integer literals
2210+
@overload # (None)
2211+
def reshape(self, shape: None, /, *, order: _OrderACF = "C", copy: builtins.bool | None = None) -> Self: ...
2212+
@overload # (empty_sequence)
2213+
def reshape( # type: ignore[overload-overlap] # mypy false positive
2214+
self,
2215+
shape: Sequence[Never],
2216+
/,
2217+
*,
2218+
order: _OrderACF = "C",
2219+
copy: builtins.bool | None = None,
2220+
) -> ndarray[tuple[()], _DType_co]: ...
2221+
@overload # (() | (int) | (int, int) | ....) # up to 8-d
21982222
def reshape(
21992223
self,
2200-
shape: _ShapeLike,
2224+
shape: _AnyShapeType,
22012225
/,
22022226
*,
2203-
order: _OrderACF = ...,
2204-
copy: None | builtins.bool = ...,
2205-
) -> ndarray[_Shape, _DType_co]: ...
2206-
@overload
2227+
order: _OrderACF = "C",
2228+
copy: builtins.bool | None = None,
2229+
) -> ndarray[_AnyShapeType, _DType_co]: ...
2230+
@overload # (index)
2231+
def reshape(
2232+
self,
2233+
size1: SupportsIndex,
2234+
/,
2235+
*,
2236+
order: _OrderACF = "C",
2237+
copy: builtins.bool | None = None,
2238+
) -> ndarray[tuple[int], _DType_co]: ...
2239+
@overload # (index, index)
2240+
def reshape(
2241+
self,
2242+
size1: SupportsIndex,
2243+
size2: SupportsIndex,
2244+
/,
2245+
*,
2246+
order: _OrderACF = "C",
2247+
copy: builtins.bool | None = None,
2248+
) -> ndarray[tuple[int, int], _DType_co]: ...
2249+
@overload # (index, index, index)
22072250
def reshape(
22082251
self,
2252+
size1: SupportsIndex,
2253+
size2: SupportsIndex,
2254+
size3: SupportsIndex,
2255+
/,
2256+
*,
2257+
order: _OrderACF = "C",
2258+
copy: builtins.bool | None = None,
2259+
) -> ndarray[tuple[int, int, int], _DType_co]: ...
2260+
@overload # (index, index, index, index)
2261+
def reshape(
2262+
self,
2263+
size1: SupportsIndex,
2264+
size2: SupportsIndex,
2265+
size3: SupportsIndex,
2266+
size4: SupportsIndex,
2267+
/,
2268+
*,
2269+
order: _OrderACF = "C",
2270+
copy: builtins.bool | None = None,
2271+
) -> ndarray[tuple[int, int, int, int], _DType_co]: ...
2272+
@overload # (int, *(index, ...))
2273+
def reshape(
2274+
self,
2275+
size0: SupportsIndex,
2276+
/,
22092277
*shape: SupportsIndex,
2210-
order: _OrderACF = ...,
2211-
copy: None | builtins.bool = ...,
2278+
order: _OrderACF = "C",
2279+
copy: builtins.bool | None = None,
2280+
) -> ndarray[_Shape, _DType_co]: ...
2281+
@overload # (sequence[index])
2282+
def reshape(
2283+
self,
2284+
shape: Sequence[SupportsIndex],
2285+
/,
2286+
*,
2287+
order: _OrderACF = "C",
2288+
copy: builtins.bool | None = None,
22122289
) -> ndarray[_Shape, _DType_co]: ...
22132290

22142291
@overload
@@ -3110,10 +3187,88 @@ class generic(_ArrayOrScalarCommon):
31103187
def flatten(self, /, order: _OrderKACF = "C") -> ndarray[tuple[int], dtype[Self]]: ...
31113188
def ravel(self, /, order: _OrderKACF = "C") -> ndarray[tuple[int], dtype[Self]]: ...
31123189

3113-
@overload
3114-
def reshape(self, shape: _ShapeLike, /, *, order: _OrderACF = ...) -> NDArray[Self]: ...
3115-
@overload
3116-
def reshape(self, *shape: SupportsIndex, order: _OrderACF = ...) -> NDArray[Self]: ...
3190+
@overload # (() | [])
3191+
def reshape(
3192+
self,
3193+
shape: tuple[()] | list[Never],
3194+
/,
3195+
*,
3196+
order: _OrderACF = "C",
3197+
copy: builtins.bool | None = None,
3198+
) -> Self: ...
3199+
@overload # ((1, *(1, ...))@_ShapeType)
3200+
def reshape(
3201+
self,
3202+
shape: _Shape1NType,
3203+
/,
3204+
*,
3205+
order: _OrderACF = "C",
3206+
copy: builtins.bool | None = None,
3207+
) -> ndarray[_Shape1NType, dtype[Self]]: ...
3208+
@overload # (Sequence[index, ...]) # not recommended
3209+
def reshape(
3210+
self,
3211+
shape: Sequence[SupportsIndex],
3212+
/,
3213+
*,
3214+
order: _OrderACF = "C",
3215+
copy: builtins.bool | None = None,
3216+
) -> Self | ndarray[tuple[L[1], ...], dtype[Self]]: ...
3217+
@overload # _(index)
3218+
def reshape(
3219+
self,
3220+
size1: SupportsIndex,
3221+
/,
3222+
*,
3223+
order: _OrderACF = "C",
3224+
copy: builtins.bool | None = None,
3225+
) -> ndarray[tuple[L[1]], dtype[Self]]: ...
3226+
@overload # _(index, index)
3227+
def reshape(
3228+
self,
3229+
size1: SupportsIndex,
3230+
size2: SupportsIndex,
3231+
/,
3232+
*,
3233+
order: _OrderACF = "C",
3234+
copy: builtins.bool | None = None,
3235+
) -> ndarray[tuple[L[1], L[1]], dtype[Self]]: ...
3236+
@overload # _(index, index, index)
3237+
def reshape(
3238+
self,
3239+
size1: SupportsIndex,
3240+
size2: SupportsIndex,
3241+
size3: SupportsIndex,
3242+
/,
3243+
*,
3244+
order: _OrderACF = "C",
3245+
copy: builtins.bool | None = None,
3246+
) -> ndarray[tuple[L[1], L[1], L[1]], dtype[Self]]: ...
3247+
@overload # _(index, index, index, index)
3248+
def reshape(
3249+
self,
3250+
size1: SupportsIndex,
3251+
size2: SupportsIndex,
3252+
size3: SupportsIndex,
3253+
size4: SupportsIndex,
3254+
/,
3255+
*,
3256+
order: _OrderACF = "C",
3257+
copy: builtins.bool | None = None,
3258+
) -> ndarray[tuple[L[1], L[1], L[1], L[1]], dtype[Self]]: ...
3259+
@overload # _(index, index, index, index, index, *index) # ndim >= 5
3260+
def reshape(
3261+
self,
3262+
size1: SupportsIndex,
3263+
size2: SupportsIndex,
3264+
size3: SupportsIndex,
3265+
size4: SupportsIndex,
3266+
size5: SupportsIndex,
3267+
/,
3268+
*sizes6_: SupportsIndex,
3269+
order: _OrderACF = "C",
3270+
copy: builtins.bool | None = None,
3271+
) -> ndarray[tuple[L[1], L[1], L[1], L[1], L[1], Unpack[tuple[L[1], ...]]], dtype[Self]]: ...
31173272

31183273
def squeeze(self, axis: None | L[0] | tuple[()] = ...) -> Self: ...
31193274
def transpose(self, axes: None | tuple[()] = ..., /) -> Self: ...

numpy/_core/fromnumeric.pyi

Lines changed: 75 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ from typing import (
1010
overload,
1111
type_check_only,
1212
)
13-
from typing_extensions import Never
13+
from typing_extensions import Never, deprecated
1414

1515
import numpy as np
1616
from numpy import (
@@ -25,6 +25,7 @@ from numpy import (
2525
timedelta64,
2626
object_,
2727
generic,
28+
_AnyShapeType,
2829
_OrderKACF,
2930
_OrderACF,
3031
_ModeKind,
@@ -105,6 +106,7 @@ __all__ = [
105106
_SCT = TypeVar("_SCT", bound=generic)
106107
_SCT_uifcO = TypeVar("_SCT_uifcO", bound=number[Any] | object_)
107108
_ArrayType = TypeVar("_ArrayType", bound=np.ndarray[Any, Any])
109+
_SizeType = TypeVar("_SizeType", bound=int)
108110
_ShapeType = TypeVar("_ShapeType", bound=tuple[int, ...])
109111
_ShapeType_co = TypeVar("_ShapeType_co", bound=tuple[int, ...], covariant=True)
110112

@@ -162,24 +164,73 @@ def take(
162164
) -> _ArrayType: ...
163165

164166
@overload
167+
def reshape( # shape: index
168+
a: _ArrayLike[_SCT],
169+
/,
170+
shape: SupportsIndex,
171+
order: _OrderACF = "C",
172+
*,
173+
copy: bool | None = None,
174+
) -> np.ndarray[tuple[int], np.dtype[_SCT]]: ...
175+
@overload
176+
def reshape( # shape: (int, ...) @ _AnyShapeType
177+
a: _ArrayLike[_SCT],
178+
/,
179+
shape: _AnyShapeType,
180+
order: _OrderACF = "C",
181+
*,
182+
copy: bool | None = None,
183+
) -> np.ndarray[_AnyShapeType, np.dtype[_SCT]]: ...
184+
@overload # shape: Sequence[index]
165185
def reshape(
166186
a: _ArrayLike[_SCT],
167187
/,
168-
shape: _ShapeLike = ...,
169-
order: _OrderACF = ...,
188+
shape: Sequence[SupportsIndex],
189+
order: _OrderACF = "C",
170190
*,
171-
newshape: _ShapeLike = ...,
172-
copy: None | bool = ...,
191+
copy: bool | None = None,
173192
) -> NDArray[_SCT]: ...
193+
@overload # shape: index
194+
def reshape(
195+
a: ArrayLike,
196+
/,
197+
shape: SupportsIndex,
198+
order: _OrderACF = "C",
199+
*,
200+
copy: bool | None = None,
201+
) -> np.ndarray[tuple[int], np.dtype[Any]]: ...
174202
@overload
203+
def reshape( # shape: (int, ...) @ _AnyShapeType
204+
a: ArrayLike,
205+
/,
206+
shape: _AnyShapeType,
207+
order: _OrderACF = "C",
208+
*,
209+
copy: bool | None = None,
210+
) -> np.ndarray[_AnyShapeType, np.dtype[Any]]: ...
211+
@overload # shape: Sequence[index]
175212
def reshape(
176213
a: ArrayLike,
177214
/,
178-
shape: _ShapeLike = ...,
179-
order: _OrderACF = ...,
215+
shape: Sequence[SupportsIndex],
216+
order: _OrderACF = "C",
180217
*,
181-
newshape: _ShapeLike = ...,
182-
copy: None | bool = ...,
218+
copy: bool | None = None,
219+
) -> NDArray[Any]: ...
220+
@overload
221+
@deprecated(
222+
"`newshape` keyword argument is deprecated, "
223+
"use `shape=...` or pass shape positionally instead. "
224+
"(deprecated in NumPy 2.1)",
225+
)
226+
def reshape(
227+
a: ArrayLike,
228+
/,
229+
shape: None = None,
230+
order: _OrderACF = "C",
231+
*,
232+
newshape: _ShapeLike,
233+
copy: bool | None = None,
183234
) -> NDArray[Any]: ...
184235

185236
@overload
@@ -378,16 +429,23 @@ def searchsorted(
378429
sorter: None | _ArrayLikeInt_co = ..., # 1D int array
379430
) -> NDArray[intp]: ...
380431

432+
# unlike `reshape`, `resize` only accepts positive integers, so literal ints can be used
381433
@overload
382-
def resize(
383-
a: _ArrayLike[_SCT],
384-
new_shape: _ShapeLike,
385-
) -> NDArray[_SCT]: ...
434+
def resize(a: _ArrayLike[_SCT], new_shape: _SizeType) -> np.ndarray[tuple[_SizeType], np.dtype[_SCT]]: ...
386435
@overload
387-
def resize(
388-
a: ArrayLike,
389-
new_shape: _ShapeLike,
390-
) -> NDArray[Any]: ...
436+
def resize(a: _ArrayLike[_SCT], new_shape: SupportsIndex) -> np.ndarray[tuple[int], np.dtype[_SCT]]: ...
437+
@overload
438+
def resize(a: _ArrayLike[_SCT], new_shape: _ShapeType) -> np.ndarray[_ShapeType, np.dtype[_SCT]]: ...
439+
@overload
440+
def resize(a: _ArrayLike[_SCT], new_shape: Sequence[SupportsIndex]) -> NDArray[_SCT]: ...
441+
@overload
442+
def resize(a: ArrayLike, new_shape: _SizeType) -> np.ndarray[tuple[_SizeType], np.dtype[Any]]: ...
443+
@overload
444+
def resize(a: ArrayLike, new_shape: SupportsIndex) -> np.ndarray[tuple[int], np.dtype[Any]]: ...
445+
@overload
446+
def resize(a: ArrayLike, new_shape: _ShapeType) -> np.ndarray[_ShapeType, np.dtype[Any]]: ...
447+
@overload
448+
def resize(a: ArrayLike, new_shape: Sequence[SupportsIndex]) -> NDArray[Any]: ...
391449

392450
@overload
393451
def squeeze(

0 commit comments

Comments
 (0)