Skip to content

Commit 1d82bcb

Browse files
author
Bas van Beek
committed
ENH: Add annotations for np.lib.shape_base
1 parent b32b72e commit 1d82bcb

File tree

1 file changed

+208
-17
lines changed

1 file changed

+208
-17
lines changed

numpy/lib/shape_base.pyi

Lines changed: 208 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,215 @@
1-
from typing import List
1+
from typing import List, TypeVar, Callable, Sequence, Any, overload, Tuple
2+
from typing_extensions import SupportsIndex, Protocol
3+
4+
from numpy import (
5+
generic,
6+
integer,
7+
dtype,
8+
ufunc,
9+
bool_,
10+
unsignedinteger,
11+
signedinteger,
12+
floating,
13+
complexfloating,
14+
object_,
15+
)
16+
17+
from numpy.typing import (
18+
ArrayLike,
19+
NDArray,
20+
_ShapeLike,
21+
_NestedSequence,
22+
_SupportsDType,
23+
_ArrayLikeBool_co,
24+
_ArrayLikeUInt_co,
25+
_ArrayLikeInt_co,
26+
_ArrayLikeFloat_co,
27+
_ArrayLikeComplex_co,
28+
_ArrayLikeObject_co,
29+
)
230

331
from numpy.core.shape_base import vstack
432

33+
_SCT = TypeVar("_SCT", bound=generic)
34+
35+
_ArrayLike = _NestedSequence[_SupportsDType[dtype[_SCT]]]
36+
37+
# The signatures of `__array_wrap__` and `__array_prepare__` are the same;
38+
# give them unique names for the sake of clarity
39+
class _ArrayWrap(Protocol):
40+
def __call__(
41+
self,
42+
__array: NDArray[Any],
43+
__context: None | Tuple[ufunc, Tuple[Any, ...], int] = ...,
44+
) -> Any: ...
45+
46+
class _ArrayPrepare(Protocol):
47+
def __call__(
48+
self,
49+
__array: NDArray[Any],
50+
__context: None | Tuple[ufunc, Tuple[Any, ...], int] = ...,
51+
) -> Any: ...
52+
53+
class _SupportsArrayWrap(Protocol):
54+
@property
55+
def __array_wrap__(self) -> _ArrayWrap: ...
56+
57+
class _SupportsArrayPrepare(Protocol):
58+
@property
59+
def __array_prepare__(self) -> _ArrayPrepare: ...
60+
561
__all__: List[str]
662

763
row_stack = vstack
864

9-
def take_along_axis(arr, indices, axis): ...
10-
def put_along_axis(arr, indices, values, axis): ...
11-
def apply_along_axis(func1d, axis, arr, *args, **kwargs): ...
12-
def apply_over_axes(func, a, axes): ...
13-
def expand_dims(a, axis): ...
14-
def column_stack(tup): ...
15-
def dstack(tup): ...
16-
def array_split(ary, indices_or_sections, axis=...): ...
17-
def split(ary, indices_or_sections, axis=...): ...
18-
def hsplit(ary, indices_or_sections): ...
19-
def vsplit(ary, indices_or_sections): ...
20-
def dsplit(ary, indices_or_sections): ...
21-
def get_array_prepare(*args): ...
22-
def get_array_wrap(*args): ...
23-
def kron(a, b): ...
24-
def tile(A, reps): ...
65+
def take_along_axis(
66+
arr: _SCT | NDArray[_SCT],
67+
indices: NDArray[integer[Any]],
68+
axis: None | int,
69+
) -> NDArray[_SCT]: ...
70+
71+
def put_along_axis(
72+
arr: NDArray[_SCT],
73+
indices: NDArray[integer[Any]],
74+
values: ArrayLike,
75+
axis: None | int,
76+
) -> None: ...
77+
78+
@overload
79+
def apply_along_axis(
80+
func1d: Callable[..., _ArrayLike[_SCT]],
81+
axis: SupportsIndex,
82+
arr: ArrayLike,
83+
*args: Any,
84+
**kwargs: Any,
85+
) -> NDArray[_SCT]: ...
86+
@overload
87+
def apply_along_axis(
88+
func1d: Callable[..., ArrayLike],
89+
axis: SupportsIndex,
90+
arr: ArrayLike,
91+
*args: Any,
92+
**kwargs: Any,
93+
) -> NDArray[Any]: ...
94+
95+
def apply_over_axes(
96+
func: Callable[[NDArray[Any], int], NDArray[_SCT]],
97+
a: ArrayLike,
98+
axes: int | Sequence[int],
99+
) -> NDArray[_SCT]: ...
100+
101+
@overload
102+
def expand_dims(
103+
a: _ArrayLike[_SCT],
104+
axis: _ShapeLike,
105+
) -> NDArray[_SCT]: ...
106+
@overload
107+
def expand_dims(
108+
a: ArrayLike,
109+
axis: _ShapeLike,
110+
) -> NDArray[Any]: ...
111+
112+
@overload
113+
def column_stack(tup: Sequence[_ArrayLike[_SCT]]) -> NDArray[_SCT]: ...
114+
@overload
115+
def column_stack(tup: Sequence[ArrayLike]) -> NDArray[Any]: ...
116+
117+
@overload
118+
def dstack(tup: Sequence[_ArrayLike[_SCT]]) -> NDArray[_SCT]: ...
119+
@overload
120+
def dstack(tup: Sequence[ArrayLike]) -> NDArray[Any]: ...
121+
122+
@overload
123+
def array_split(
124+
ary: _ArrayLike[_SCT],
125+
indices_or_sections: _ShapeLike,
126+
axis: SupportsIndex = ...,
127+
) -> List[NDArray[_SCT]]: ...
128+
@overload
129+
def array_split(
130+
ary: ArrayLike,
131+
indices_or_sections: _ShapeLike,
132+
axis: SupportsIndex = ...,
133+
) -> List[NDArray[Any]]: ...
134+
135+
@overload
136+
def split(
137+
ary: _ArrayLike[_SCT],
138+
indices_or_sections: _ShapeLike,
139+
axis: SupportsIndex = ...,
140+
) -> List[NDArray[_SCT]]: ...
141+
@overload
142+
def split(
143+
ary: ArrayLike,
144+
indices_or_sections: _ShapeLike,
145+
axis: SupportsIndex = ...,
146+
) -> List[NDArray[Any]]: ...
147+
148+
@overload
149+
def hsplit(
150+
ary: _ArrayLike[_SCT],
151+
indices_or_sections: _ShapeLike,
152+
) -> List[NDArray[_SCT]]: ...
153+
@overload
154+
def hsplit(
155+
ary: ArrayLike,
156+
indices_or_sections: _ShapeLike,
157+
) -> List[NDArray[Any]]: ...
158+
159+
@overload
160+
def vsplit(
161+
ary: _ArrayLike[_SCT],
162+
indices_or_sections: _ShapeLike,
163+
) -> List[NDArray[_SCT]]: ...
164+
@overload
165+
def vsplit(
166+
ary: ArrayLike,
167+
indices_or_sections: _ShapeLike,
168+
) -> List[NDArray[Any]]: ...
169+
170+
@overload
171+
def dsplit(
172+
ary: _ArrayLike[_SCT],
173+
indices_or_sections: _ShapeLike,
174+
) -> List[NDArray[_SCT]]: ...
175+
@overload
176+
def dsplit(
177+
ary: ArrayLike,
178+
indices_or_sections: _ShapeLike,
179+
) -> List[NDArray[Any]]: ...
180+
181+
@overload
182+
def get_array_prepare(*args: _SupportsArrayPrepare) -> _ArrayPrepare: ...
183+
@overload
184+
def get_array_prepare(*args: object) -> None | _ArrayPrepare: ...
185+
186+
@overload
187+
def get_array_wrap(*args: _SupportsArrayWrap) -> _ArrayWrap: ...
188+
@overload
189+
def get_array_wrap(*args: object) -> None | _ArrayWrap: ...
190+
191+
@overload
192+
def kron(a: _ArrayLikeBool_co, b: _ArrayLikeBool_co) -> NDArray[bool_]: ... # type: ignore[misc]
193+
@overload
194+
def kron(a: _ArrayLikeUInt_co, b: _ArrayLikeUInt_co) -> NDArray[unsignedinteger[Any]]: ... # type: ignore[misc]
195+
@overload
196+
def kron(a: _ArrayLikeInt_co, b: _ArrayLikeInt_co) -> NDArray[signedinteger[Any]]: ... # type: ignore[misc]
197+
@overload
198+
def kron(a: _ArrayLikeFloat_co, b: _ArrayLikeFloat_co) -> NDArray[floating[Any]]: ... # type: ignore[misc]
199+
@overload
200+
def kron(a: _ArrayLikeComplex_co, b: _ArrayLikeComplex_co) -> NDArray[complexfloating[Any, Any]]: ...
201+
@overload
202+
def kron(a: _ArrayLikeObject_co, b: Any) -> NDArray[object_]: ...
203+
@overload
204+
def kron(a: Any, b: _ArrayLikeObject_co) -> NDArray[object_]: ...
205+
206+
@overload
207+
def tile(
208+
A: _ArrayLike[_SCT],
209+
reps: int | Sequence[int],
210+
) -> NDArray[_SCT]: ...
211+
@overload
212+
def tile(
213+
A: ArrayLike,
214+
reps: int | Sequence[int],
215+
) -> NDArray[Any]: ...

0 commit comments

Comments
 (0)