Skip to content

Commit 547deac

Browse files
committed
TYP: 1-d shape-typing for ndarray.flatten and ravel
1 parent 7bfdc8c commit 547deac

File tree

6 files changed

+60
-53
lines changed

6 files changed

+60
-53
lines changed

numpy/__init__.pyi

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2190,17 +2190,8 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
21902190
axis: None | SupportsIndex = ...,
21912191
) -> ndarray[_Shape, _DType_co]: ...
21922192

2193-
# TODO: use `tuple[int]` as shape type once covariant (#26081)
2194-
def flatten(
2195-
self,
2196-
order: _OrderKACF = ...,
2197-
) -> ndarray[_Shape, _DType_co]: ...
2198-
2199-
# TODO: use `tuple[int]` as shape type once covariant (#26081)
2200-
def ravel(
2201-
self,
2202-
order: _OrderKACF = ...,
2203-
) -> ndarray[_Shape, _DType_co]: ...
2193+
def flatten(self, /, order: _OrderKACF = "C") -> ndarray[tuple[int], _DType_co]: ...
2194+
def ravel(self, /, order: _OrderKACF = "C") -> ndarray[tuple[int], _DType_co]: ...
22042195

22052196
@overload
22062197
def reshape(
@@ -3100,11 +3091,10 @@ _NBit_fc = TypeVar("_NBit_fc", _NBitHalf, _NBitSingle, _NBitDouble, _NBitLongDou
31003091
class generic(_ArrayOrScalarCommon):
31013092
@abstractmethod
31023093
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
3103-
# TODO: use `tuple[()]` as shape type once covariant (#26081)
31043094
@overload
3105-
def __array__(self, dtype: None = ..., /) -> NDArray[Self]: ...
3095+
def __array__(self, dtype: None = None, /) -> ndarray[tuple[()], dtype[Self]]: ...
31063096
@overload
3107-
def __array__(self, dtype: _DType, /) -> ndarray[_Shape, _DType]: ...
3097+
def __array__(self, dtype: _DType, /) -> ndarray[tuple[()], _DType]: ...
31083098
def __hash__(self) -> int: ...
31093099
@property
31103100
def base(self) -> None: ...
@@ -3118,7 +3108,7 @@ class generic(_ArrayOrScalarCommon):
31183108
def strides(self) -> tuple[()]: ...
31193109
def byteswap(self, inplace: L[False] = ...) -> Self: ...
31203110
@property
3121-
def flat(self) -> flatiter[NDArray[Self]]: ...
3111+
def flat(self) -> flatiter[ndarray[tuple[int], dtype[Self]]]: ...
31223112

31233113
if sys.version_info >= (3, 12):
31243114
def __buffer__(self, flags: int, /) -> memoryview: ...
@@ -3202,8 +3192,8 @@ class generic(_ArrayOrScalarCommon):
32023192
) -> _NdArraySubClass: ...
32033193

32043194
def repeat(self, repeats: _ArrayLikeInt_co, axis: None | SupportsIndex = ...) -> NDArray[Self]: ...
3205-
def flatten(self, order: _OrderKACF = ...) -> NDArray[Self]: ...
3206-
def ravel(self, order: _OrderKACF = ...) -> NDArray[Self]: ...
3195+
def flatten(self, /, order: _OrderKACF = "C") -> ndarray[tuple[int], dtype[Self]]: ...
3196+
def ravel(self, /, order: _OrderKACF = "C") -> ndarray[tuple[int], dtype[Self]]: ...
32073197

32083198
@overload
32093199
def reshape(self, shape: _ShapeLike, /, *, order: _OrderACF = ...) -> NDArray[Self]: ...
@@ -4492,13 +4482,12 @@ class poly1d:
44924482
@coefficients.setter
44934483
def coefficients(self, value: NDArray[Any]) -> None: ...
44944484

4495-
__hash__: ClassVar[None] # type: ignore
4485+
__hash__: ClassVar[None] # type: ignore[assignment] # pyright: ignore[reportIncompatibleMethodOverride]
44964486

4497-
# TODO: use `tuple[int]` as shape type once covariant (#26081)
44984487
@overload
4499-
def __array__(self, t: None = ..., copy: None | bool = ...) -> NDArray[Any]: ...
4488+
def __array__(self, /, t: None = None, copy: builtins.bool | None = None) -> ndarray[tuple[int], dtype[Any]]: ...
45004489
@overload
4501-
def __array__(self, t: _DType, copy: None | bool = ...) -> ndarray[_Shape, _DType]: ...
4490+
def __array__(self, /, t: _DType, copy: builtins.bool | None = None) -> ndarray[tuple[int], _DType]: ...
45024491

45034492
@overload
45044493
def __call__(self, val: _ScalarLike_co) -> Any: ...
@@ -4668,8 +4657,8 @@ class matrix(ndarray[_Shape2DType_co, _DType_co]):
46684657

46694658
def squeeze(self, axis: None | _ShapeLike = ...) -> matrix[_Shape2D, _DType_co]: ...
46704659
def tolist(self: _SupportsItem[_T]) -> list[list[_T]]: ...
4671-
def ravel(self, order: _OrderKACF = ...) -> matrix[_Shape2D, _DType_co]: ...
4672-
def flatten(self, order: _OrderKACF = ...) -> matrix[_Shape2D, _DType_co]: ...
4660+
def ravel(self, /, order: _OrderKACF = "C") -> matrix[tuple[L[1], int], _DType_co]: ... # pyright: ignore[reportIncompatibleMethodOverride]
4661+
def flatten(self, /, order: _OrderKACF = "C") -> matrix[tuple[L[1], int], _DType_co]: ... # pyright: ignore[reportIncompatibleMethodOverride]
46734662

46744663
@property
46754664
def T(self) -> matrix[_Shape2D, _DType_co]: ...

numpy/_core/fromnumeric.pyi

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ from numpy._typing import (
3939
ArrayLike,
4040
_ArrayLike,
4141
NDArray,
42+
_NestedSequence,
4243
_ShapeLike,
4344
_ArrayLikeBool_co,
4445
_ArrayLikeUInt_co,
@@ -438,10 +439,27 @@ def trace(
438439
out: _ArrayType = ...,
439440
) -> _ArrayType: ...
440441

442+
_Array1D: TypeAlias = np.ndarray[tuple[int], np.dtype[_SCT]]
443+
444+
@overload
445+
def ravel(a: _ArrayLike[_SCT], order: _OrderKACF = "C") -> _Array1D[_SCT]: ...
446+
@overload
447+
def ravel(a: bytes | _NestedSequence[bytes], order: _OrderKACF = "C") -> _Array1D[np.bytes_]: ...
448+
@overload
449+
def ravel(a: str | _NestedSequence[str], order: _OrderKACF = "C") -> _Array1D[np.str_]: ...
450+
@overload
451+
def ravel(a: bool | _NestedSequence[bool], order: _OrderKACF = "C") -> _Array1D[np.bool]: ...
452+
@overload
453+
def ravel(a: int | _NestedSequence[int], order: _OrderKACF = "C") -> _Array1D[np.int_ | np.bool]: ...
454+
@overload
455+
def ravel(a: float | _NestedSequence[float], order: _OrderKACF = "C") -> _Array1D[np.float64 | np.int_ | np.bool]: ...
441456
@overload
442-
def ravel(a: _ArrayLike[_SCT], order: _OrderKACF = ...) -> NDArray[_SCT]: ...
457+
def ravel(
458+
a: complex | _NestedSequence[complex],
459+
order: _OrderKACF = "C",
460+
) -> _Array1D[np.complex128 | np.float64 | np.int_ | np.bool]: ...
443461
@overload
444-
def ravel(a: ArrayLike, order: _OrderKACF = ...) -> NDArray[Any]: ...
462+
def ravel(a: ArrayLike, order: _OrderKACF = "C") -> np.ndarray[tuple[int], np.dtype[Any]]: ...
445463

446464
@overload
447465
def nonzero(a: np.generic | np.ndarray[tuple[()], Any]) -> NoReturn: ...

numpy/typing/tests/data/reveal/fromnumeric.pyi

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,11 @@ assert_type(np.trace(AR_b), Any)
121121
assert_type(np.trace(AR_f4), Any)
122122
assert_type(np.trace(AR_f4, out=AR_subclass), NDArraySubclass)
123123

124-
assert_type(np.ravel(b), npt.NDArray[np.bool])
125-
assert_type(np.ravel(f4), npt.NDArray[np.float32])
126-
assert_type(np.ravel(f), npt.NDArray[Any])
127-
assert_type(np.ravel(AR_b), npt.NDArray[np.bool])
128-
assert_type(np.ravel(AR_f4), npt.NDArray[np.float32])
124+
assert_type(np.ravel(b), np.ndarray[tuple[int], np.dtype[np.bool]])
125+
assert_type(np.ravel(f4), np.ndarray[tuple[int], np.dtype[np.float32]])
126+
assert_type(np.ravel(f), np.ndarray[tuple[int], np.dtype[np.float64 | np.int_ | np.bool]])
127+
assert_type(np.ravel(AR_b), np.ndarray[tuple[int], np.dtype[np.bool]])
128+
assert_type(np.ravel(AR_f4), np.ndarray[tuple[int], np.dtype[np.float32]])
129129

130130
assert_type(np.nonzero(b), NoReturn)
131131
assert_type(np.nonzero(f4), NoReturn)

numpy/typing/tests/data/reveal/ndarray_misc.pyi

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,11 @@ assert_type(AR_f8.trace(out=B), SubClass)
173173
assert_type(AR_f8.item(), float)
174174
assert_type(AR_U.item(), str)
175175

176-
assert_type(AR_f8.ravel(), npt.NDArray[np.float64])
177-
assert_type(AR_U.ravel(), npt.NDArray[np.str_])
176+
assert_type(AR_f8.ravel(), np.ndarray[tuple[int], np.dtype[np.float64]])
177+
assert_type(AR_U.ravel(), np.ndarray[tuple[int], np.dtype[np.str_]])
178178

179-
assert_type(AR_f8.flatten(), npt.NDArray[np.float64])
180-
assert_type(AR_U.flatten(), npt.NDArray[np.str_])
179+
assert_type(AR_f8.flatten(), np.ndarray[tuple[int], np.dtype[np.float64]])
180+
assert_type(AR_U.flatten(), np.ndarray[tuple[int], np.dtype[np.str_]])
181181

182182
assert_type(AR_f8.reshape(1), npt.NDArray[np.float64])
183183
assert_type(AR_U.reshape(1), npt.NDArray[np.str_])

numpy/typing/tests/data/reveal/ndarray_shape_manipulation.pyi

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ assert_type(nd.transpose((1, 0)), npt.NDArray[np.int64])
2525
assert_type(nd.swapaxes(0, 1), npt.NDArray[np.int64])
2626

2727
# flatten
28-
assert_type(nd.flatten(), npt.NDArray[np.int64])
29-
assert_type(nd.flatten("C"), npt.NDArray[np.int64])
28+
assert_type(nd.flatten(), np.ndarray[tuple[int], np.dtype[np.int64]])
29+
assert_type(nd.flatten("C"), np.ndarray[tuple[int], np.dtype[np.int64]])
3030

3131
# ravel
32-
assert_type(nd.ravel(), npt.NDArray[np.int64])
33-
assert_type(nd.ravel("C"), npt.NDArray[np.int64])
32+
assert_type(nd.ravel(), np.ndarray[tuple[int], np.dtype[np.int64]])
33+
assert_type(nd.ravel("C"), np.ndarray[tuple[int], np.dtype[np.int64]])
3434

3535
# squeeze
3636
assert_type(nd.squeeze(), npt.NDArray[np.int64])

numpy/typing/tests/data/reveal/scalars.pyi

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -93,21 +93,21 @@ assert_type(c16.tolist(), complex)
9393
assert_type(U.tolist(), str)
9494
assert_type(S.tolist(), bytes)
9595

96-
assert_type(b.ravel(), npt.NDArray[np.bool])
97-
assert_type(i8.ravel(), npt.NDArray[np.int64])
98-
assert_type(u8.ravel(), npt.NDArray[np.uint64])
99-
assert_type(f8.ravel(), npt.NDArray[np.float64])
100-
assert_type(c16.ravel(), npt.NDArray[np.complex128])
101-
assert_type(U.ravel(), npt.NDArray[np.str_])
102-
assert_type(S.ravel(), npt.NDArray[np.bytes_])
103-
104-
assert_type(b.flatten(), npt.NDArray[np.bool])
105-
assert_type(i8.flatten(), npt.NDArray[np.int64])
106-
assert_type(u8.flatten(), npt.NDArray[np.uint64])
107-
assert_type(f8.flatten(), npt.NDArray[np.float64])
108-
assert_type(c16.flatten(), npt.NDArray[np.complex128])
109-
assert_type(U.flatten(), npt.NDArray[np.str_])
110-
assert_type(S.flatten(), npt.NDArray[np.bytes_])
96+
assert_type(b.ravel(), np.ndarray[tuple[int], np.dtype[np.bool]])
97+
assert_type(i8.ravel(), np.ndarray[tuple[int], np.dtype[np.int64]])
98+
assert_type(u8.ravel(), np.ndarray[tuple[int], np.dtype[np.uint64]])
99+
assert_type(f8.ravel(), np.ndarray[tuple[int], np.dtype[np.float64]])
100+
assert_type(c16.ravel(), np.ndarray[tuple[int], np.dtype[np.complex128]])
101+
assert_type(U.ravel(), np.ndarray[tuple[int], np.dtype[np.str_]])
102+
assert_type(S.ravel(), np.ndarray[tuple[int], np.dtype[np.bytes_]])
103+
104+
assert_type(b.flatten(), np.ndarray[tuple[int], np.dtype[np.bool]])
105+
assert_type(i8.flatten(), np.ndarray[tuple[int], np.dtype[np.int64]])
106+
assert_type(u8.flatten(), np.ndarray[tuple[int], np.dtype[np.uint64]])
107+
assert_type(f8.flatten(), np.ndarray[tuple[int], np.dtype[np.float64]])
108+
assert_type(c16.flatten(), np.ndarray[tuple[int], np.dtype[np.complex128]])
109+
assert_type(U.flatten(), np.ndarray[tuple[int], np.dtype[np.str_]])
110+
assert_type(S.flatten(), np.ndarray[tuple[int], np.dtype[np.bytes_]])
111111

112112
assert_type(b.reshape(1), npt.NDArray[np.bool])
113113
assert_type(i8.reshape(1), npt.NDArray[np.int64])

0 commit comments

Comments
 (0)