Skip to content

Commit 7bfdc8c

Browse files
authored
Merge pull request numpy#27750 from jorenham/typing/ndarray.item
TYP: Fix ``ndarray.item()`` and improve ``ndarray.tolist()``
2 parents da32320 + 56ca6cb commit 7bfdc8c

File tree

3 files changed

+91
-39
lines changed

3 files changed

+91
-39
lines changed

numpy/__init__.pyi

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,7 +1799,28 @@ _ArrayAPIVersion: TypeAlias = L["2021.12", "2022.12", "2023.12"]
17991799

18001800
@type_check_only
18011801
class _SupportsItem(Protocol[_T_co]):
1802-
def item(self, args: Any, /) -> _T_co: ...
1802+
def item(self, /) -> _T_co: ...
1803+
1804+
@type_check_only
1805+
class _HasShapeAndSupportsItem(_SupportsItem[_T_co], Protocol[_ShapeType_co, _T_co]):
1806+
@property
1807+
def shape(self, /) -> _ShapeType_co: ...
1808+
1809+
# matches any `x` on `x.type.item() -> _T_co`, e.g. `dtype[np.int8]` gives `_T_co: int`
1810+
@type_check_only
1811+
class _HashTypeWithItem(Protocol[_T_co]):
1812+
@property
1813+
def type(self, /) -> type[_SupportsItem[_T_co]]: ...
1814+
1815+
# matches any `x` on `x.shape: _ShapeType_co` and `x.dtype.type.item() -> _T_co`,
1816+
# useful for capturing the item-type (`_T_co`) of the scalar-type of an array with
1817+
# specific shape (`_ShapeType_co`).
1818+
@type_check_only
1819+
class _HasShapeAndDTypeWithItem(Protocol[_ShapeType_co, _T_co]):
1820+
@property
1821+
def shape(self, /) -> _ShapeType_co: ...
1822+
@property
1823+
def dtype(self, /) -> _HashTypeWithItem[_T_co]: ...
18031824

18041825
@type_check_only
18051826
class _SupportsReal(Protocol[_T_co]):
@@ -1921,18 +1942,29 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
19211942
@property
19221943
def flat(self) -> flatiter[Self]: ...
19231944

1924-
# Use the same output type as that of the underlying `generic`
1945+
@overload # special casing for `StringDType`, which has no scalar type
1946+
def item(self: ndarray[Any, dtypes.StringDType], /) -> str: ...
19251947
@overload
1926-
def item(
1927-
self: ndarray[Any, _dtype[_SupportsItem[_T]]], # type: ignore[type-var]
1928-
*args: SupportsIndex,
1929-
) -> _T: ...
1948+
def item(self: ndarray[Any, dtypes.StringDType], arg0: SupportsIndex | tuple[SupportsIndex, ...] = ..., /) -> str: ...
19301949
@overload
1931-
def item(
1932-
self: ndarray[Any, _dtype[_SupportsItem[_T]]], # type: ignore[type-var]
1933-
args: tuple[SupportsIndex, ...],
1934-
/,
1935-
) -> _T: ...
1950+
def item(self: ndarray[Any, dtypes.StringDType], /, *args: SupportsIndex) -> str: ...
1951+
@overload # use the same output type as that of the underlying `generic`
1952+
def item(self: _HasShapeAndDTypeWithItem[Any, _T], /) -> _T: ...
1953+
@overload
1954+
def item(self: _HasShapeAndDTypeWithItem[Any, _T], arg0: SupportsIndex | tuple[SupportsIndex, ...] = ..., /) -> _T: ...
1955+
@overload
1956+
def item(self: _HasShapeAndDTypeWithItem[Any, _T], /, *args: SupportsIndex) -> _T: ...
1957+
1958+
@overload
1959+
def tolist(self: _HasShapeAndSupportsItem[tuple[()], _T], /) -> _T: ...
1960+
@overload
1961+
def tolist(self: _HasShapeAndSupportsItem[tuple[int], _T], /) -> list[_T]: ...
1962+
@overload
1963+
def tolist(self: _HasShapeAndSupportsItem[tuple[int, int], _T], /) -> list[list[_T]]: ...
1964+
@overload
1965+
def tolist(self: _HasShapeAndSupportsItem[tuple[int, int, int], _T], /) -> list[list[list[_T]]]: ...
1966+
@overload
1967+
def tolist(self: _HasShapeAndSupportsItem[Any, _T], /) -> _T | list[_T] | list[list[_T]] | list[list[list[Any]]]: ...
19361968

19371969
@overload
19381970
def resize(self, new_shape: _ShapeLike, /, *, refcheck: builtins.bool = ...) -> None: ...
@@ -4635,7 +4667,7 @@ class matrix(ndarray[_Shape2DType_co, _DType_co]):
46354667
def ptp(self, axis: None | _ShapeLike = ..., out: _NdArraySubClass = ...) -> _NdArraySubClass: ...
46364668

46374669
def squeeze(self, axis: None | _ShapeLike = ...) -> matrix[_Shape2D, _DType_co]: ...
4638-
def tolist(self: matrix[_Shape2D, dtype[_SupportsItem[_T]]]) -> list[list[_T]]: ... # type: ignore[typevar]
4670+
def tolist(self: _SupportsItem[_T]) -> list[list[_T]]: ...
46394671
def ravel(self, order: _OrderKACF = ...) -> matrix[_Shape2D, _DType_co]: ...
46404672
def flatten(self, order: _OrderKACF = ...) -> matrix[_Shape2D, _DType_co]: ...
46414673

numpy/typing/tests/data/pass/numeric.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
from __future__ import annotations
9+
from typing import cast
910

1011
import numpy as np
1112
import numpy.typing as npt
@@ -15,7 +16,10 @@ class SubClass(npt.NDArray[np.float64]):
1516

1617
i8 = np.int64(1)
1718

18-
A = np.arange(27).reshape(3, 3, 3)
19+
A = cast(
20+
np.ndarray[tuple[int, int, int], np.dtype[np.intp]],
21+
np.arange(27).reshape(3, 3, 3),
22+
)
1923
B: list[list[list[int]]] = A.tolist()
2024
C = np.empty((27, 27)).view(SubClass)
2125

numpy/typing/tests/data/reveal/ndarray_conversion.pyi

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,32 @@ import numpy.typing as npt
55

66
from typing_extensions import assert_type
77

8-
nd: npt.NDArray[np.int_]
8+
b1_0d: np.ndarray[tuple[()], np.dtype[np.bool]]
9+
u2_1d: np.ndarray[tuple[int], np.dtype[np.uint16]]
10+
i4_2d: np.ndarray[tuple[int, int], np.dtype[np.int32]]
11+
f8_3d: np.ndarray[tuple[int, int, int], np.dtype[np.float64]]
12+
cG_4d: np.ndarray[tuple[int, int, int, int], np.dtype[np.clongdouble]]
13+
i0_nd: npt.NDArray[np.int_]
914

1015
# item
11-
assert_type(nd.item(), int)
12-
assert_type(nd.item(1), int)
13-
assert_type(nd.item(0, 1), int)
14-
assert_type(nd.item((0, 1)), int)
16+
assert_type(i0_nd.item(), int)
17+
assert_type(i0_nd.item(1), int)
18+
assert_type(i0_nd.item(0, 1), int)
19+
assert_type(i0_nd.item((0, 1)), int)
20+
21+
assert_type(b1_0d.item(()), bool)
22+
assert_type(u2_1d.item((0,)), int)
23+
assert_type(i4_2d.item(-1, 2), int)
24+
assert_type(f8_3d.item(2, 1, -1), float)
25+
assert_type(cG_4d.item(-0xEd_fed_Deb_a_dead_bee), complex) # c'mon Ed, we talked about this...
1526

1627
# tolist
17-
assert_type(nd.tolist(), Any)
28+
assert_type(b1_0d.tolist(), bool)
29+
assert_type(u2_1d.tolist(), list[int])
30+
assert_type(i4_2d.tolist(), list[list[int]])
31+
assert_type(f8_3d.tolist(), list[list[list[float]]])
32+
assert_type(cG_4d.tolist(), complex | list[complex] | list[list[complex]] | list[list[list[Any]]])
33+
assert_type(i0_nd.tolist(), int | list[int] | list[list[int]] | list[list[list[Any]]])
1834

1935
# itemset does not return a value
2036
# tostring is pretty simple
@@ -24,34 +40,34 @@ assert_type(nd.tolist(), Any)
2440
# dumps is pretty simple
2541

2642
# astype
27-
assert_type(nd.astype("float"), npt.NDArray[Any])
28-
assert_type(nd.astype(float), npt.NDArray[Any])
29-
assert_type(nd.astype(np.float64), npt.NDArray[np.float64])
30-
assert_type(nd.astype(np.float64, "K"), npt.NDArray[np.float64])
31-
assert_type(nd.astype(np.float64, "K", "unsafe"), npt.NDArray[np.float64])
32-
assert_type(nd.astype(np.float64, "K", "unsafe", True), npt.NDArray[np.float64])
33-
assert_type(nd.astype(np.float64, "K", "unsafe", True, True), npt.NDArray[np.float64])
43+
assert_type(i0_nd.astype("float"), npt.NDArray[Any])
44+
assert_type(i0_nd.astype(float), npt.NDArray[Any])
45+
assert_type(i0_nd.astype(np.float64), npt.NDArray[np.float64])
46+
assert_type(i0_nd.astype(np.float64, "K"), npt.NDArray[np.float64])
47+
assert_type(i0_nd.astype(np.float64, "K", "unsafe"), npt.NDArray[np.float64])
48+
assert_type(i0_nd.astype(np.float64, "K", "unsafe", True), npt.NDArray[np.float64])
49+
assert_type(i0_nd.astype(np.float64, "K", "unsafe", True, True), npt.NDArray[np.float64])
3450

35-
assert_type(np.astype(nd, np.float64), npt.NDArray[np.float64])
51+
assert_type(np.astype(i0_nd, np.float64), npt.NDArray[np.float64])
3652

3753
# byteswap
38-
assert_type(nd.byteswap(), npt.NDArray[np.int_])
39-
assert_type(nd.byteswap(True), npt.NDArray[np.int_])
54+
assert_type(i0_nd.byteswap(), npt.NDArray[np.int_])
55+
assert_type(i0_nd.byteswap(True), npt.NDArray[np.int_])
4056

4157
# copy
42-
assert_type(nd.copy(), npt.NDArray[np.int_])
43-
assert_type(nd.copy("C"), npt.NDArray[np.int_])
58+
assert_type(i0_nd.copy(), npt.NDArray[np.int_])
59+
assert_type(i0_nd.copy("C"), npt.NDArray[np.int_])
4460

45-
assert_type(nd.view(), npt.NDArray[np.int_])
46-
assert_type(nd.view(np.float64), npt.NDArray[np.float64])
47-
assert_type(nd.view(float), npt.NDArray[Any])
48-
assert_type(nd.view(np.float64, np.matrix), np.matrix[tuple[int, int], Any])
61+
assert_type(i0_nd.view(), npt.NDArray[np.int_])
62+
assert_type(i0_nd.view(np.float64), npt.NDArray[np.float64])
63+
assert_type(i0_nd.view(float), npt.NDArray[Any])
64+
assert_type(i0_nd.view(np.float64, np.matrix), np.matrix[tuple[int, int], Any])
4965

5066
# getfield
51-
assert_type(nd.getfield("float"), npt.NDArray[Any])
52-
assert_type(nd.getfield(float), npt.NDArray[Any])
53-
assert_type(nd.getfield(np.float64), npt.NDArray[np.float64])
54-
assert_type(nd.getfield(np.float64, 8), npt.NDArray[np.float64])
67+
assert_type(i0_nd.getfield("float"), npt.NDArray[Any])
68+
assert_type(i0_nd.getfield(float), npt.NDArray[Any])
69+
assert_type(i0_nd.getfield(np.float64), npt.NDArray[np.float64])
70+
assert_type(i0_nd.getfield(np.float64, 8), npt.NDArray[np.float64])
5571

5672
# setflags does not return a value
5773
# fill does not return a value

0 commit comments

Comments
 (0)