Skip to content

Commit 0440f65

Browse files
committed
🥅 np.linalg workarounds for Any shape overload bugs
1 parent 018d5e6 commit 0440f65

File tree

1 file changed

+51
-30
lines changed

1 file changed

+51
-30
lines changed

‎src/numpy-stubs/linalg/_linalg.pyi

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,10 @@ _Toinexact64l_1nd: TypeAlias = _nt._ToArray_1nd[_nt.inexact64l]
122122
_Toinexact64l_2ds: TypeAlias = _nt._ToArray_2ds[_nt.inexact64l]
123123
_Toinexact64l_3nd: TypeAlias = _nt._ToArray_3nd[_nt.inexact64l]
124124

125-
_ToUnsafe64_1nd: TypeAlias = _nt._ToArray2_1nd[
126-
_nt.inexact64 | _nt.co_integer | np.character[Any], complex | _nt._PyCharacter
127-
]
128-
_ToUnsafe64_2ds: TypeAlias = _nt._ToArray2_2ds[
129-
_nt.inexact64 | _nt.co_integer | np.character[Any], complex | _nt._PyCharacter
130-
]
131-
_ToUnsafe64_3nd: TypeAlias = _nt._ToArray2_3nd[
132-
_nt.inexact64 | _nt.co_integer | np.character[Any], complex | _nt._PyCharacter
133-
]
125+
_Unsafe64: TypeAlias = _nt.inexact64 | _nt.co_integer | np.character[Any]
126+
_ToUnsafe64_1nd: TypeAlias = _nt._ToArray2_1nd[_Unsafe64, complex | _nt._PyCharacter]
127+
_ToUnsafe64_2ds: TypeAlias = _nt._ToArray2_2ds[_Unsafe64, complex | _nt._PyCharacter]
128+
_ToUnsafe64_3nd: TypeAlias = _nt._ToArray2_3nd[_Unsafe64, complex | _nt._PyCharacter]
134129

135130
_Array2ND: TypeAlias = _nt.Array[_ScalarT, _nt.Shape2N]
136131

@@ -324,6 +319,10 @@ _PosInt: TypeAlias = L[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
324319
_NegInt: TypeAlias = L[-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16]
325320

326321
#
322+
@overload # workaround for microsoft/pyright#10232
323+
def matrix_power(a: _nt.CastsArray[np.float64, _nt.NeitherShape], n: CanIndex) -> _Array2ND[np.float64]: ...
324+
@overload # workaround for microsoft/pyright#10232
325+
def matrix_power(a: _nt.CastsWithArray[np.float64, _NumberT, _nt.NeitherShape], n: CanIndex) -> _Array2ND[_NumberT]: ...
327326
@overload
328327
def matrix_power(a: _nt.CanLenArray[_NumberT, _Shape2NDT], n: _PosInt) -> _nt.Array[_NumberT, _Shape2NDT]: ...
329328
@overload
@@ -392,6 +391,8 @@ def outer(x1: _nt.CoInteger_1d, x2: _nt.ToInteger_1d, /) -> _nt.Array2D[np.integ
392391
def outer(x1: _nt.ToNumber_1d, x2: _nt.ToNumber_1d, /) -> _nt.Array2D[Any]: ...
393392

394393
#
394+
@overload # workaround for microsoft/pyright#10232
395+
def multi_dot(arrays: Iterable[_nt._ToArray_nnd[_nt.co_number]], *, out: None = None) -> Any: ...
395396
@overload
396397
def multi_dot(arrays: Iterable[_nt._ToArray_1ds[_AnyNumberT]], *, out: None = None) -> _AnyNumberT: ...
397398
@overload
@@ -454,6 +455,8 @@ def cross(x1: _nt.CoComplex_1nd, x2: _nt.ToComplex_1nd, /, *, axis: int = -1) ->
454455
def cross(x1: _nt.CoComplex_1nd, x2: _nt.CoComplex_1nd, /, *, axis: int = -1) -> _nt.Array[Any]: ...
455456

456457
# pyright false positive in case of typevar constraints
458+
@overload # workaround for microsoft/pyright#10232
459+
def matmul(x1: _nt._ToArray_nnd[_nt.co_number], x2: _nt._ToArray_nnd[_nt.co_number], /) -> Any: ...
457460
@overload
458461
def matmul(x1: _nt._ToArray_1ds[_AnyNumberT], x2: _nt._ToArray_1ds[_AnyNumberT], /) -> _AnyNumberT: ... # pyright: ignore[reportOverlappingOverload]
459462
@overload
@@ -640,6 +643,14 @@ def svdvals(x: _Toinexact32_1nd, /) -> _nt.Array[np.float32]: ...
640643
def svdvals(x: _nt.CoComplex128_1nd, /) -> _nt.Array[np.floating]: ...
641644

642645
#
646+
@overload # workaround for microsoft/pyright#10232
647+
def matrix_rank(
648+
A: _nt._ToArray_nnd[_nt.co_complex128],
649+
tol: _nt.ToFloating_nd | None = None,
650+
hermitian: bool = False,
651+
*,
652+
rtol: _nt.ToFloating_nd | None = None,
653+
) -> Any: ...
643654
@overload # <2d +complex128
644655
def matrix_rank(
645656
A: _nt.CoComplex128_0d | _nt.CoComplex128_1ds,
@@ -674,6 +685,8 @@ def matrix_rank(
674685
) -> Any: ...
675686

676687
#
688+
@overload # workaround for microsoft/pyright#10232
689+
def cond(x: _nt._ToArray_nnd[_nt.co_complex128], p: _Ord | None = None) -> Any: ...
677690
@overload # 2d float64 | complex128
678691
def cond(x: _Toinexact64_2ds, p: _Ord | None = None) -> np.float64: ...
679692
@overload # 2d float32 | complex64
@@ -690,6 +703,8 @@ def cond(x: _nt.CoComplex128_3nd, p: _Ord | None = None) -> _nt.Array[np.floatin
690703
def cond(x: _nt.CoComplex128_1nd, p: _Ord | None = None) -> Any: ...
691704

692705
# keep in sync with `det`
706+
@overload # # workaround for microsoft/pyright#10232
707+
def slogdet(a: _nt._ToArray_nnd[_nt.co_complex128]) -> SlogdetResult: ...
693708
@overload # 2d float64
694709
def slogdet(a: _ToFloat64_2ds) -> SlogdetResult[np.float64, np.float64]: ...
695710
@overload # 2d float32 + complex64
@@ -703,9 +718,11 @@ def slogdet(a: _nt._ToArray_3nd[_Inexact32T]) -> SlogdetResult[_nt.Array[np.floa
703718
@overload # >2d complex128
704719
def slogdet(a: _nt.ToComplex128_3nd) -> SlogdetResult[_nt.Array[np.float64], _nt.Array[np.complex128]]: ...
705720
@overload # +complex128
706-
def slogdet(a: _nt.CoComplex128_1nd) -> SlogdetResult[Any, Any]: ...
721+
def slogdet(a: _nt.CoComplex128_1nd) -> SlogdetResult: ...
707722

708723
#
724+
@overload # workaround for microsoft/pyright#10232
725+
def det(a: _nt._ToArray_nnd[_nt.co_complex128]) -> Any: ...
709726
@overload # 2d float64
710727
def det(a: _ToFloat64_2ds) -> np.float64: ...
711728
@overload # 2d float32 + complex64
@@ -763,9 +780,9 @@ def norm(
763780
x: _ToUnsafe64_1nd, ord: _Ord | None = None, axis: _Ax2 | None = None, *, keepdims: _True
764781
) -> _Array2ND[np.float64]: ...
765782
@overload # float64 | complex128 | character, axis=<given> (positional)
766-
def norm(x: _ToUnsafe64_1nd, ord: _Ord | None, axis: _Ax2, keepdims: bool = False) -> _nt.Array[np.float64]: ... # type: ignore[overload-overlap]
783+
def norm(x: _ToUnsafe64_1nd, ord: _Ord | None, axis: _Ax2, keepdims: bool = False) -> _nt.Array[np.float64]: ...
767784
@overload # float64 | complex128 | character, axis=<given> (keyword)
768-
def norm( # type: ignore[overload-overlap]
785+
def norm(
769786
x: _ToUnsafe64_1nd, ord: _Ord | None = None, *, axis: _Ax2, keepdims: bool = False
770787
) -> _nt.Array[np.float64]: ...
771788
@overload # float16, axis=None, keepdims=False
@@ -820,34 +837,38 @@ def norm(
820837
def norm(x: _nt.CoComplex_1nd, ord: _Ord | None = None, axis: _Ax2 | None = None, keepdims: bool = False) -> Any: ...
821838

822839
#
823-
@overload # 2d float64 | complex128 | character
824-
def matrix_norm(x: _ToUnsafe64_2ds, /, *, keepdims: bool = False, ord: _Ord = "fro") -> np.float64: ... # type: ignore[overload-overlap]
825-
@overload # nd float64 | complex128 | character, keepdims=True
826-
def matrix_norm(x: _ToUnsafe64_1nd, /, *, keepdims: _True, ord: _Ord = "fro") -> _Array2ND[np.float64]: ...
827-
@overload # >2d float64 | complex128 | character
828-
def matrix_norm(x: _ToUnsafe64_3nd, /, *, keepdims: bool = False, ord: _Ord = "fro") -> _nt.Array[np.float64]: ...
840+
@overload # workaround for microsoft/pyright#10232
841+
def matrix_norm(
842+
x: _nt._ToArray_nnd[_nt.co_number | np.character[Any]], /, *, keepdims: bool = False, ord: _Ord = "fro"
843+
) -> Any: ...
829844
@overload # 2d float16
830845
def matrix_norm(x: _nt.ToFloat16_2ds, /, *, keepdims: bool = False, ord: _Ord = "fro") -> np.float16: ... # type: ignore[overload-overlap]
831-
@overload # nd float16, keepdims=True
832-
def matrix_norm(x: _nt.ToFloat16_1nd, /, *, keepdims: _True, ord: _Ord = "fro") -> _Array2ND[np.float16]: ...
833-
@overload # >2d float16
834-
def matrix_norm(x: _nt.ToFloat16_3nd, /, *, keepdims: bool = False, ord: _Ord = "fro") -> _nt.Array[np.float16]: ...
835846
@overload # 2d float32 | complex64, keepdims=True
836847
def matrix_norm(x: _Toinexact32_2ds, /, *, keepdims: bool = False, ord: _Ord = "fro") -> np.float32: ... # type: ignore[overload-overlap]
837-
@overload # nd float32 | complex64, keepdims=True
838-
def matrix_norm(x: _Toinexact32_1nd, /, *, keepdims: _True, ord: _Ord = "fro") -> _Array2ND[np.float32]: ...
839-
@overload # >2d float32 | complex64
840-
def matrix_norm(x: _Toinexact32_3nd, /, *, keepdims: bool = False, ord: _Ord = "fro") -> _nt.Array[np.float32]: ...
848+
@overload # 2d float64 | complex128 | character
849+
def matrix_norm(x: _ToUnsafe64_2ds, /, *, keepdims: bool = False, ord: _Ord = "fro") -> np.float64: ... # type: ignore[overload-overlap]
841850
@overload # 2d longdouble | clongdouble
842851
def matrix_norm(x: _Toinexact64l_2ds, /, *, keepdims: bool = False, ord: _Ord = "fro") -> np.longdouble: ... # type: ignore[overload-overlap]
843-
@overload # nd longdouble | clongdouble, keepdims=True
844-
def matrix_norm(x: _Toinexact64l_1nd, /, *, keepdims: _True, ord: _Ord = "fro") -> _Array2ND[np.longdouble]: ...
845-
@overload # >2d longdouble | clongdouble
846-
def matrix_norm(x: _Toinexact64l_3nd, /, *, keepdims: bool = False, ord: _Ord = "fro") -> _nt.Array[np.longdouble]: ...
847852
@overload # 2d +number
848853
def matrix_norm(x: _nt.CoComplex_2ds, /, *, keepdims: bool = False, ord: _Ord = "fro") -> np.floating: ... # type: ignore[overload-overlap]
854+
@overload # nd float16, keepdims=True
855+
def matrix_norm(x: _nt.ToFloat16_1nd, /, *, keepdims: _True, ord: _Ord = "fro") -> _Array2ND[np.float16]: ...
856+
@overload # nd float32 | complex64, keepdims=True
857+
def matrix_norm(x: _Toinexact32_1nd, /, *, keepdims: _True, ord: _Ord = "fro") -> _Array2ND[np.float32]: ...
858+
@overload # nd float64 | complex128 | character, keepdims=True
859+
def matrix_norm(x: _ToUnsafe64_1nd, /, *, keepdims: _True, ord: _Ord = "fro") -> _Array2ND[np.float64]: ...
860+
@overload # nd longdouble | clongdouble, keepdims=True
861+
def matrix_norm(x: _Toinexact64l_1nd, /, *, keepdims: _True, ord: _Ord = "fro") -> _Array2ND[np.longdouble]: ...
849862
@overload # nd +number, keepdims=True
850863
def matrix_norm(x: _nt.CoComplex_1nd, /, *, keepdims: _True, ord: _Ord = "fro") -> _Array2ND[np.floating]: ...
864+
@overload # >2d float16
865+
def matrix_norm(x: _nt.ToFloat16_3nd, /, *, keepdims: bool = False, ord: _Ord = "fro") -> _nt.Array[np.float16]: ...
866+
@overload # >2d float32 | complex64
867+
def matrix_norm(x: _Toinexact32_3nd, /, *, keepdims: bool = False, ord: _Ord = "fro") -> _nt.Array[np.float32]: ...
868+
@overload # >2d float64 | complex128 | character
869+
def matrix_norm(x: _ToUnsafe64_3nd, /, *, keepdims: bool = False, ord: _Ord = "fro") -> _nt.Array[np.float64]: ...
870+
@overload # >2d longdouble | clongdouble
871+
def matrix_norm(x: _Toinexact64l_3nd, /, *, keepdims: bool = False, ord: _Ord = "fro") -> _nt.Array[np.longdouble]: ...
851872
@overload # >2d +number
852873
def matrix_norm(x: _nt.CoComplex_3nd, /, *, keepdims: bool = False, ord: _Ord = "fro") -> _nt.Array[np.floating]: ...
853874
@overload # nd +number

0 commit comments

Comments
 (0)