Skip to content

Commit 3a8e7c9

Browse files
committed
TYP: Type-test the ndarray builtin type conversion operators
1 parent 4b7ae75 commit 3a8e7c9

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from typing import Protocol, TypeAlias, TypeVar
2+
from typing_extensions import assert_type
3+
import numpy as np
4+
5+
from numpy._typing import _64Bit
6+
7+
8+
_T = TypeVar("_T")
9+
_T_co = TypeVar("_T_co", covariant=True)
10+
11+
class CanAbs(Protocol[_T_co]):
12+
def __abs__(self, /) -> _T_co: ...
13+
14+
class CanInvert(Protocol[_T_co]):
15+
def __invert__(self, /) -> _T_co: ...
16+
17+
class CanNeg(Protocol[_T_co]):
18+
def __neg__(self, /) -> _T_co: ...
19+
20+
class CanPos(Protocol[_T_co]):
21+
def __pos__(self, /) -> _T_co: ...
22+
23+
def do_abs(x: CanAbs[_T]) -> _T: ...
24+
def do_invert(x: CanInvert[_T]) -> _T: ...
25+
def do_neg(x: CanNeg[_T]) -> _T: ...
26+
def do_pos(x: CanPos[_T]) -> _T: ...
27+
28+
_Bool_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.bool]]
29+
_UInt8_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.uint8]]
30+
_Int16_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.int16]]
31+
_LongLong_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.longlong]]
32+
_Float32_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.float32]]
33+
_Float64_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.float64]]
34+
_LongDouble_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.longdouble]]
35+
_Complex64_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.complex64]]
36+
_Complex128_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.complex128]]
37+
_CLongDouble_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.clongdouble]]
38+
39+
b1_1d: _Bool_1d
40+
u1_1d: _UInt8_1d
41+
i2_1d: _Int16_1d
42+
q_1d: _LongLong_1d
43+
f4_1d: _Float32_1d
44+
f8_1d: _Float64_1d
45+
g_1d: _LongDouble_1d
46+
c8_1d: _Complex64_1d
47+
c16_1d: _Complex128_1d
48+
G_1d: _CLongDouble_1d
49+
50+
assert_type(do_abs(b1_1d), _Bool_1d)
51+
assert_type(do_abs(u1_1d), _UInt8_1d)
52+
assert_type(do_abs(i2_1d), _Int16_1d)
53+
assert_type(do_abs(q_1d), _LongLong_1d)
54+
assert_type(do_abs(f4_1d), _Float32_1d)
55+
assert_type(do_abs(f8_1d), _Float64_1d)
56+
assert_type(do_abs(g_1d), _LongDouble_1d)
57+
58+
assert_type(do_abs(c8_1d), _Float32_1d)
59+
# NOTE: Unfortunately it's not possible to have this return a `float64` sctype, see
60+
# https://github.com/python/mypy/issues/14070
61+
assert_type(do_abs(c16_1d), np.ndarray[tuple[int], np.dtype[np.floating[_64Bit]]])
62+
assert_type(do_abs(G_1d), _LongDouble_1d)
63+
64+
assert_type(do_invert(b1_1d), _Bool_1d)
65+
assert_type(do_invert(u1_1d), _UInt8_1d)
66+
assert_type(do_invert(i2_1d), _Int16_1d)
67+
assert_type(do_invert(q_1d), _LongLong_1d)
68+
69+
assert_type(do_neg(u1_1d), _UInt8_1d)
70+
assert_type(do_neg(i2_1d), _Int16_1d)
71+
assert_type(do_neg(q_1d), _LongLong_1d)
72+
assert_type(do_neg(f4_1d), _Float32_1d)
73+
assert_type(do_neg(c16_1d), _Complex128_1d)
74+
75+
assert_type(do_pos(u1_1d), _UInt8_1d)
76+
assert_type(do_pos(i2_1d), _Int16_1d)
77+
assert_type(do_pos(q_1d), _LongLong_1d)
78+
assert_type(do_pos(f4_1d), _Float32_1d)
79+
assert_type(do_pos(c16_1d), _Complex128_1d)

0 commit comments

Comments
 (0)