Skip to content

Commit 19caf5b

Browse files
authored
Merge pull request numpy#19887 from BvB93/linalg
ENH: Add annotations for `np.linalg`
2 parents cf09bbf + f8958a2 commit 19caf5b

File tree

4 files changed

+450
-21
lines changed

4 files changed

+450
-21
lines changed

numpy/linalg/__init__.pyi

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,32 @@
11
from typing import Any, List
22

3+
from numpy.linalg.linalg import (
4+
matrix_power as matrix_power,
5+
solve as solve,
6+
tensorsolve as tensorsolve,
7+
tensorinv as tensorinv,
8+
inv as inv,
9+
cholesky as cholesky,
10+
eigvals as eigvals,
11+
eigvalsh as eigvalsh,
12+
pinv as pinv,
13+
slogdet as slogdet,
14+
det as det,
15+
svd as svd,
16+
eig as eig,
17+
eigh as eigh,
18+
lstsq as lstsq,
19+
norm as norm,
20+
qr as qr,
21+
cond as cond,
22+
matrix_rank as matrix_rank,
23+
multi_dot as multi_dot,
24+
)
25+
326
from numpy._pytesttester import PytestTester
427

528
__all__: List[str]
629
__path__: List[str]
730
test: PytestTester
831

932
class LinAlgError(Exception): ...
10-
11-
def tensorsolve(a, b, axes=...): ...
12-
def solve(a, b): ...
13-
def tensorinv(a, ind=...): ...
14-
def inv(a): ...
15-
def matrix_power(a, n): ...
16-
def cholesky(a): ...
17-
def qr(a, mode=...): ...
18-
def eigvals(a): ...
19-
def eigvalsh(a, UPLO=...): ...
20-
def eig(a): ...
21-
def eigh(a, UPLO=...): ...
22-
def svd(a, full_matrices=..., compute_uv=..., hermitian=...): ...
23-
def cond(x, p=...): ...
24-
def matrix_rank(A, tol=..., hermitian=...): ...
25-
def pinv(a, rcond=..., hermitian=...): ...
26-
def slogdet(a): ...
27-
def det(a): ...
28-
def lstsq(a, b, rcond=...): ...
29-
def norm(x, ord=..., axis=..., keepdims=...): ...
30-
def multi_dot(arrays, *, out=...): ...

numpy/linalg/linalg.pyi

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
from typing import (
2+
Literal as L,
3+
List,
4+
Iterable,
5+
overload,
6+
TypeVar,
7+
Any,
8+
SupportsIndex,
9+
SupportsInt,
10+
Tuple,
11+
)
12+
13+
from numpy import (
14+
generic,
15+
floating,
16+
complexfloating,
17+
int32,
18+
float64,
19+
complex128,
20+
)
21+
22+
from numpy.typing import (
23+
NDArray,
24+
ArrayLike,
25+
_ArrayLikeInt_co,
26+
_ArrayLikeFloat_co,
27+
_ArrayLikeComplex_co,
28+
_ArrayLikeTD64_co,
29+
_ArrayLikeObject_co,
30+
)
31+
32+
_T = TypeVar("_T")
33+
_ArrayType = TypeVar("_ArrayType", bound=NDArray[Any])
34+
35+
_2Tuple = Tuple[_T, _T]
36+
_ModeKind = L["reduced", "complete", "r", "raw"]
37+
38+
__all__: List[str]
39+
40+
@overload
41+
def tensorsolve(
42+
a: _ArrayLikeInt_co,
43+
b: _ArrayLikeInt_co,
44+
axes: None | Iterable[int] =...,
45+
) -> NDArray[float64]: ...
46+
@overload
47+
def tensorsolve(
48+
a: _ArrayLikeFloat_co,
49+
b: _ArrayLikeFloat_co,
50+
axes: None | Iterable[int] =...,
51+
) -> NDArray[floating[Any]]: ...
52+
@overload
53+
def tensorsolve(
54+
a: _ArrayLikeComplex_co,
55+
b: _ArrayLikeComplex_co,
56+
axes: None | Iterable[int] =...,
57+
) -> NDArray[complexfloating[Any, Any]]: ...
58+
59+
@overload
60+
def solve(
61+
a: _ArrayLikeInt_co,
62+
b: _ArrayLikeInt_co,
63+
) -> NDArray[float64]: ...
64+
@overload
65+
def solve(
66+
a: _ArrayLikeFloat_co,
67+
b: _ArrayLikeFloat_co,
68+
) -> NDArray[floating[Any]]: ...
69+
@overload
70+
def solve(
71+
a: _ArrayLikeComplex_co,
72+
b: _ArrayLikeComplex_co,
73+
) -> NDArray[complexfloating[Any, Any]]: ...
74+
75+
@overload
76+
def tensorinv(
77+
a: _ArrayLikeInt_co,
78+
ind: int = ...,
79+
) -> NDArray[float64]: ...
80+
@overload
81+
def tensorinv(
82+
a: _ArrayLikeFloat_co,
83+
ind: int = ...,
84+
) -> NDArray[floating[Any]]: ...
85+
@overload
86+
def tensorinv(
87+
a: _ArrayLikeComplex_co,
88+
ind: int = ...,
89+
) -> NDArray[complexfloating[Any, Any]]: ...
90+
91+
@overload
92+
def inv(a: _ArrayLikeInt_co) -> NDArray[float64]: ...
93+
@overload
94+
def inv(a: _ArrayLikeFloat_co) -> NDArray[floating[Any]]: ...
95+
@overload
96+
def inv(a: _ArrayLikeComplex_co) -> NDArray[complexfloating[Any, Any]]: ...
97+
98+
# TODO: The supported input and output dtypes are dependant on the value of `n`.
99+
# For example: `n < 0` always casts integer types to float64
100+
def matrix_power(
101+
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
102+
n: SupportsIndex,
103+
) -> NDArray[Any]: ...
104+
105+
@overload
106+
def cholesky(a: _ArrayLikeInt_co) -> NDArray[float64]: ...
107+
@overload
108+
def cholesky(a: _ArrayLikeFloat_co) -> NDArray[floating[Any]]: ...
109+
@overload
110+
def cholesky(a: _ArrayLikeComplex_co) -> NDArray[complexfloating[Any, Any]]: ...
111+
112+
@overload
113+
def qr(a: _ArrayLikeInt_co, mode: _ModeKind = ...) -> _2Tuple[NDArray[float64]]: ...
114+
@overload
115+
def qr(a: _ArrayLikeFloat_co, mode: _ModeKind = ...) -> _2Tuple[NDArray[floating[Any]]]: ...
116+
@overload
117+
def qr(a: _ArrayLikeComplex_co, mode: _ModeKind = ...) -> _2Tuple[NDArray[complexfloating[Any, Any]]]: ...
118+
119+
@overload
120+
def eigvals(a: _ArrayLikeInt_co) -> NDArray[float64] | NDArray[complex128]: ...
121+
@overload
122+
def eigvals(a: _ArrayLikeFloat_co) -> NDArray[floating[Any]] | NDArray[complexfloating[Any, Any]]: ...
123+
@overload
124+
def eigvals(a: _ArrayLikeComplex_co) -> NDArray[complexfloating[Any, Any]]: ...
125+
126+
@overload
127+
def eigvalsh(a: _ArrayLikeInt_co, UPLO: L["L", "U", "l", "u"] = ...) -> NDArray[float64]: ...
128+
@overload
129+
def eigvalsh(a: _ArrayLikeComplex_co, UPLO: L["L", "U", "l", "u"] = ...) -> NDArray[floating[Any]]: ...
130+
131+
@overload
132+
def eig(a: _ArrayLikeInt_co) -> _2Tuple[NDArray[float64]] | _2Tuple[NDArray[complex128]]: ...
133+
@overload
134+
def eig(a: _ArrayLikeFloat_co) -> _2Tuple[NDArray[floating[Any]]] | _2Tuple[NDArray[complexfloating[Any, Any]]]: ...
135+
@overload
136+
def eig(a: _ArrayLikeComplex_co) -> _2Tuple[NDArray[complexfloating[Any, Any]]]: ...
137+
138+
@overload
139+
def eigh(
140+
a: _ArrayLikeInt_co,
141+
UPLO: L["L", "U", "l", "u"] = ...,
142+
) -> Tuple[NDArray[float64], NDArray[float64]]: ...
143+
@overload
144+
def eigh(
145+
a: _ArrayLikeFloat_co,
146+
UPLO: L["L", "U", "l", "u"] = ...,
147+
) -> Tuple[NDArray[floating[Any]], NDArray[floating[Any]]]: ...
148+
@overload
149+
def eigh(
150+
a: _ArrayLikeComplex_co,
151+
UPLO: L["L", "U", "l", "u"] = ...,
152+
) -> Tuple[NDArray[floating[Any]], NDArray[complexfloating[Any, Any]]]: ...
153+
154+
@overload
155+
def svd(
156+
a: _ArrayLikeInt_co,
157+
full_matrices: bool = ...,
158+
compute_uv: L[True] = ...,
159+
hermitian: bool = ...,
160+
) -> Tuple[
161+
NDArray[float64],
162+
NDArray[float64],
163+
NDArray[float64],
164+
]: ...
165+
@overload
166+
def svd(
167+
a: _ArrayLikeFloat_co,
168+
full_matrices: bool = ...,
169+
compute_uv: L[True] = ...,
170+
hermitian: bool = ...,
171+
) -> Tuple[
172+
NDArray[floating[Any]],
173+
NDArray[floating[Any]],
174+
NDArray[floating[Any]],
175+
]: ...
176+
@overload
177+
def svd(
178+
a: _ArrayLikeComplex_co,
179+
full_matrices: bool = ...,
180+
compute_uv: L[True] = ...,
181+
hermitian: bool = ...,
182+
) -> Tuple[
183+
NDArray[complexfloating[Any, Any]],
184+
NDArray[floating[Any]],
185+
NDArray[complexfloating[Any, Any]],
186+
]: ...
187+
@overload
188+
def svd(
189+
a: _ArrayLikeInt_co,
190+
full_matrices: bool = ...,
191+
compute_uv: L[False] = ...,
192+
hermitian: bool = ...,
193+
) -> NDArray[float64]: ...
194+
@overload
195+
def svd(
196+
a: _ArrayLikeComplex_co,
197+
full_matrices: bool = ...,
198+
compute_uv: L[False] = ...,
199+
hermitian: bool = ...,
200+
) -> NDArray[floating[Any]]: ...
201+
202+
# TODO: Returns a scalar for 2D arrays and
203+
# a `(x.ndim - 2)`` dimensionl array otherwise
204+
def cond(x: _ArrayLikeComplex_co, p: None | float | L["fro", "nuc"] = ...) -> Any: ...
205+
206+
# TODO: Returns `int` for <2D arrays and `intp` otherwise
207+
def matrix_rank(
208+
A: _ArrayLikeComplex_co,
209+
tol: None | _ArrayLikeFloat_co = ...,
210+
hermitian: bool = ...,
211+
) -> Any: ...
212+
213+
@overload
214+
def pinv(
215+
a: _ArrayLikeInt_co,
216+
rcond: _ArrayLikeFloat_co = ...,
217+
hermitian: bool = ...,
218+
) -> NDArray[float64]: ...
219+
@overload
220+
def pinv(
221+
a: _ArrayLikeFloat_co,
222+
rcond: _ArrayLikeFloat_co = ...,
223+
hermitian: bool = ...,
224+
) -> NDArray[floating[Any]]: ...
225+
@overload
226+
def pinv(
227+
a: _ArrayLikeComplex_co,
228+
rcond: _ArrayLikeFloat_co = ...,
229+
hermitian: bool = ...,
230+
) -> NDArray[complexfloating[Any, Any]]: ...
231+
232+
# TODO: Returns a 2-tuple of scalars for 2D arrays and
233+
# a 2-tuple of `(a.ndim - 2)`` dimensionl arrays otherwise
234+
def slogdet(a: _ArrayLikeComplex_co) -> _2Tuple[Any]: ...
235+
236+
# TODO: Returns a 2-tuple of scalars for 2D arrays and
237+
# a 2-tuple of `(a.ndim - 2)`` dimensionl arrays otherwise
238+
def det(a: _ArrayLikeComplex_co) -> Any: ...
239+
240+
@overload
241+
def lstsq(a: _ArrayLikeInt_co, b: _ArrayLikeInt_co, rcond: None | float = ...) -> Tuple[
242+
NDArray[float64],
243+
NDArray[float64],
244+
int32,
245+
NDArray[float64],
246+
]: ...
247+
@overload
248+
def lstsq(a: _ArrayLikeFloat_co, b: _ArrayLikeFloat_co, rcond: None | float = ...) -> Tuple[
249+
NDArray[floating[Any]],
250+
NDArray[floating[Any]],
251+
int32,
252+
NDArray[floating[Any]],
253+
]: ...
254+
@overload
255+
def lstsq(a: _ArrayLikeComplex_co, b: _ArrayLikeComplex_co, rcond: None | float = ...) -> Tuple[
256+
NDArray[complexfloating[Any, Any]],
257+
NDArray[floating[Any]],
258+
int32,
259+
NDArray[floating[Any]],
260+
]: ...
261+
262+
@overload
263+
def norm(
264+
x: ArrayLike,
265+
ord: None | float | L["fro", "nuc"] = ...,
266+
axis: None = ...,
267+
keepdims: bool = ...,
268+
) -> floating[Any]: ...
269+
@overload
270+
def norm(
271+
x: ArrayLike,
272+
ord: None | float | L["fro", "nuc"] = ...,
273+
axis: SupportsInt | SupportsIndex | Tuple[int, ...] = ...,
274+
keepdims: bool = ...,
275+
) -> Any: ...
276+
277+
# TODO: Returns a scalar or array
278+
def multi_dot(
279+
arrays: Iterable[_ArrayLikeComplex_co | _ArrayLikeObject_co | _ArrayLikeTD64_co],
280+
*,
281+
out: None | NDArray[Any] = ...,
282+
) -> Any: ...
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import numpy as np
2+
import numpy.typing as npt
3+
4+
AR_f8: npt.NDArray[np.float64]
5+
AR_O: npt.NDArray[np.object_]
6+
AR_M: npt.NDArray[np.datetime64]
7+
8+
np.linalg.tensorsolve(AR_O, AR_O) # E: incompatible type
9+
10+
np.linalg.solve(AR_O, AR_O) # E: incompatible type
11+
12+
np.linalg.tensorinv(AR_O) # E: incompatible type
13+
14+
np.linalg.inv(AR_O) # E: incompatible type
15+
16+
np.linalg.matrix_power(AR_M, 5) # E: incompatible type
17+
18+
np.linalg.cholesky(AR_O) # E: incompatible type
19+
20+
np.linalg.qr(AR_O) # E: incompatible type
21+
np.linalg.qr(AR_f8, mode="bob") # E: No overload variant
22+
23+
np.linalg.eigvals(AR_O) # E: incompatible type
24+
25+
np.linalg.eigvalsh(AR_O) # E: incompatible type
26+
np.linalg.eigvalsh(AR_O, UPLO="bob") # E: No overload variant
27+
28+
np.linalg.eig(AR_O) # E: incompatible type
29+
30+
np.linalg.eigh(AR_O) # E: incompatible type
31+
np.linalg.eigh(AR_O, UPLO="bob") # E: No overload variant
32+
33+
np.linalg.svd(AR_O) # E: incompatible type
34+
35+
np.linalg.cond(AR_O) # E: incompatible type
36+
np.linalg.cond(AR_f8, p="bob") # E: incompatible type
37+
38+
np.linalg.matrix_rank(AR_O) # E: incompatible type
39+
40+
np.linalg.pinv(AR_O) # E: incompatible type
41+
42+
np.linalg.slogdet(AR_O) # E: incompatible type
43+
44+
np.linalg.det(AR_O) # E: incompatible type
45+
46+
np.linalg.norm(AR_f8, ord="bob") # E: No overload variant
47+
48+
np.linalg.multi_dot([AR_M]) # E: incompatible type

0 commit comments

Comments
 (0)