Skip to content

Commit cb6f5fe

Browse files
committed
TYP: Shape-typed ndarray.reshape method
1 parent 780d4d8 commit cb6f5fe

File tree

3 files changed

+103
-20
lines changed

3 files changed

+103
-20
lines changed

numpy/__init__.pyi

Lines changed: 89 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,12 +1756,25 @@ _IntegralArrayT = TypeVar("_IntegralArrayT", bound=NDArray[integer[Any] | np.boo
17561756
_RealArrayT = TypeVar("_RealArrayT", bound=NDArray[floating[Any] | integer[Any] | timedelta64 | np.bool | object_])
17571757
_NumericArrayT = TypeVar("_NumericArrayT", bound=NDArray[number[Any] | timedelta64 | object_])
17581758

1759-
_Shape2D: TypeAlias = tuple[int, int]
1760-
1759+
_AnyShapeType = TypeVar(
1760+
"_AnyShapeType",
1761+
tuple[()], # 0-d
1762+
tuple[int], # 1-d
1763+
tuple[int, int], # 2-d
1764+
tuple[int, int, int], # 3-d
1765+
tuple[int, int, int, int], # 4-d
1766+
tuple[int, int, int, int, int], # 5-d
1767+
tuple[int, int, int, int, int, int], # 6-d
1768+
tuple[int, int, int, int, int, int, int], # 7-d
1769+
tuple[int, int, int, int, int, int, int, int], # 8-d
1770+
tuple[int, ...], # N-d
1771+
)
17611772
_ShapeType = TypeVar("_ShapeType", bound=_Shape)
17621773
_ShapeType_co = TypeVar("_ShapeType_co", covariant=True, bound=_Shape)
1763-
_Shape1NType = TypeVar("_Shape1NType", bound=tuple[L[1], Unpack[tuple[L[1], ...]]]) # (1,) | (1, 1) | (1, 1, 1) | ...
1774+
_Shape2D: TypeAlias = tuple[int, int]
17641775
_Shape2DType_co = TypeVar("_Shape2DType_co", covariant=True, bound=_Shape2D)
1776+
_Shape1NType = TypeVar("_Shape1NType", bound=tuple[L[1], Unpack[tuple[L[1], ...]]]) # (1,) | (1, 1) | (1, 1, 1) | ...
1777+
17651778
_NumberType = TypeVar("_NumberType", bound=number[Any])
17661779

17671780

@@ -2204,21 +2217,86 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
22042217
order: _OrderKACF = ...,
22052218
) -> ndarray[_Shape, _DType_co]: ...
22062219

2207-
@overload
2220+
# NOTE: reshape also accepts negative integers, so we can't use integer literals
2221+
@overload # (None)
2222+
def reshape(self, shape: None, /, *, order: _OrderACF = "C", copy: builtins.bool | None = None) -> Self: ...
2223+
@overload # (empty_sequence)
2224+
def reshape( # type: ignore[overload-overlap] # mypy false positive
2225+
self,
2226+
shape: Sequence[Never],
2227+
/,
2228+
*,
2229+
order: _OrderACF = "C",
2230+
copy: builtins.bool | None = None,
2231+
) -> ndarray[tuple[()], _DType_co]: ...
2232+
@overload # (() | (int) | (int, int) | ....) # up to 8-d
22082233
def reshape(
22092234
self,
2210-
shape: _ShapeLike,
2235+
shape: _AnyShapeType,
22112236
/,
22122237
*,
2213-
order: _OrderACF = ...,
2214-
copy: None | builtins.bool = ...,
2215-
) -> ndarray[_Shape, _DType_co]: ...
2216-
@overload
2238+
order: _OrderACF = "C",
2239+
copy: builtins.bool | None = None,
2240+
) -> ndarray[_AnyShapeType, _DType_co]: ...
2241+
@overload # (index)
2242+
def reshape(
2243+
self,
2244+
size1: SupportsIndex,
2245+
/,
2246+
*,
2247+
order: _OrderACF = "C",
2248+
copy: builtins.bool | None = None,
2249+
) -> ndarray[tuple[int], _DType_co]: ...
2250+
@overload # (index, index)
22172251
def reshape(
22182252
self,
2253+
size1: SupportsIndex,
2254+
size2: SupportsIndex,
2255+
/,
2256+
*,
2257+
order: _OrderACF = "C",
2258+
copy: builtins.bool | None = None,
2259+
) -> ndarray[tuple[int, int], _DType_co]: ...
2260+
@overload # (index, index, index)
2261+
def reshape(
2262+
self,
2263+
size1: SupportsIndex,
2264+
size2: SupportsIndex,
2265+
size3: SupportsIndex,
2266+
/,
2267+
*,
2268+
order: _OrderACF = "C",
2269+
copy: builtins.bool | None = None,
2270+
) -> ndarray[tuple[int, int, int], _DType_co]: ...
2271+
@overload # (index, index, index, index)
2272+
def reshape(
2273+
self,
2274+
size1: SupportsIndex,
2275+
size2: SupportsIndex,
2276+
size3: SupportsIndex,
2277+
size4: SupportsIndex,
2278+
/,
2279+
*,
2280+
order: _OrderACF = "C",
2281+
copy: builtins.bool | None = None,
2282+
) -> ndarray[tuple[int, int, int, int], _DType_co]: ...
2283+
@overload # (int, *(index, ...))
2284+
def reshape(
2285+
self,
2286+
size0: SupportsIndex,
2287+
/,
22192288
*shape: SupportsIndex,
2220-
order: _OrderACF = ...,
2221-
copy: None | builtins.bool = ...,
2289+
order: _OrderACF = "C",
2290+
copy: builtins.bool | None = None,
2291+
) -> ndarray[_Shape, _DType_co]: ...
2292+
@overload # (sequence[index])
2293+
def reshape(
2294+
self,
2295+
shape: Sequence[SupportsIndex],
2296+
/,
2297+
*,
2298+
order: _OrderACF = "C",
2299+
copy: builtins.bool | None = None,
22222300
) -> ndarray[_Shape, _DType_co]: ...
22232301

22242302
@overload

numpy/typing/tests/data/reveal/ndarray_misc.pyi

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,12 @@ assert_type(AR_U.ravel(), npt.NDArray[np.str_])
179179
assert_type(AR_f8.flatten(), npt.NDArray[np.float64])
180180
assert_type(AR_U.flatten(), npt.NDArray[np.str_])
181181

182-
assert_type(AR_f8.reshape(1), npt.NDArray[np.float64])
183-
assert_type(AR_U.reshape(1), npt.NDArray[np.str_])
182+
assert_type(AR_i8.reshape(None), npt.NDArray[np.int64])
183+
assert_type(AR_f8.reshape(-1), np.ndarray[tuple[int], np.dtype[np.float64]])
184+
assert_type(AR_c8.reshape(2, 3, 4, 5), np.ndarray[tuple[int, int, int, int], np.dtype[np.complex64]])
185+
assert_type(AR_m.reshape(()), np.ndarray[tuple[()], np.dtype[np.timedelta64]])
186+
assert_type(AR_U.reshape([]), np.ndarray[tuple[()], np.dtype[np.str_]])
187+
assert_type(AR_V.reshape((480, 720, 4)), np.ndarray[tuple[int, int, int], np.dtype[np.void]])
184188

185189
assert_type(int(AR_f8), int)
186190
assert_type(int(AR_U), int)

numpy/typing/tests/data/reveal/ndarray_shape_manipulation.pyi

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@ from typing_extensions import assert_type
66
nd: npt.NDArray[np.int64]
77

88
# reshape
9-
assert_type(nd.reshape(), npt.NDArray[np.int64])
10-
assert_type(nd.reshape(4), npt.NDArray[np.int64])
11-
assert_type(nd.reshape(2, 2), npt.NDArray[np.int64])
12-
assert_type(nd.reshape((2, 2)), npt.NDArray[np.int64])
13-
14-
assert_type(nd.reshape((2, 2), order="C"), npt.NDArray[np.int64])
15-
assert_type(nd.reshape(4, order="C"), npt.NDArray[np.int64])
9+
assert_type(nd.reshape(None), npt.NDArray[np.int64])
10+
assert_type(nd.reshape(4), np.ndarray[tuple[int], np.dtype[np.int64]])
11+
assert_type(nd.reshape((4,)), np.ndarray[tuple[int], np.dtype[np.int64]])
12+
assert_type(nd.reshape(2, 2), np.ndarray[tuple[int, int], np.dtype[np.int64]])
13+
assert_type(nd.reshape((2, 2)), np.ndarray[tuple[int, int], np.dtype[np.int64]])
14+
15+
assert_type(nd.reshape((2, 2), order="C"), np.ndarray[tuple[int, int], np.dtype[np.int64]])
16+
assert_type(nd.reshape(4, order="C"), np.ndarray[tuple[int], np.dtype[np.int64]])
1617

1718
# resize does not return a value
1819

0 commit comments

Comments
 (0)