Skip to content

Commit a8525db

Browse files
committed
TYP: Workaround a nypy bug in the ndarray builtin type conversion ops
1 parent 3a8e7c9 commit a8525db

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

numpy/__init__.pyi

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2302,12 +2302,22 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
23022302
def __ge__(self: NDArray[Any], other: _ArrayLikeObject_co, /) -> NDArray[np.bool]: ...
23032303

23042304
# Unary ops
2305+
2306+
# TODO: Uncomment once https://github.com/python/mypy/issues/14070 is fixed
2307+
# @overload
2308+
# def __abs__(self: ndarray[_ShapeType, dtypes.Complex64DType], /) -> ndarray[_ShapeType, dtypes.Float32DType]: ...
2309+
# @overload
2310+
# def __abs__(self: ndarray[_ShapeType, dtypes.Complex128DType], /) -> ndarray[_ShapeType, dtypes.Float64DType]: ...
2311+
# @overload
2312+
# def __abs__(self: ndarray[_ShapeType, dtypes.CLongDoubleDType], /) -> ndarray[_ShapeType, dtypes.LongDoubleDType]: ...
2313+
# @overload
2314+
# def __abs__(self: ndarray[_ShapeType, dtype[complex128]], /) -> ndarray[_ShapeType, dtype[float64]]: ...
23052315
@overload
2306-
def __abs__(self: _RealArrayT, /) -> _RealArrayT: ...
2307-
@overload
2308-
def __abs__(self: ndarray[_ShapeType, dtype[complex128]], /) -> ndarray[_ShapeType, dtype[float64]]: ...
2316+
def __abs__(
2317+
self: ndarray[_ShapeType, dtype[complexfloating[_NBit_fc]]], /
2318+
) -> ndarray[_ShapeType, dtype[floating[_NBit_fc]]]: ...
23092319
@overload
2310-
def __abs__(self: ndarray[_ShapeType, dtype[complexfloating[_NBit1]]], /) -> ndarray[_ShapeType, dtype[floating[_NBit1]]]: ...
2320+
def __abs__(self: _RealArrayT, /) -> _RealArrayT: ...
23112321
def __invert__(self: _IntegralArrayT, /) -> _IntegralArrayT: ... # noqa: PYI019
23122322
def __neg__(self: _NumericArrayT, /) -> _NumericArrayT: ... # noqa: PYI019
23132323
def __pos__(self: _NumericArrayT, /) -> _NumericArrayT: ... # noqa: PYI019
@@ -3060,6 +3070,7 @@ _ScalarType = TypeVar("_ScalarType", bound=generic)
30603070
_NBit = TypeVar("_NBit", bound=NBitBase)
30613071
_NBit1 = TypeVar("_NBit1", bound=NBitBase)
30623072
_NBit2 = TypeVar("_NBit2", bound=NBitBase, default=_NBit1)
3073+
_NBit_fc = TypeVar("_NBit_fc", _NBitHalf, _NBitSingle, _NBitDouble, _NBitLongDouble)
30633074

30643075
class generic(_ArrayOrScalarCommon):
30653076
@abstractmethod

0 commit comments

Comments
 (0)