Skip to content

Commit 9dda370

Browse files
committed
🏷️ fix stubtest errors in numpy.linalg._linalg
1 parent 0eba466 commit 9dda370

File tree

2 files changed

+52
-16
lines changed

2 files changed

+52
-16
lines changed

.mypyignore-todo

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,6 @@ numpy.lib.format.read_array(_header_(1|2)_0)?
4444
numpy.lib.mixins.NDArrayOperatorsMixin.__array_ufunc__
4545
numpy.lib.recfunctions.unstructured_to_structured
4646

47-
numpy.linalg(._linalg)?.cholesky
48-
numpy.linalg(._linalg)?.pinv
49-
numpy.linalg(._linalg)?.tensordot
5047
numpy.linalg.lapack_lite
5148
numpy.linalg.linalg
5249

src/numpy-stubs/linalg/_linalg.pyi

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ from _numtype import (
3232
CoInteger_1d,
3333
CoInteger_1ds,
3434
CoInteger_1nd,
35+
CoSInteger_1nd,
3536
CoTimeDelta_1d,
37+
CoTimeDelta_1nd,
38+
CoUInteger_1nd,
3639
Is,
3740
Sequence_2d,
3841
Sequence_2nd,
@@ -84,6 +87,7 @@ from _numtype import (
8487
ToObject_1d,
8588
ToObject_1nd,
8689
ToObject_2nd,
90+
ToSInteger_1nd,
8791
ToTimeDelta_1d,
8892
ToTimeDelta_1nd,
8993
ToUInteger_1nd,
@@ -101,7 +105,8 @@ from _numtype import (
101105
_ToArray_3nd,
102106
)
103107
from numpy._core.fromnumeric import matrix_transpose
104-
from numpy._core.numeric import tensordot, vecdot
108+
from numpy._core.numeric import vecdot
109+
from numpy._globals import _NoValueType
105110
from numpy._typing import DTypeLike, _32Bit, _64Bit, _DTypeLike as _ToDType
106111

107112
__all__ = [
@@ -245,6 +250,40 @@ class SlogdetResult(NamedTuple, Generic[_FloatingNDT_co, _InexactNDT_co]):
245250

246251
class LinAlgError(ValueError): ...
247252

253+
# keep in sync with `numpy._core.numeric.tensordot`
254+
@overload
255+
def tensordot(x1: ToBool_1nd, x2: ToBool_1nd, /, *, axes: _Ax2 = 2) -> Array[np.bool]: ...
256+
@overload
257+
def tensordot(x1: ToUInteger_1nd, x2: CoUInteger_1nd, /, *, axes: _Ax2 = 2) -> Array[np.unsignedinteger]: ...
258+
@overload
259+
def tensordot(x1: CoUInteger_1nd, x2: ToUInteger_1nd, /, *, axes: _Ax2 = 2) -> Array[np.unsignedinteger]: ...
260+
@overload
261+
def tensordot(x1: ToSInteger_1nd, x2: CoSInteger_1nd, /, *, axes: _Ax2 = 2) -> Array[np.signedinteger]: ...
262+
@overload
263+
def tensordot(x1: CoSInteger_1nd, x2: ToSInteger_1nd, /, *, axes: _Ax2 = 2) -> Array[np.signedinteger]: ...
264+
@overload
265+
def tensordot(x1: ToFloating_1nd, x2: CoFloating_1nd, /, *, axes: _Ax2 = 2) -> Array[np.floating]: ...
266+
@overload
267+
def tensordot(x1: CoFloating_1nd, x2: ToFloating_1nd, /, *, axes: _Ax2 = 2) -> Array[np.floating]: ...
268+
@overload
269+
def tensordot(x1: ToComplex_1nd, x2: CoComplex_1nd, /, *, axes: _Ax2 = 2) -> Array[np.complexfloating]: ...
270+
@overload
271+
def tensordot(x1: CoComplex_1nd, x2: ToComplex_1nd, /, *, axes: _Ax2 = 2) -> Array[np.complexfloating]: ...
272+
@overload
273+
def tensordot(x1: ToTimeDelta_1nd, x2: CoTimeDelta_1nd, /, *, axes: _Ax2 = 2) -> Array[np.timedelta64]: ...
274+
@overload
275+
def tensordot(x1: CoTimeDelta_1nd, x2: ToTimeDelta_1nd, /, *, axes: _Ax2 = 2) -> Array[np.timedelta64]: ...
276+
@overload
277+
def tensordot(x1: ToObject_1nd, x2: ToObject_1nd, /, *, axes: _Ax2 = 2) -> Array[np.object_]: ...
278+
@overload
279+
def tensordot(
280+
x1: CoComplex_1nd | CoTimeDelta_1nd | ToObject_1nd,
281+
x2: CoComplex_1nd | CoTimeDelta_1nd | ToObject_1nd,
282+
/,
283+
*,
284+
axes: _Ax2 = 2,
285+
) -> Array[Any]: ...
286+
248287
# keep in sync with `solve`
249288
@overload
250289
def tensorsolve(a: _ToFloat64_1nd, b: CoFloat64_1nd, axes: _Axes | None = None) -> Array[np.float64]: ...
@@ -316,47 +355,47 @@ def pinv(
316355
rcond: ToFloating_nd | None = None,
317356
hermitian: bool = False,
318357
*,
319-
rtol: ToFloating_nd | None = None,
358+
rtol: ToFloating_nd | _NoValueType = ...,
320359
) -> _Array_2nd[np.float64]: ...
321360
@overload
322361
def pinv(
323362
a: ToComplex128_1nd,
324363
rcond: ToFloating_nd | None = None,
325364
hermitian: bool = False,
326365
*,
327-
rtol: ToFloating_nd | None = None,
366+
rtol: ToFloating_nd | _NoValueType = ...,
328367
) -> _Array_2nd[np.complex128]: ...
329368
@overload
330369
def pinv(
331370
a: ToFloat32_1nd,
332371
rcond: ToFloating_nd | None = None,
333372
hermitian: bool = False,
334373
*,
335-
rtol: ToFloating_nd | None = None,
374+
rtol: ToFloating_nd | _NoValueType = ...,
336375
) -> _Array_2nd[np.float32]: ...
337376
@overload
338377
def pinv(
339378
a: ToComplex64_1nd,
340379
rcond: ToFloating_nd | None = None,
341380
hermitian: bool = False,
342381
*,
343-
rtol: ToFloating_nd | None = None,
382+
rtol: ToFloating_nd | _NoValueType = ...,
344383
) -> _Array_2nd[np.complex64]: ...
345384
@overload
346385
def pinv(
347386
a: CoFloat64_1nd,
348387
rcond: ToFloating_nd | None = None,
349388
hermitian: bool = False,
350389
*,
351-
rtol: ToFloating_nd | None = None,
390+
rtol: ToFloating_nd | _NoValueType = ...,
352391
) -> _Array_2nd[np.floating]: ...
353392
@overload
354393
def pinv(
355394
a: CoComplex128_1nd,
356395
rcond: ToFloating_nd | None = None,
357396
hermitian: bool = False,
358397
*,
359-
rtol: ToFloating_nd | None = None,
398+
rtol: ToFloating_nd | _NoValueType = ...,
360399
) -> _Array_2nd[np.inexact]: ...
361400

362401
_PosInt: TypeAlias = L[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
@@ -392,17 +431,17 @@ def matrix_power(a: CoComplex_1nd | ToObject_1nd, n: CanIndex) -> _Array_2nd[Any
392431

393432
#
394433
@overload
395-
def cholesky(a: _ToFloat64_1nd) -> _Array_2nd[np.float64]: ...
434+
def cholesky(a: _ToFloat64_1nd, /, *, upper: bool = False) -> _Array_2nd[np.float64]: ...
396435
@overload
397-
def cholesky(a: ToComplex128_1nd) -> _Array_2nd[np.complex128]: ...
436+
def cholesky(a: ToComplex128_1nd, /, *, upper: bool = False) -> _Array_2nd[np.complex128]: ...
398437
@overload
399-
def cholesky(a: ToFloat32_1nd) -> _Array_2nd[np.float32]: ...
438+
def cholesky(a: ToFloat32_1nd, /, *, upper: bool = False) -> _Array_2nd[np.float32]: ...
400439
@overload
401-
def cholesky(a: ToComplex64_1nd) -> _Array_2nd[np.complex64]: ...
440+
def cholesky(a: ToComplex64_1nd, /, *, upper: bool = False) -> _Array_2nd[np.complex64]: ...
402441
@overload
403-
def cholesky(a: CoFloat64_1nd) -> _Array_2nd[np.floating]: ...
442+
def cholesky(a: CoFloat64_1nd, /, *, upper: bool = False) -> _Array_2nd[np.floating]: ...
404443
@overload
405-
def cholesky(a: CoComplex128_1nd) -> _Array_2nd[np.inexact]: ...
444+
def cholesky(a: CoComplex128_1nd, /, *, upper: bool = False) -> _Array_2nd[np.inexact]: ...
406445

407446
#
408447
@overload

0 commit comments

Comments
 (0)