Skip to content

Commit 19c0728

Browse files
authored
TYP: Transparent __array__ shape-type (numpy#26927)
This changes the ndarray.__array__ and flatiter.__array__ methods to return a ndarray with the same shape type. Due to technical limitations, flatiter will it only return the shape type of its underlying ndarray if it's 1-d (like tuple[int]).
1 parent d19263a commit 19c0728

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

numpy/__init__.pyi

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,7 @@ _ArrayLikeInt: TypeAlias = (
901901
)
902902

903903
_FlatIterSelf = TypeVar("_FlatIterSelf", bound=flatiter[Any])
904+
_FlatShapeType = TypeVar("_FlatShapeType", bound=tuple[int])
904905

905906
@final
906907
class flatiter(Generic[_NdArraySubClass]):
@@ -935,6 +936,10 @@ class flatiter(Generic[_NdArraySubClass]):
935936
value: Any,
936937
) -> None: ...
937938
@overload
939+
def __array__(self: flatiter[ndarray[_FlatShapeType, _DType]], dtype: None = ..., /) -> ndarray[_FlatShapeType, _DType]: ...
940+
@overload
941+
def __array__(self: flatiter[ndarray[_FlatShapeType, Any]], dtype: _DType, /) -> ndarray[_FlatShapeType, _DType]: ...
942+
@overload
938943
def __array__(self: flatiter[ndarray[Any, _DType]], dtype: None = ..., /) -> ndarray[Any, _DType]: ...
939944
@overload
940945
def __array__(self, dtype: _DType, /) -> ndarray[Any, _DType]: ...
@@ -1469,11 +1474,11 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
14691474
@overload
14701475
def __array__(
14711476
self, dtype: None = ..., /, *, copy: None | bool = ...
1472-
) -> ndarray[Any, _DType_co]: ...
1477+
) -> ndarray[_ShapeType, _DType_co]: ...
14731478
@overload
14741479
def __array__(
14751480
self, dtype: _DType, /, *, copy: None | bool = ...
1476-
) -> ndarray[Any, _DType]: ...
1481+
) -> ndarray[_ShapeType, _DType]: ...
14771482

14781483
def __array_ufunc__(
14791484
self,
@@ -1704,11 +1709,13 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
17041709
axis: None | SupportsIndex = ...,
17051710
) -> ndarray[Any, _DType_co]: ...
17061711

1712+
# TODO: use `tuple[int]` as shape type once covariant (#26081)
17071713
def flatten(
17081714
self,
17091715
order: _OrderKACF = ...,
17101716
) -> ndarray[Any, _DType_co]: ...
17111717

1718+
# TODO: use `tuple[int]` as shape type once covariant (#26081)
17121719
def ravel(
17131720
self,
17141721
order: _OrderKACF = ...,
@@ -2613,6 +2620,7 @@ _NBit2 = TypeVar("_NBit2", bound=NBitBase)
26132620
class generic(_ArrayOrScalarCommon):
26142621
@abstractmethod
26152622
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
2623+
# TODO: use `tuple[()]` as shape type once covariant (#26081)
26162624
@overload
26172625
def __array__(self: _ScalarType, dtype: None = ..., /) -> NDArray[_ScalarType]: ...
26182626
@overload
@@ -3740,6 +3748,7 @@ class poly1d:
37403748

37413749
__hash__: ClassVar[None] # type: ignore
37423750

3751+
# TODO: use `tuple[int]` as shape type once covariant (#26081)
37433752
@overload
37443753
def __array__(self, t: None = ..., copy: None | bool = ...) -> NDArray[Any]: ...
37453754
@overload

numpy/typing/tests/data/reveal/flatiter.pyi

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sys
2-
from typing import Any
2+
from typing import Any, Literal, TypeAlias
33

44
import numpy as np
55
import numpy.typing as npt
@@ -10,6 +10,10 @@ else:
1010
from typing_extensions import assert_type
1111

1212
a: np.flatiter[npt.NDArray[np.str_]]
13+
a_1d: np.flatiter[np.ndarray[tuple[int], np.dtype[np.bytes_]]]
14+
15+
Size: TypeAlias = Literal[42]
16+
a_1d_fixed: np.flatiter[np.ndarray[tuple[Size], np.dtype[np.object_]]]
1317

1418
assert_type(a.base, npt.NDArray[np.str_])
1519
assert_type(a.copy(), npt.NDArray[np.str_])
@@ -23,8 +27,26 @@ assert_type(a[...], npt.NDArray[np.str_])
2327
assert_type(a[:], npt.NDArray[np.str_])
2428
assert_type(a[(...,)], npt.NDArray[np.str_])
2529
assert_type(a[(0,)], np.str_)
30+
2631
assert_type(a.__array__(), npt.NDArray[np.str_])
2732
assert_type(a.__array__(np.dtype(np.float64)), npt.NDArray[np.float64])
33+
assert_type(
34+
a_1d.__array__(),
35+
np.ndarray[tuple[int], np.dtype[np.bytes_]],
36+
)
37+
assert_type(
38+
a_1d.__array__(np.dtype(np.float64)),
39+
np.ndarray[tuple[int], np.dtype[np.float64]],
40+
)
41+
assert_type(
42+
a_1d_fixed.__array__(),
43+
np.ndarray[tuple[Size], np.dtype[np.object_]],
44+
)
45+
assert_type(
46+
a_1d_fixed.__array__(np.dtype(np.float64)),
47+
np.ndarray[tuple[Size], np.dtype[np.float64]],
48+
)
49+
2850
a[0] = "a"
2951
a[:5] = "a"
3052
a[...] = "a"

0 commit comments

Comments
 (0)