Skip to content

Commit 4bd4494

Browse files
authored
Merge pull request numpy#19444 from BvB93/shape_base
ENH: Add annotations for `np.lib.shape_base`
2 parents 45cfd11 + a8699e2 commit 4bd4494

File tree

5 files changed

+296
-27
lines changed

5 files changed

+296
-27
lines changed

numpy/__init__.pyi

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,10 +1219,9 @@ class _ArrayOrScalarCommon:
12191219
@property
12201220
def __array_interface__(self): ...
12211221
@property
1222-
def __array_priority__(self): ...
1222+
def __array_priority__(self) -> float: ...
12231223
@property
12241224
def __array_struct__(self): ...
1225-
def __array_wrap__(array, context=...): ...
12261225
def __setstate__(self, __state): ...
12271226
# a `bool_` is returned when `keepdims=True` and `self` is a 0d array
12281227

@@ -1599,6 +1598,7 @@ _FlexDType = TypeVar("_FlexDType", bound=dtype[flexible])
15991598
# TODO: Set the `bound` to something more suitable once we
16001599
# have proper shape support
16011600
_ShapeType = TypeVar("_ShapeType", bound=Any)
1601+
_ShapeType2 = TypeVar("_ShapeType2", bound=Any)
16021602
_NumberType = TypeVar("_NumberType", bound=number[Any])
16031603

16041604
# There is currently no exhaustive way to type the buffer protocol,
@@ -1674,6 +1674,19 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
16741674
def __array__(self, __dtype: None = ...) -> ndarray[Any, _DType_co]: ...
16751675
@overload
16761676
def __array__(self, __dtype: _DType) -> ndarray[Any, _DType]: ...
1677+
1678+
def __array_wrap__(
1679+
self,
1680+
__array: ndarray[_ShapeType2, _DType],
1681+
__context: None | Tuple[ufunc, Tuple[Any, ...], int] = ...,
1682+
) -> ndarray[_ShapeType2, _DType]: ...
1683+
1684+
def __array_prepare__(
1685+
self,
1686+
__array: ndarray[_ShapeType2, _DType],
1687+
__context: None | Tuple[ufunc, Tuple[Any, ...], int] = ...,
1688+
) -> ndarray[_ShapeType2, _DType]: ...
1689+
16771690
@property
16781691
def ctypes(self) -> _ctypes[int]: ...
16791692
@property

numpy/core/_add_newdocs.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,7 +1585,7 @@
15851585
For integer arguments the function is equivalent to the Python built-in
15861586
`range` function, but returns an ndarray rather than a list.
15871587
1588-
When using a non-integer step, such as 0.1, it is often better to use
1588+
When using a non-integer step, such as 0.1, it is often better to use
15891589
`numpy.linspace`. See the warnings section below for more information.
15901590
15911591
Parameters
@@ -2771,13 +2771,17 @@
27712771

27722772

27732773
add_newdoc('numpy.core.multiarray', 'ndarray', ('__array_prepare__',
2774-
"""a.__array_prepare__(obj) -> Object of same type as ndarray object obj.
2774+
"""a.__array_prepare__(array[, context], /)
2775+
2776+
Returns a view of `array` with the same type as self.
27752777
27762778
"""))
27772779

27782780

27792781
add_newdoc('numpy.core.multiarray', 'ndarray', ('__array_wrap__',
2780-
"""a.__array_wrap__(obj) -> Object of same type as ndarray object a.
2782+
"""a.__array_wrap__(array[, context], /)
2783+
2784+
Returns a view of `array` with the same type as self.
27812785
27822786
"""))
27832787

numpy/lib/shape_base.pyi

Lines changed: 208 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,215 @@
1-
from typing import List
1+
from typing import List, TypeVar, Callable, Sequence, Any, overload, Tuple
2+
from typing_extensions import SupportsIndex, Protocol
3+
4+
from numpy import (
5+
generic,
6+
integer,
7+
dtype,
8+
ufunc,
9+
bool_,
10+
unsignedinteger,
11+
signedinteger,
12+
floating,
13+
complexfloating,
14+
object_,
15+
)
16+
17+
from numpy.typing import (
18+
ArrayLike,
19+
NDArray,
20+
_ShapeLike,
21+
_NestedSequence,
22+
_SupportsDType,
23+
_ArrayLikeBool_co,
24+
_ArrayLikeUInt_co,
25+
_ArrayLikeInt_co,
26+
_ArrayLikeFloat_co,
27+
_ArrayLikeComplex_co,
28+
_ArrayLikeObject_co,
29+
)
230

331
from numpy.core.shape_base import vstack
432

33+
_SCT = TypeVar("_SCT", bound=generic)
34+
35+
_ArrayLike = _NestedSequence[_SupportsDType[dtype[_SCT]]]
36+
37+
# The signatures of `__array_wrap__` and `__array_prepare__` are the same;
38+
# give them unique names for the sake of clarity
39+
class _ArrayWrap(Protocol):
40+
def __call__(
41+
self,
42+
__array: NDArray[Any],
43+
__context: None | Tuple[ufunc, Tuple[Any, ...], int] = ...,
44+
) -> Any: ...
45+
46+
class _ArrayPrepare(Protocol):
47+
def __call__(
48+
self,
49+
__array: NDArray[Any],
50+
__context: None | Tuple[ufunc, Tuple[Any, ...], int] = ...,
51+
) -> Any: ...
52+
53+
class _SupportsArrayWrap(Protocol):
54+
@property
55+
def __array_wrap__(self) -> _ArrayWrap: ...
56+
57+
class _SupportsArrayPrepare(Protocol):
58+
@property
59+
def __array_prepare__(self) -> _ArrayPrepare: ...
60+
561
__all__: List[str]
662

763
row_stack = vstack
864

9-
def take_along_axis(arr, indices, axis): ...
10-
def put_along_axis(arr, indices, values, axis): ...
11-
def apply_along_axis(func1d, axis, arr, *args, **kwargs): ...
12-
def apply_over_axes(func, a, axes): ...
13-
def expand_dims(a, axis): ...
14-
def column_stack(tup): ...
15-
def dstack(tup): ...
16-
def array_split(ary, indices_or_sections, axis=...): ...
17-
def split(ary, indices_or_sections, axis=...): ...
18-
def hsplit(ary, indices_or_sections): ...
19-
def vsplit(ary, indices_or_sections): ...
20-
def dsplit(ary, indices_or_sections): ...
21-
def get_array_prepare(*args): ...
22-
def get_array_wrap(*args): ...
23-
def kron(a, b): ...
24-
def tile(A, reps): ...
65+
def take_along_axis(
66+
arr: _SCT | NDArray[_SCT],
67+
indices: NDArray[integer[Any]],
68+
axis: None | int,
69+
) -> NDArray[_SCT]: ...
70+
71+
def put_along_axis(
72+
arr: NDArray[_SCT],
73+
indices: NDArray[integer[Any]],
74+
values: ArrayLike,
75+
axis: None | int,
76+
) -> None: ...
77+
78+
@overload
79+
def apply_along_axis(
80+
func1d: Callable[..., _ArrayLike[_SCT]],
81+
axis: SupportsIndex,
82+
arr: ArrayLike,
83+
*args: Any,
84+
**kwargs: Any,
85+
) -> NDArray[_SCT]: ...
86+
@overload
87+
def apply_along_axis(
88+
func1d: Callable[..., ArrayLike],
89+
axis: SupportsIndex,
90+
arr: ArrayLike,
91+
*args: Any,
92+
**kwargs: Any,
93+
) -> NDArray[Any]: ...
94+
95+
def apply_over_axes(
96+
func: Callable[[NDArray[Any], int], NDArray[_SCT]],
97+
a: ArrayLike,
98+
axes: int | Sequence[int],
99+
) -> NDArray[_SCT]: ...
100+
101+
@overload
102+
def expand_dims(
103+
a: _ArrayLike[_SCT],
104+
axis: _ShapeLike,
105+
) -> NDArray[_SCT]: ...
106+
@overload
107+
def expand_dims(
108+
a: ArrayLike,
109+
axis: _ShapeLike,
110+
) -> NDArray[Any]: ...
111+
112+
@overload
113+
def column_stack(tup: Sequence[_ArrayLike[_SCT]]) -> NDArray[_SCT]: ...
114+
@overload
115+
def column_stack(tup: Sequence[ArrayLike]) -> NDArray[Any]: ...
116+
117+
@overload
118+
def dstack(tup: Sequence[_ArrayLike[_SCT]]) -> NDArray[_SCT]: ...
119+
@overload
120+
def dstack(tup: Sequence[ArrayLike]) -> NDArray[Any]: ...
121+
122+
@overload
123+
def array_split(
124+
ary: _ArrayLike[_SCT],
125+
indices_or_sections: _ShapeLike,
126+
axis: SupportsIndex = ...,
127+
) -> List[NDArray[_SCT]]: ...
128+
@overload
129+
def array_split(
130+
ary: ArrayLike,
131+
indices_or_sections: _ShapeLike,
132+
axis: SupportsIndex = ...,
133+
) -> List[NDArray[Any]]: ...
134+
135+
@overload
136+
def split(
137+
ary: _ArrayLike[_SCT],
138+
indices_or_sections: _ShapeLike,
139+
axis: SupportsIndex = ...,
140+
) -> List[NDArray[_SCT]]: ...
141+
@overload
142+
def split(
143+
ary: ArrayLike,
144+
indices_or_sections: _ShapeLike,
145+
axis: SupportsIndex = ...,
146+
) -> List[NDArray[Any]]: ...
147+
148+
@overload
149+
def hsplit(
150+
ary: _ArrayLike[_SCT],
151+
indices_or_sections: _ShapeLike,
152+
) -> List[NDArray[_SCT]]: ...
153+
@overload
154+
def hsplit(
155+
ary: ArrayLike,
156+
indices_or_sections: _ShapeLike,
157+
) -> List[NDArray[Any]]: ...
158+
159+
@overload
160+
def vsplit(
161+
ary: _ArrayLike[_SCT],
162+
indices_or_sections: _ShapeLike,
163+
) -> List[NDArray[_SCT]]: ...
164+
@overload
165+
def vsplit(
166+
ary: ArrayLike,
167+
indices_or_sections: _ShapeLike,
168+
) -> List[NDArray[Any]]: ...
169+
170+
@overload
171+
def dsplit(
172+
ary: _ArrayLike[_SCT],
173+
indices_or_sections: _ShapeLike,
174+
) -> List[NDArray[_SCT]]: ...
175+
@overload
176+
def dsplit(
177+
ary: ArrayLike,
178+
indices_or_sections: _ShapeLike,
179+
) -> List[NDArray[Any]]: ...
180+
181+
@overload
182+
def get_array_prepare(*args: _SupportsArrayPrepare) -> _ArrayPrepare: ...
183+
@overload
184+
def get_array_prepare(*args: object) -> None | _ArrayPrepare: ...
185+
186+
@overload
187+
def get_array_wrap(*args: _SupportsArrayWrap) -> _ArrayWrap: ...
188+
@overload
189+
def get_array_wrap(*args: object) -> None | _ArrayWrap: ...
190+
191+
@overload
192+
def kron(a: _ArrayLikeBool_co, b: _ArrayLikeBool_co) -> NDArray[bool_]: ... # type: ignore[misc]
193+
@overload
194+
def kron(a: _ArrayLikeUInt_co, b: _ArrayLikeUInt_co) -> NDArray[unsignedinteger[Any]]: ... # type: ignore[misc]
195+
@overload
196+
def kron(a: _ArrayLikeInt_co, b: _ArrayLikeInt_co) -> NDArray[signedinteger[Any]]: ... # type: ignore[misc]
197+
@overload
198+
def kron(a: _ArrayLikeFloat_co, b: _ArrayLikeFloat_co) -> NDArray[floating[Any]]: ... # type: ignore[misc]
199+
@overload
200+
def kron(a: _ArrayLikeComplex_co, b: _ArrayLikeComplex_co) -> NDArray[complexfloating[Any, Any]]: ...
201+
@overload
202+
def kron(a: _ArrayLikeObject_co, b: Any) -> NDArray[object_]: ...
203+
@overload
204+
def kron(a: Any, b: _ArrayLikeObject_co) -> NDArray[object_]: ...
205+
206+
@overload
207+
def tile(
208+
A: _ArrayLike[_SCT],
209+
reps: int | Sequence[int],
210+
) -> NDArray[_SCT]: ...
211+
@overload
212+
def tile(
213+
A: ArrayLike,
214+
reps: int | Sequence[int],
215+
) -> NDArray[Any]: ...

numpy/typing/tests/data/reveal/ndarray_misc.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
from typing import Any
1212

1313
import numpy as np
14+
from numpy.typing import NDArray
1415

15-
class SubClass(np.ndarray): ...
16+
class SubClass(NDArray[np.object_]): ...
1617

1718
f8: np.float64
1819
B: SubClass
19-
AR_f8: np.ndarray[Any, np.dtype[np.float64]]
20-
AR_i8: np.ndarray[Any, np.dtype[np.int64]]
21-
AR_U: np.ndarray[Any, np.dtype[np.str_]]
20+
AR_f8: NDArray[np.float64]
21+
AR_i8: NDArray[np.int64]
22+
AR_U: NDArray[np.str_]
2223

2324
ctypes_obj = AR_f8.ctypes
2425

@@ -126,7 +127,7 @@ class SubClass(np.ndarray): ...
126127

127128
reveal_type(f8.repeat(1)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
128129
reveal_type(AR_f8.repeat(1)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
129-
reveal_type(B.repeat(1)) # E: numpy.ndarray[Any, Any]
130+
reveal_type(B.repeat(1)) # E: numpy.ndarray[Any, numpy.dtype[numpy.object_]]
130131

131132
reveal_type(f8.std()) # E: Any
132133
reveal_type(AR_f8.std()) # E: Any
@@ -189,3 +190,6 @@ class SubClass(np.ndarray): ...
189190
reveal_type(complex(AR_f8)) # E: complex
190191

191192
reveal_type(operator.index(AR_i8)) # E: int
193+
194+
reveal_type(AR_f8.__array_prepare__(B)) # E: numpy.ndarray[Any, numpy.dtype[numpy.object_]]
195+
reveal_type(AR_f8.__array_wrap__(B)) # E: numpy.ndarray[Any, numpy.dtype[numpy.object_]]
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import numpy as np
2+
from numpy.typing import NDArray
3+
from typing import Any, List
4+
5+
i8: np.int64
6+
f8: np.float64
7+
8+
AR_b: NDArray[np.bool_]
9+
AR_i8: NDArray[np.int64]
10+
AR_f8: NDArray[np.float64]
11+
12+
AR_LIKE_f8: List[float]
13+
14+
reveal_type(np.take_along_axis(AR_f8, AR_i8, axis=1)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
15+
reveal_type(np.take_along_axis(f8, AR_i8, axis=None)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
16+
17+
reveal_type(np.put_along_axis(AR_f8, AR_i8, "1.0", axis=1)) # E: None
18+
19+
reveal_type(np.expand_dims(AR_i8, 2)) # E: numpy.ndarray[Any, numpy.dtype[{int64}]]
20+
reveal_type(np.expand_dims(AR_LIKE_f8, 2)) # E: numpy.ndarray[Any, numpy.dtype[Any]]
21+
22+
reveal_type(np.column_stack([AR_i8])) # E: numpy.ndarray[Any, numpy.dtype[{int64}]]
23+
reveal_type(np.column_stack([AR_LIKE_f8])) # E: numpy.ndarray[Any, numpy.dtype[Any]]
24+
25+
reveal_type(np.dstack([AR_i8])) # E: numpy.ndarray[Any, numpy.dtype[{int64}]]
26+
reveal_type(np.dstack([AR_LIKE_f8])) # E: numpy.ndarray[Any, numpy.dtype[Any]]
27+
28+
reveal_type(np.row_stack([AR_i8])) # E: numpy.ndarray[Any, numpy.dtype[{int64}]]
29+
reveal_type(np.row_stack([AR_LIKE_f8])) # E: numpy.ndarray[Any, numpy.dtype[Any]]
30+
31+
reveal_type(np.array_split(AR_i8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[{int64}]]]
32+
reveal_type(np.array_split(AR_LIKE_f8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[Any]]]
33+
34+
reveal_type(np.split(AR_i8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[{int64}]]]
35+
reveal_type(np.split(AR_LIKE_f8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[Any]]]
36+
37+
reveal_type(np.hsplit(AR_i8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[{int64}]]]
38+
reveal_type(np.hsplit(AR_LIKE_f8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[Any]]]
39+
40+
reveal_type(np.vsplit(AR_i8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[{int64}]]]
41+
reveal_type(np.vsplit(AR_LIKE_f8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[Any]]]
42+
43+
reveal_type(np.dsplit(AR_i8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[{int64}]]]
44+
reveal_type(np.dsplit(AR_LIKE_f8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[Any]]]
45+
46+
reveal_type(np.lib.shape_base.get_array_prepare(AR_i8)) # E: numpy.lib.shape_base._ArrayPrepare
47+
reveal_type(np.lib.shape_base.get_array_prepare(AR_i8, 1)) # E: Union[None, numpy.lib.shape_base._ArrayPrepare]
48+
49+
reveal_type(np.get_array_wrap(AR_i8)) # E: numpy.lib.shape_base._ArrayWrap
50+
reveal_type(np.get_array_wrap(AR_i8, 1)) # E: Union[None, numpy.lib.shape_base._ArrayWrap]
51+
52+
reveal_type(np.kron(AR_b, AR_b)) # E: numpy.ndarray[Any, numpy.dtype[numpy.bool_]]
53+
reveal_type(np.kron(AR_b, AR_i8)) # E: numpy.ndarray[Any, numpy.dtype[numpy.signedinteger[Any]]]
54+
reveal_type(np.kron(AR_f8, AR_f8)) # E: numpy.ndarray[Any, numpy.dtype[numpy.floating[Any]]]
55+
56+
reveal_type(np.tile(AR_i8, 5)) # E: numpy.ndarray[Any, numpy.dtype[{int64}]]
57+
reveal_type(np.tile(AR_LIKE_f8, [2, 2])) # E: numpy.ndarray[Any, numpy.dtype[Any]]

0 commit comments

Comments
 (0)