Skip to content

Commit dfd1180

Browse files
authored
Merge pull request numpy#27659 from jorenham/typing/transparent-ndarray-ops
TYP: Transparent ``ndarray`` unary operator method signatures
2 parents ac39902 + a8525db commit dfd1180

File tree

2 files changed

+109
-53
lines changed

2 files changed

+109
-53
lines changed

numpy/__init__.pyi

Lines changed: 30 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,11 +1760,15 @@ _DType = TypeVar("_DType", bound=dtype[Any])
17601760
_DType_co = TypeVar("_DType_co", covariant=True, bound=dtype[Any])
17611761
_FlexDType = TypeVar("_FlexDType", bound=dtype[flexible])
17621762

1763+
_IntegralArrayT = TypeVar("_IntegralArrayT", bound=NDArray[integer[Any] | np.bool | object_])
1764+
_RealArrayT = TypeVar("_RealArrayT", bound=NDArray[floating[Any] | integer[Any] | timedelta64 | np.bool | object_])
1765+
_NumericArrayT = TypeVar("_NumericArrayT", bound=NDArray[number[Any] | timedelta64 | object_])
1766+
17631767
_Shape1D: TypeAlias = tuple[int]
17641768
_Shape2D: TypeAlias = tuple[int, int]
17651769

1770+
_ShapeType = TypeVar("_ShapeType", bound=_Shape)
17661771
_ShapeType_co = TypeVar("_ShapeType_co", covariant=True, bound=_Shape)
1767-
_ShapeType2 = TypeVar("_ShapeType2", bound=_Shape)
17681772
_Shape2DType_co = TypeVar("_Shape2DType_co", covariant=True, bound=_Shape2D)
17691773
_NumberType = TypeVar("_NumberType", bound=number[Any])
17701774

@@ -1881,11 +1885,11 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
18811885

18821886
def __array_wrap__(
18831887
self,
1884-
array: ndarray[_ShapeType2, _DType],
1888+
array: ndarray[_ShapeType, _DType],
18851889
context: None | tuple[ufunc, tuple[Any, ...], int] = ...,
18861890
return_scalar: builtins.bool = ...,
18871891
/,
1888-
) -> ndarray[_ShapeType2, _DType]: ...
1892+
) -> ndarray[_ShapeType, _DType]: ...
18891893

18901894
@overload
18911895
def __getitem__(self, key: (
@@ -2237,22 +2241,10 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
22372241
offset: SupportsIndex = ...
22382242
) -> NDArray[Any]: ...
22392243

2240-
# Dispatch to the underlying `generic` via protocols
2241-
def __int__(
2242-
self: NDArray[SupportsInt], # type: ignore[type-var]
2243-
) -> int: ...
2244-
2245-
def __float__(
2246-
self: NDArray[SupportsFloat], # type: ignore[type-var]
2247-
) -> float: ...
2248-
2249-
def __complex__(
2250-
self: NDArray[SupportsComplex], # type: ignore[type-var]
2251-
) -> complex: ...
2252-
2253-
def __index__(
2254-
self: NDArray[SupportsIndex], # type: ignore[type-var]
2255-
) -> int: ...
2244+
def __index__(self: NDArray[np.integer[Any]], /) -> int: ...
2245+
def __int__(self: NDArray[number[Any] | np.bool | object_], /) -> int: ...
2246+
def __float__(self: NDArray[number[Any] | np.bool | object_], /) -> float: ...
2247+
def __complex__(self: NDArray[number[Any] | np.bool | object_], /) -> complex: ...
22562248

22572249
def __len__(self) -> int: ...
22582250
def __setitem__(self, key, value): ...
@@ -2310,41 +2302,25 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
23102302
def __ge__(self: NDArray[Any], other: _ArrayLikeObject_co, /) -> NDArray[np.bool]: ...
23112303

23122304
# Unary ops
2313-
@overload
2314-
def __abs__(self: NDArray[_UnknownType]) -> NDArray[Any]: ...
2315-
@overload
2316-
def __abs__(self: NDArray[np.bool]) -> NDArray[np.bool]: ...
2317-
@overload
2318-
def __abs__(self: NDArray[complexfloating[_NBit1, _NBit1]]) -> NDArray[floating[_NBit1]]: ...
2319-
@overload
2320-
def __abs__(self: NDArray[_NumberType]) -> NDArray[_NumberType]: ...
2321-
@overload
2322-
def __abs__(self: NDArray[timedelta64]) -> NDArray[timedelta64]: ...
2323-
@overload
2324-
def __abs__(self: NDArray[object_]) -> Any: ...
23252305

2326-
@overload
2327-
def __invert__(self: NDArray[_UnknownType]) -> NDArray[Any]: ...
2328-
@overload
2329-
def __invert__(self: NDArray[np.bool]) -> NDArray[np.bool]: ...
2330-
@overload
2331-
def __invert__(self: NDArray[_IntType]) -> NDArray[_IntType]: ...
2332-
@overload
2333-
def __invert__(self: NDArray[object_]) -> Any: ...
2334-
2335-
@overload
2336-
def __pos__(self: NDArray[_NumberType]) -> NDArray[_NumberType]: ...
2337-
@overload
2338-
def __pos__(self: NDArray[timedelta64]) -> NDArray[timedelta64]: ...
2339-
@overload
2340-
def __pos__(self: NDArray[object_]) -> Any: ...
2341-
2342-
@overload
2343-
def __neg__(self: NDArray[_NumberType]) -> NDArray[_NumberType]: ...
2344-
@overload
2345-
def __neg__(self: NDArray[timedelta64]) -> NDArray[timedelta64]: ...
2346-
@overload
2347-
def __neg__(self: NDArray[object_]) -> Any: ...
2306+
# TODO: Uncomment once https://github.com/python/mypy/issues/14070 is fixed
2307+
# @overload
2308+
# def __abs__(self: ndarray[_ShapeType, dtypes.Complex64DType], /) -> ndarray[_ShapeType, dtypes.Float32DType]: ...
2309+
# @overload
2310+
# def __abs__(self: ndarray[_ShapeType, dtypes.Complex128DType], /) -> ndarray[_ShapeType, dtypes.Float64DType]: ...
2311+
# @overload
2312+
# def __abs__(self: ndarray[_ShapeType, dtypes.CLongDoubleDType], /) -> ndarray[_ShapeType, dtypes.LongDoubleDType]: ...
2313+
# @overload
2314+
# def __abs__(self: ndarray[_ShapeType, dtype[complex128]], /) -> ndarray[_ShapeType, dtype[float64]]: ...
2315+
@overload
2316+
def __abs__(
2317+
self: ndarray[_ShapeType, dtype[complexfloating[_NBit_fc]]], /
2318+
) -> ndarray[_ShapeType, dtype[floating[_NBit_fc]]]: ...
2319+
@overload
2320+
def __abs__(self: _RealArrayT, /) -> _RealArrayT: ...
2321+
def __invert__(self: _IntegralArrayT, /) -> _IntegralArrayT: ... # noqa: PYI019
2322+
def __neg__(self: _NumericArrayT, /) -> _NumericArrayT: ... # noqa: PYI019
2323+
def __pos__(self: _NumericArrayT, /) -> _NumericArrayT: ... # noqa: PYI019
23482324

23492325
# Binary ops
23502326
@overload
@@ -3094,6 +3070,7 @@ _ScalarType = TypeVar("_ScalarType", bound=generic)
30943070
_NBit = TypeVar("_NBit", bound=NBitBase)
30953071
_NBit1 = TypeVar("_NBit1", bound=NBitBase)
30963072
_NBit2 = TypeVar("_NBit2", bound=NBitBase, default=_NBit1)
3073+
_NBit_fc = TypeVar("_NBit_fc", _NBitHalf, _NBitSingle, _NBitDouble, _NBitLongDouble)
30973074

30983075
class generic(_ArrayOrScalarCommon):
30993076
@abstractmethod
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from typing import Protocol, TypeAlias, TypeVar
2+
from typing_extensions import assert_type
3+
import numpy as np
4+
5+
from numpy._typing import _64Bit
6+
7+
8+
_T = TypeVar("_T")
9+
_T_co = TypeVar("_T_co", covariant=True)
10+
11+
class CanAbs(Protocol[_T_co]):
12+
def __abs__(self, /) -> _T_co: ...
13+
14+
class CanInvert(Protocol[_T_co]):
15+
def __invert__(self, /) -> _T_co: ...
16+
17+
class CanNeg(Protocol[_T_co]):
18+
def __neg__(self, /) -> _T_co: ...
19+
20+
class CanPos(Protocol[_T_co]):
21+
def __pos__(self, /) -> _T_co: ...
22+
23+
def do_abs(x: CanAbs[_T]) -> _T: ...
24+
def do_invert(x: CanInvert[_T]) -> _T: ...
25+
def do_neg(x: CanNeg[_T]) -> _T: ...
26+
def do_pos(x: CanPos[_T]) -> _T: ...
27+
28+
_Bool_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.bool]]
29+
_UInt8_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.uint8]]
30+
_Int16_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.int16]]
31+
_LongLong_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.longlong]]
32+
_Float32_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.float32]]
33+
_Float64_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.float64]]
34+
_LongDouble_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.longdouble]]
35+
_Complex64_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.complex64]]
36+
_Complex128_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.complex128]]
37+
_CLongDouble_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.clongdouble]]
38+
39+
b1_1d: _Bool_1d
40+
u1_1d: _UInt8_1d
41+
i2_1d: _Int16_1d
42+
q_1d: _LongLong_1d
43+
f4_1d: _Float32_1d
44+
f8_1d: _Float64_1d
45+
g_1d: _LongDouble_1d
46+
c8_1d: _Complex64_1d
47+
c16_1d: _Complex128_1d
48+
G_1d: _CLongDouble_1d
49+
50+
assert_type(do_abs(b1_1d), _Bool_1d)
51+
assert_type(do_abs(u1_1d), _UInt8_1d)
52+
assert_type(do_abs(i2_1d), _Int16_1d)
53+
assert_type(do_abs(q_1d), _LongLong_1d)
54+
assert_type(do_abs(f4_1d), _Float32_1d)
55+
assert_type(do_abs(f8_1d), _Float64_1d)
56+
assert_type(do_abs(g_1d), _LongDouble_1d)
57+
58+
assert_type(do_abs(c8_1d), _Float32_1d)
59+
# NOTE: Unfortunately it's not possible to have this return a `float64` sctype, see
60+
# https://github.com/python/mypy/issues/14070
61+
assert_type(do_abs(c16_1d), np.ndarray[tuple[int], np.dtype[np.floating[_64Bit]]])
62+
assert_type(do_abs(G_1d), _LongDouble_1d)
63+
64+
assert_type(do_invert(b1_1d), _Bool_1d)
65+
assert_type(do_invert(u1_1d), _UInt8_1d)
66+
assert_type(do_invert(i2_1d), _Int16_1d)
67+
assert_type(do_invert(q_1d), _LongLong_1d)
68+
69+
assert_type(do_neg(u1_1d), _UInt8_1d)
70+
assert_type(do_neg(i2_1d), _Int16_1d)
71+
assert_type(do_neg(q_1d), _LongLong_1d)
72+
assert_type(do_neg(f4_1d), _Float32_1d)
73+
assert_type(do_neg(c16_1d), _Complex128_1d)
74+
75+
assert_type(do_pos(u1_1d), _UInt8_1d)
76+
assert_type(do_pos(i2_1d), _Int16_1d)
77+
assert_type(do_pos(q_1d), _LongLong_1d)
78+
assert_type(do_pos(f4_1d), _Float32_1d)
79+
assert_type(do_pos(c16_1d), _Complex128_1d)

0 commit comments

Comments
 (0)