|
| 1 | +from typing import Protocol, TypeAlias, TypeVar |
| 2 | +from typing_extensions import assert_type |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +from numpy._typing import _64Bit |
| 6 | + |
| 7 | + |
| 8 | +_T = TypeVar("_T") |
| 9 | +_T_co = TypeVar("_T_co", covariant=True) |
| 10 | + |
| 11 | +class CanAbs(Protocol[_T_co]): |
| 12 | + def __abs__(self, /) -> _T_co: ... |
| 13 | + |
| 14 | +class CanInvert(Protocol[_T_co]): |
| 15 | + def __invert__(self, /) -> _T_co: ... |
| 16 | + |
| 17 | +class CanNeg(Protocol[_T_co]): |
| 18 | + def __neg__(self, /) -> _T_co: ... |
| 19 | + |
| 20 | +class CanPos(Protocol[_T_co]): |
| 21 | + def __pos__(self, /) -> _T_co: ... |
| 22 | + |
| 23 | +def do_abs(x: CanAbs[_T]) -> _T: ... |
| 24 | +def do_invert(x: CanInvert[_T]) -> _T: ... |
| 25 | +def do_neg(x: CanNeg[_T]) -> _T: ... |
| 26 | +def do_pos(x: CanPos[_T]) -> _T: ... |
| 27 | + |
| 28 | +_Bool_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.bool]] |
| 29 | +_UInt8_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.uint8]] |
| 30 | +_Int16_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.int16]] |
| 31 | +_LongLong_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.longlong]] |
| 32 | +_Float32_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.float32]] |
| 33 | +_Float64_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.float64]] |
| 34 | +_LongDouble_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.longdouble]] |
| 35 | +_Complex64_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.complex64]] |
| 36 | +_Complex128_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.complex128]] |
| 37 | +_CLongDouble_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.clongdouble]] |
| 38 | + |
| 39 | +b1_1d: _Bool_1d |
| 40 | +u1_1d: _UInt8_1d |
| 41 | +i2_1d: _Int16_1d |
| 42 | +q_1d: _LongLong_1d |
| 43 | +f4_1d: _Float32_1d |
| 44 | +f8_1d: _Float64_1d |
| 45 | +g_1d: _LongDouble_1d |
| 46 | +c8_1d: _Complex64_1d |
| 47 | +c16_1d: _Complex128_1d |
| 48 | +G_1d: _CLongDouble_1d |
| 49 | + |
| 50 | +assert_type(do_abs(b1_1d), _Bool_1d) |
| 51 | +assert_type(do_abs(u1_1d), _UInt8_1d) |
| 52 | +assert_type(do_abs(i2_1d), _Int16_1d) |
| 53 | +assert_type(do_abs(q_1d), _LongLong_1d) |
| 54 | +assert_type(do_abs(f4_1d), _Float32_1d) |
| 55 | +assert_type(do_abs(f8_1d), _Float64_1d) |
| 56 | +assert_type(do_abs(g_1d), _LongDouble_1d) |
| 57 | + |
| 58 | +assert_type(do_abs(c8_1d), _Float32_1d) |
| 59 | +# NOTE: Unfortunately it's not possible to have this return a `float64` sctype, see |
| 60 | +# https://github.com/python/mypy/issues/14070 |
| 61 | +assert_type(do_abs(c16_1d), np.ndarray[tuple[int], np.dtype[np.floating[_64Bit]]]) |
| 62 | +assert_type(do_abs(G_1d), _LongDouble_1d) |
| 63 | + |
| 64 | +assert_type(do_invert(b1_1d), _Bool_1d) |
| 65 | +assert_type(do_invert(u1_1d), _UInt8_1d) |
| 66 | +assert_type(do_invert(i2_1d), _Int16_1d) |
| 67 | +assert_type(do_invert(q_1d), _LongLong_1d) |
| 68 | + |
| 69 | +assert_type(do_neg(u1_1d), _UInt8_1d) |
| 70 | +assert_type(do_neg(i2_1d), _Int16_1d) |
| 71 | +assert_type(do_neg(q_1d), _LongLong_1d) |
| 72 | +assert_type(do_neg(f4_1d), _Float32_1d) |
| 73 | +assert_type(do_neg(c16_1d), _Complex128_1d) |
| 74 | + |
| 75 | +assert_type(do_pos(u1_1d), _UInt8_1d) |
| 76 | +assert_type(do_pos(i2_1d), _Int16_1d) |
| 77 | +assert_type(do_pos(q_1d), _LongLong_1d) |
| 78 | +assert_type(do_pos(f4_1d), _Float32_1d) |
| 79 | +assert_type(do_pos(c16_1d), _Complex128_1d) |
0 commit comments