Skip to content

Commit 4c3ceaa

Browse files
committed
TYP: Concrete complex128 scalar type with builtins.complex as a base class
1 parent 1bcb9f3 commit 4c3ceaa

File tree

5 files changed

+126
-43
lines changed

5 files changed

+126
-43
lines changed

numpy/__init__.pyi

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2845,6 +2845,8 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
28452845
# operand. An exception to this rule are unsigned integers though, which
28462846
# also accepts a signed integer for the right operand as long it is a 0D
28472847
# object and its value is >= 0
2848+
# NOTE: Due to a mypy bug, overloading on e.g. `self: NDArray[SCT_floating]` won't
2849+
# work, as this will lead to `false negatives` when using these inplace ops.
28482850
@overload
28492851
def __iadd__(self: NDArray[_UnknownType], other: _ArrayLikeUnknown, /) -> NDArray[Any]: ...
28502852
@overload
@@ -2858,6 +2860,8 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
28582860
@overload
28592861
def __iadd__(self: NDArray[floating[_NBit1]], other: _ArrayLikeFloat_co, /) -> NDArray[floating[_NBit1]]: ...
28602862
@overload
2863+
def __iadd__(self: NDArray[complex128], other: _ArrayLikeComplex_co, /) -> NDArray[complex128]: ...
2864+
@overload
28612865
def __iadd__(self: NDArray[complexfloating[_NBit1, _NBit1]], other: _ArrayLikeComplex_co, /) -> NDArray[complexfloating[_NBit1, _NBit1]]: ...
28622866
@overload
28632867
def __iadd__(self: NDArray[timedelta64], other: _ArrayLikeTD64_co, /) -> NDArray[timedelta64]: ...
@@ -2877,6 +2881,8 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
28772881
@overload
28782882
def __isub__(self: NDArray[floating[_NBit1]], other: _ArrayLikeFloat_co, /) -> NDArray[floating[_NBit1]]: ...
28792883
@overload
2884+
def __isub__(self: NDArray[complex128], other: _ArrayLikeComplex_co, /) -> NDArray[complex128]: ...
2885+
@overload
28802886
def __isub__(self: NDArray[complexfloating[_NBit1, _NBit1]], other: _ArrayLikeComplex_co, /) -> NDArray[complexfloating[_NBit1, _NBit1]]: ...
28812887
@overload
28822888
def __isub__(self: NDArray[timedelta64], other: _ArrayLikeTD64_co, /) -> NDArray[timedelta64]: ...
@@ -2898,6 +2904,8 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
28982904
@overload
28992905
def __imul__(self: NDArray[floating[_NBit1]], other: _ArrayLikeFloat_co, /) -> NDArray[floating[_NBit1]]: ...
29002906
@overload
2907+
def __imul__(self: NDArray[complex128], other: _ArrayLikeComplex_co, /) -> NDArray[complex128]: ...
2908+
@overload
29012909
def __imul__(self: NDArray[complexfloating[_NBit1, _NBit1]], other: _ArrayLikeComplex_co, /) -> NDArray[complexfloating[_NBit1, _NBit1]]: ...
29022910
@overload
29032911
def __imul__(self: NDArray[timedelta64], other: _ArrayLikeFloat_co, /) -> NDArray[timedelta64]: ...
@@ -2911,6 +2919,8 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
29112919
@overload
29122920
def __itruediv__(self: NDArray[floating[_NBit1]], other: _ArrayLikeFloat_co, /) -> NDArray[floating[_NBit1]]: ...
29132921
@overload
2922+
def __itruediv__(self: NDArray[complex128], other: _ArrayLikeComplex_co, /) -> NDArray[complex128]: ...
2923+
@overload
29142924
def __itruediv__(self: NDArray[complexfloating[_NBit1, _NBit1]], other: _ArrayLikeComplex_co, /) -> NDArray[complexfloating[_NBit1, _NBit1]]: ...
29152925
@overload
29162926
def __itruediv__(self: NDArray[timedelta64], other: _ArrayLikeBool_co, /) -> NoReturn: ...
@@ -2930,6 +2940,8 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
29302940
@overload
29312941
def __ifloordiv__(self: NDArray[floating[_NBit1]], other: _ArrayLikeFloat_co, /) -> NDArray[floating[_NBit1]]: ...
29322942
@overload
2943+
def __ifloordiv__(self: NDArray[complex128], other: _ArrayLikeComplex_co, /) -> NDArray[complex128]: ...
2944+
@overload
29332945
def __ifloordiv__(self: NDArray[complexfloating[_NBit1, _NBit1]], other: _ArrayLikeComplex_co, /) -> NDArray[complexfloating[_NBit1, _NBit1]]: ...
29342946
@overload
29352947
def __ifloordiv__(self: NDArray[timedelta64], other: _ArrayLikeBool_co, /) -> NoReturn: ...
@@ -2949,6 +2961,8 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
29492961
@overload
29502962
def __ipow__(self: NDArray[floating[_NBit1]], other: _ArrayLikeFloat_co, /) -> NDArray[floating[_NBit1]]: ...
29512963
@overload
2964+
def __ipow__(self: NDArray[complex128], other: _ArrayLikeComplex_co, /) -> NDArray[complex128]: ...
2965+
@overload
29522966
def __ipow__(self: NDArray[complexfloating[_NBit1, _NBit1]], other: _ArrayLikeComplex_co, /) -> NDArray[complexfloating[_NBit1, _NBit1]]: ...
29532967
@overload
29542968
def __ipow__(self: NDArray[object_], other: Any, /) -> NDArray[object_]: ...
@@ -3032,6 +3046,8 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
30323046
@overload
30333047
def __imatmul__(self: NDArray[floating[_NBit1]], other: _ArrayLikeFloat_co, /) -> NDArray[floating[_NBit1]]: ...
30343048
@overload
3049+
def __imatmul__(self: NDArray[complex128], other: _ArrayLikeComplex_co, /) -> NDArray[complex128]: ...
3050+
@overload
30353051
def __imatmul__(self: NDArray[complexfloating[_NBit1, _NBit1]], other: _ArrayLikeComplex_co, /) -> NDArray[complexfloating[_NBit1, _NBit1]]: ...
30363052
@overload
30373053
def __imatmul__(self: NDArray[object_], other: Any, /) -> NDArray[object_]: ...
@@ -3534,8 +3550,7 @@ uint: TypeAlias = uintp
35343550
ulong: TypeAlias = unsignedinteger[_NBitLong]
35353551
ulonglong: TypeAlias = unsignedinteger[_NBitLongLong]
35363552

3537-
class inexact(number[_NBit1]): # type: ignore
3538-
def __getnewargs__(self: inexact[_64Bit]) -> tuple[float, ...]: ...
3553+
class inexact(number[_NBit1]): ... # type: ignore[misc]
35393554

35403555
_IntType = TypeVar("_IntType", bound=integer[Any])
35413556

@@ -3571,6 +3586,7 @@ float32: TypeAlias = floating[_32Bit]
35713586

35723587
# NOTE: `_64Bit` is equivalent to `_64Bit | _32Bit | _16Bit | _8Bit`
35733588
_Float64_co: TypeAlias = float | floating[_64Bit] | integer[_64Bit] | np.bool
3589+
_Complex128_co: TypeAlias = complex | complexfloating[_64Bit, _64Bit] | _Float64_co
35743590

35753591
# either a C `double`, `float`, or `longdouble`
35763592
class float64(floating[_64Bit], float): # type: ignore[misc]
@@ -3588,77 +3604,101 @@ class float64(floating[_64Bit], float): # type: ignore[misc]
35883604
@overload
35893605
def __add__(self, other: _Float64_co, /) -> float64: ...
35903606
@overload
3607+
def __add__(self, other: complexfloating[_64Bit, _64Bit], /) -> complex128: ...
3608+
@overload
35913609
def __add__(self, other: complexfloating[_NBit1, _NBit2], /) -> complexfloating[_NBit1 | _64Bit, _NBit2 | _64Bit]: ...
35923610
@overload
35933611
def __add__(self, other: complex, /) -> float64 | complex128: ...
35943612
@overload
35953613
def __radd__(self, other: _Float64_co, /) -> float64: ...
35963614
@overload
3615+
def __radd__(self, other: complexfloating[_64Bit, _64Bit], /) -> complex128: ...
3616+
@overload
35973617
def __radd__(self, other: complexfloating[_NBit1, _NBit2], /) -> complexfloating[_NBit1 | _64Bit, _NBit2 | _64Bit]: ...
35983618
@overload
35993619
def __radd__(self, other: complex, /) -> float64 | complex128: ...
36003620

36013621
@overload
36023622
def __sub__(self, other: _Float64_co, /) -> float64: ...
36033623
@overload
3624+
def __sub__(self, other: complexfloating[_64Bit, _64Bit], /) -> complex128: ...
3625+
@overload
36043626
def __sub__(self, other: complexfloating[_NBit1, _NBit2], /) -> complexfloating[_NBit1 | _64Bit, _NBit2 | _64Bit]: ...
36053627
@overload
36063628
def __sub__(self, other: complex, /) -> float64 | complex128: ...
36073629
@overload
36083630
def __rsub__(self, other: _Float64_co, /) -> float64: ...
36093631
@overload
3632+
def __rsub__(self, other: complexfloating[_64Bit, _64Bit], /) -> complex128: ...
3633+
@overload
36103634
def __rsub__(self, other: complexfloating[_NBit1, _NBit2], /) -> complexfloating[_NBit1 | _64Bit, _NBit2 | _64Bit]: ...
36113635
@overload
36123636
def __rsub__(self, other: complex, /) -> float64 | complex128: ...
36133637

36143638
@overload
36153639
def __mul__(self, other: _Float64_co, /) -> float64: ...
36163640
@overload
3641+
def __mul__(self, other: complexfloating[_64Bit, _64Bit], /) -> complex128: ...
3642+
@overload
36173643
def __mul__(self, other: complexfloating[_NBit1, _NBit2], /) -> complexfloating[_NBit1 | _64Bit, _NBit2 | _64Bit]: ...
36183644
@overload
36193645
def __mul__(self, other: complex, /) -> float64 | complex128: ...
36203646
@overload
36213647
def __rmul__(self, other: _Float64_co, /) -> float64: ...
36223648
@overload
3649+
def __rmul__(self, other: complexfloating[_64Bit, _64Bit], /) -> complex128: ...
3650+
@overload
36233651
def __rmul__(self, other: complexfloating[_NBit1, _NBit2], /) -> complexfloating[_NBit1 | _64Bit, _NBit2 | _64Bit]: ...
36243652
@overload
36253653
def __rmul__(self, other: complex, /) -> float64 | complex128: ...
36263654

36273655
@overload
36283656
def __truediv__(self, other: _Float64_co, /) -> float64: ...
36293657
@overload
3658+
def __truediv__(self, other: complexfloating[_64Bit, _64Bit], /) -> complex128: ...
3659+
@overload
36303660
def __truediv__(self, other: complexfloating[_NBit1, _NBit2], /) -> complexfloating[_NBit1 | _64Bit, _NBit2 | _64Bit]: ...
36313661
@overload
36323662
def __truediv__(self, other: complex, /) -> float64 | complex128: ...
36333663
@overload
36343664
def __rtruediv__(self, other: _Float64_co, /) -> float64: ...
36353665
@overload
3666+
def __rtruediv__(self, other: complexfloating[_64Bit, _64Bit], /) -> complex128: ...
3667+
@overload
36363668
def __rtruediv__(self, other: complexfloating[_NBit1, _NBit2], /) -> complexfloating[_NBit1 | _64Bit, _NBit2 | _64Bit]: ...
36373669
@overload
36383670
def __rtruediv__(self, other: complex, /) -> float64 | complex128: ...
36393671

36403672
@overload
36413673
def __floordiv__(self, other: _Float64_co, /) -> float64: ...
36423674
@overload
3675+
def __floordiv__(self, other: complexfloating[_64Bit, _64Bit], /) -> complex128: ...
3676+
@overload
36433677
def __floordiv__(self, other: complexfloating[_NBit1, _NBit2], /) -> complexfloating[_NBit1 | _64Bit, _NBit2 | _64Bit]: ...
36443678
@overload
36453679
def __floordiv__(self, other: complex, /) -> float64 | complex128: ...
36463680
@overload
36473681
def __rfloordiv__(self, other: _Float64_co, /) -> float64: ...
36483682
@overload
3683+
def __rfloordiv__(self, other: complexfloating[_64Bit, _64Bit], /) -> complex128: ...
3684+
@overload
36493685
def __rfloordiv__(self, other: complexfloating[_NBit1, _NBit2], /) -> complexfloating[_NBit1 | _64Bit, _NBit2 | _64Bit]: ...
36503686
@overload
36513687
def __rfloordiv__(self, other: complex, /) -> float64 | complex128: ...
36523688

36533689
@overload
36543690
def __pow__(self, other: _Float64_co, /) -> float64: ...
36553691
@overload
3692+
def __pow__(self, other: complexfloating[_64Bit, _64Bit], /) -> complex128: ...
3693+
@overload
36563694
def __pow__(self, other: complexfloating[_NBit1, _NBit2], /) -> complexfloating[_NBit1 | _64Bit, _NBit2 | _64Bit]: ...
36573695
@overload
36583696
def __pow__(self, other: complex, /) -> float64 | complex128: ...
36593697
@overload
36603698
def __rpow__(self, other: _Float64_co, /) -> float64: ...
36613699
@overload
3700+
def __rpow__(self, other: complexfloating[_64Bit, _64Bit], /) -> complex128: ...
3701+
@overload
36623702
def __rpow__(self, other: complexfloating[_NBit1, _NBit2], /) -> complexfloating[_NBit1 | _64Bit, _NBit2 | _64Bit]: ...
36633703
@overload
36643704
def __rpow__(self, other: complex, /) -> float64 | complex128: ...
@@ -3681,16 +3721,13 @@ longdouble: TypeAlias = floating[_NBitLongDouble]
36813721

36823722
class complexfloating(inexact[_NBit1], Generic[_NBit1, _NBit2]):
36833723
def __init__(self, value: _ComplexValue = ..., /) -> None: ...
3684-
def item(
3685-
self, args: L[0] | tuple[()] | tuple[L[0]] = ..., /,
3686-
) -> complex: ...
3724+
def item(self, args: L[0] | tuple[()] | tuple[L[0]] = ..., /) -> complex: ...
36873725
def tolist(self) -> complex: ...
36883726
@property
36893727
def real(self) -> floating[_NBit1]: ... # type: ignore[override]
36903728
@property
36913729
def imag(self) -> floating[_NBit2]: ... # type: ignore[override]
36923730
def __abs__(self) -> floating[_NBit1]: ... # type: ignore[override]
3693-
def __getnewargs__(self: complex128) -> tuple[float, float]: ...
36943731
# NOTE: Deprecated
36953732
# def __round__(self, ndigits=...): ...
36963733
__add__: _ComplexOp[_NBit1]
@@ -3705,7 +3742,49 @@ class complexfloating(inexact[_NBit1], Generic[_NBit1, _NBit2]):
37053742
__rpow__: _ComplexOp[_NBit1]
37063743

37073744
complex64: TypeAlias = complexfloating[_32Bit, _32Bit]
3708-
complex128: TypeAlias = complexfloating[_64Bit, _64Bit]
3745+
3746+
class complex128(complexfloating[_64Bit, _64Bit], complex):
3747+
def __getnewargs__(self, /) -> tuple[float, float]: ...
3748+
3749+
# overrides for `floating` and `builtins.float` compatibility
3750+
@property
3751+
def real(self) -> float64: ...
3752+
@property
3753+
def imag(self) -> float64: ...
3754+
def __abs__(self) -> float64: ...
3755+
def conjugate(self) -> Self: ...
3756+
3757+
# complex128-specific operator overrides
3758+
@overload
3759+
def __add__(self, other: _Complex128_co, /) -> complex128: ...
3760+
@overload
3761+
def __add__(self, other: complexfloating[_NBit1, _NBit2], /) -> complexfloating[_NBit1 | _64Bit, _NBit2 | _64Bit]: ...
3762+
def __radd__(self, other: _Complex128_co, /) -> complex128: ...
3763+
3764+
@overload
3765+
def __sub__(self, other: _Complex128_co, /) -> complex128: ...
3766+
@overload
3767+
def __sub__(self, other: complexfloating[_NBit1, _NBit2], /) -> complexfloating[_NBit1 | _64Bit, _NBit2 | _64Bit]: ...
3768+
def __rsub__(self, other: _Complex128_co, /) -> complex128: ...
3769+
3770+
@overload
3771+
def __mul__(self, other: _Complex128_co, /) -> complex128: ...
3772+
@overload
3773+
def __mul__(self, other: complexfloating[_NBit1, _NBit2], /) -> complexfloating[_NBit1 | _64Bit, _NBit2 | _64Bit]: ...
3774+
def __rmul__(self, other: _Complex128_co, /) -> complex128: ...
3775+
3776+
@overload
3777+
def __truediv__(self, other: _Complex128_co, /) -> complex128: ...
3778+
@overload
3779+
def __truediv__(self, other: complexfloating[_NBit1, _NBit2], /) -> complexfloating[_NBit1 | _64Bit, _NBit2 | _64Bit]: ...
3780+
def __rtruediv__(self, other: _Complex128_co, /) -> complex128: ...
3781+
3782+
@overload
3783+
def __pow__(self, other: _Complex128_co, /) -> complex128: ...
3784+
@overload
3785+
def __pow__(self, other: complexfloating[_NBit1, _NBit2], /) -> complexfloating[_NBit1 | _64Bit, _NBit2 | _64Bit]: ...
3786+
def __rpow__(self, other: _Complex128_co, /) -> complex128: ...
3787+
37093788

37103789
csingle: TypeAlias = complexfloating[_NBitSingle, _NBitSingle]
37113790
cdouble: TypeAlias = complexfloating[_NBitDouble, _NBitDouble]

numpy/typing/tests/data/pass/arithmetic.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import Any
44
import numpy as np
5+
import numpy.typing as npt
56
import pytest
67

78
c16 = np.complex128(1)
@@ -57,14 +58,14 @@ def __rpow__(self, value: Any) -> Object:
5758
return self
5859

5960

60-
AR_b: np.ndarray[Any, np.dtype[np.bool]] = np.array([True])
61-
AR_u: np.ndarray[Any, np.dtype[np.uint32]] = np.array([1], dtype=np.uint32)
62-
AR_i: np.ndarray[Any, np.dtype[np.int64]] = np.array([1])
63-
AR_f: np.ndarray[Any, np.dtype[np.float64]] = np.array([1.0])
64-
AR_c: np.ndarray[Any, np.dtype[np.complex128]] = np.array([1j])
65-
AR_m: np.ndarray[Any, np.dtype[np.timedelta64]] = np.array([np.timedelta64(1, "D")])
66-
AR_M: np.ndarray[Any, np.dtype[np.datetime64]] = np.array([np.datetime64(1, "D")])
67-
AR_O: np.ndarray[Any, np.dtype[np.object_]] = np.array([Object()])
61+
AR_b: npt.NDArray[np.bool] = np.array([True])
62+
AR_u: npt.NDArray[np.uint32] = np.array([1], dtype=np.uint32)
63+
AR_i: npt.NDArray[np.int64] = np.array([1])
64+
AR_f: npt.NDArray[np.float64] = np.array([1.0])
65+
AR_c: npt.NDArray[np.complex128] = np.array([1j])
66+
AR_m: npt.NDArray[np.timedelta64] = np.array([np.timedelta64(1, "D")])
67+
AR_M: npt.NDArray[np.datetime64] = np.array([np.datetime64(1, "D")])
68+
AR_O: npt.NDArray[np.object_] = np.array([Object()])
6869

6970
AR_LIKE_b = [True]
7071
AR_LIKE_u = [np.uint32(1)]

0 commit comments

Comments
 (0)