Skip to content

Commit 087a4d8

Browse files
committed
TYP: Map type-unions to abstract scalar types in numpy.dtype.__new__.
1 parent 1e2446c commit 087a4d8

File tree

2 files changed

+148
-62
lines changed

2 files changed

+148
-62
lines changed

numpy/__init__.pyi

Lines changed: 108 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,17 @@ from numpy._typing import (
124124
_VoidCodes,
125125
_ObjectCodes,
126126

127+
_UnsignedIntegerCodes,
128+
_SignedIntegerCodes,
129+
_IntegerCodes,
130+
_FloatingCodes,
131+
_ComplexFloatingCodes,
132+
_InexactCodes,
133+
_NumberCodes,
134+
_CharacterCodes,
135+
_FlexibleCodes,
136+
_GenericCodes,
137+
127138
# Ufuncs
128139
_UFunc_Nin1_Nout1,
129140
_UFunc_Nin2_Nout1,
@@ -751,6 +762,22 @@ _DTypeBuiltinKind: TypeAlias = L[
751762
2, # user-defined
752763
]
753764

765+
# NOTE: `type[S] | type[T]` is equivalent to `type[S | T]`
766+
_UnsignedIntegerCType: TypeAlias = type[
767+
ct.c_uint8 | ct.c_uint16 | ct.c_uint32 | ct.c_uint64
768+
| ct.c_ubyte | ct.c_ushort | ct.c_uint | ct.c_ulong | ct.c_ulonglong
769+
| ct.c_size_t | ct.c_void_p
770+
]
771+
_SignedIntegerCType: TypeAlias = type[
772+
ct.c_int8 | ct.c_int16 | ct.c_int32 | ct.c_int64
773+
| ct.c_byte | ct.c_short | ct.c_int | ct.c_long | ct.c_longlong
774+
| ct.c_ssize_t
775+
]
776+
_FloatingCType: TypeAlias = type[ct.c_float | ct.c_double | ct.c_longdouble]
777+
_IntegerCType: TypeAlias = _UnsignedIntegerCType | _SignedIntegerCType
778+
_NumberCType: TypeAlias = _IntegerCType | _IntegerCType
779+
_GenericCType: TypeAlias = _NumberCType | type[ct.c_bool | ct.c_char | ct.py_object[Any]]
780+
754781
# some commonly used builtin types that are known to result in a
755782
# `dtype[object_]`, when their *type* is passed to the `dtype` constructor
756783
# NOTE: `builtins.object` should not be included here
@@ -805,7 +832,6 @@ class dtype(Generic[_DTypeScalar_co]):
805832
# NOTE: This also accepts `dtype: type[complex | float | int | bool]`
806833
@overload
807834
def __new__(cls, dtype: type[complex], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[complex128 | float64 | int_ | np.bool]: ...
808-
809835
# TODO: This weird `memoryview` order is needed to work around a bug in
810836
# typeshed, which causes typecheckers to treat `memoryview` as a subtype
811837
# of `bytes`, even though there's no mention of that in the typing docs.
@@ -822,6 +848,9 @@ class dtype(Generic[_DTypeScalar_co]):
822848
def __new__(cls, dtype: type[builtins.str | bytes], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[character]: ...
823849
@overload
824850
def __new__(cls, dtype: type[memoryview | builtins.str | bytes], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[flexible]: ...
851+
# NOTE: `dtype: type[object]` also accepts e.g. `type[object | complex | ...]`
852+
@overload
853+
def __new__(cls, dtype: type[_BuiltinObjectLike], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[object_]: ...
825854

826855
# `unsignedinteger` string-based representations and ctypes
827856
@overload
@@ -846,8 +875,6 @@ class dtype(Generic[_DTypeScalar_co]):
846875
def __new__(cls, dtype: _ULongCodes | type[ct.c_ulong], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[ulong]: ...
847876
@overload
848877
def __new__(cls, dtype: _ULongLongCodes | type[ct.c_ulonglong], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[ulonglong]: ...
849-
# TODO: (dtype: Union[{all unsigned integer codes}]) -> dtype[unsignedinteger]
850-
# TODO: (dtype: type[Union[{all unsigned integer ctypes}]]) -> dtype[unsignedinteger]
851878

852879
# `signedinteger` string-based representations and ctypes
853880
@overload
@@ -870,10 +897,6 @@ class dtype(Generic[_DTypeScalar_co]):
870897
def __new__(cls, dtype: _LongCodes | type[ct.c_long], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[long]: ...
871898
@overload
872899
def __new__(cls, dtype: _LongLongCodes | type[ct.c_longlong], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[longlong]: ...
873-
# TODO: (dtype: Union[{all signed integer codes}]) -> dtype[signedinteger]
874-
# TODO: (dtype: type[Union[{all signed integer ctypes}]]) -> dtype[signedinteger]
875-
# TODO: (dtype: Union[{all integer codes}]) -> dtype[integer]
876-
# TODO: (dtype: type[Union[{all integer ctypes}]]) -> dtype[integer]
877900

878901
# `floating` string-based representations and ctypes
879902
@overload
@@ -890,8 +913,6 @@ class dtype(Generic[_DTypeScalar_co]):
890913
def __new__(cls, dtype: _DoubleCodes | type[ct.c_double], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[double]: ...
891914
@overload
892915
def __new__(cls, dtype: _LongDoubleCodes | type[ct.c_longdouble], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[longdouble]: ...
893-
# TODO: (dtype: Union[{all floating codes}]) -> dtype[floating]
894-
# TODO: (dtype: type[ct.c_float | ct.c_double | ct.c_longdouble]) -> dtype[floating]
895916

896917
# `complexfloating` string-based representations
897918
@overload
@@ -904,10 +925,6 @@ class dtype(Generic[_DTypeScalar_co]):
904925
def __new__(cls, dtype: _CDoubleCodes, align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[cdouble]: ...
905926
@overload
906927
def __new__(cls, dtype: _CLongDoubleCodes, align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[clongdouble]: ...
907-
# TODO: (dtype: Union[{all complex codes}]) -> dtype[complexfloating]
908-
# TODO: (dtype: Union[{all inexact codes}]) -> dtype[inexact]
909-
# TODO: (dtype: Union[{all number codes}]) -> dtype[number]
910-
# TODO: (dtype: type[Union[{all number ctypes}]]) -> dtype[number]
911928

912929
# Miscellaneous string-based representations and ctypes
913930
@overload
@@ -916,55 +933,111 @@ class dtype(Generic[_DTypeScalar_co]):
916933
def __new__(cls, dtype: _TD64Codes, align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[timedelta64]: ...
917934
@overload
918935
def __new__(cls, dtype: _DT64Codes, align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[datetime64]: ...
919-
920936
@overload
921937
def __new__(cls, dtype: _StrCodes, align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[str_]: ...
922938
@overload
923939
def __new__(cls, dtype: _BytesCodes | type[ct.c_char], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[bytes_]: ...
924940
@overload
925-
def __new__(cls, dtype: _BytesCodes, align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[bytes_]: ...
926-
927-
# TODO: (dtype: _StrCodes | _BytesCodes) -> dtype[character]
928-
@overload
929-
def __new__(cls, dtype: _VoidCodes, align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[void]: ...
930-
# TODO: (dtype: _StrCodes | _BytesCodes | _VoidCodes) -> dtype[flexible]
941+
def __new__(cls, dtype: _VoidCodes | _VoidDTypeLike, align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[void]: ...
931942
@overload
932943
def __new__(cls, dtype: _ObjectCodes | type[ct.py_object[Any]], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[object_]: ...
933-
# TODO: (dtype: Union[{all literal codes}]) -> dtype[generic]
934944

935-
# Structured (`void`-like) dtypes
945+
# Combined char-codes and ctypes, analogous to the scalar-type hierarchy
936946
@overload
937947
def __new__(
938948
cls,
939-
dtype: _VoidDTypeLike,
949+
dtype: _UnsignedIntegerCodes | _UnsignedIntegerCType,
940950
align: builtins.bool = ...,
941951
copy: builtins.bool = ...,
942952
metadata: dict[builtins.str, Any] = ...,
943-
) -> dtype[void]: ...
944-
945-
# Handle strings that can't be expressed as literals; i.e. S1, S2, ...
946-
# NOTE: This isn't limited to flexible types, because `dtype: str` also
947-
# accepts e.g. `str | Literal['f8']`
953+
) -> dtype[unsignedinteger[Any]]: ...
948954
@overload
949955
def __new__(
950956
cls,
951-
dtype: builtins.str,
957+
dtype: _SignedIntegerCodes | _SignedIntegerCType,
952958
align: builtins.bool = ...,
953959
copy: builtins.bool = ...,
954960
metadata: dict[builtins.str, Any] = ...,
955-
) -> dtype[Any]: ...
961+
) -> dtype[signedinteger[Any]]: ...
962+
@overload
963+
def __new__(
964+
cls,
965+
dtype: _IntegerCodes | _IntegerCType,
966+
align: builtins.bool = ...,
967+
copy: builtins.bool = ...,
968+
metadata: dict[builtins.str, Any] = ...,
969+
) -> dtype[integer[Any]]: ...
970+
@overload
971+
def __new__(
972+
cls,
973+
dtype: _FloatingCodes | _FloatingCType,
974+
align: builtins.bool = ...,
975+
copy: builtins.bool = ...,
976+
metadata: dict[builtins.str, Any] = ...,
977+
) -> dtype[floating[Any]]: ...
978+
@overload
979+
def __new__(
980+
cls,
981+
dtype: _ComplexFloatingCodes,
982+
align: builtins.bool = ...,
983+
copy: builtins.bool = ...,
984+
metadata: dict[builtins.str, Any] = ...,
985+
) -> dtype[complexfloating[Any, Any]]: ...
986+
@overload
987+
def __new__(
988+
cls,
989+
dtype: _InexactCodes | _FloatingCType,
990+
align: builtins.bool = ...,
991+
copy: builtins.bool = ...,
992+
metadata: dict[builtins.str, Any] = ...,
993+
) -> dtype[inexact[Any]]: ...
994+
@overload
995+
def __new__(
996+
cls,
997+
dtype: _NumberCodes | _NumberCType,
998+
align: builtins.bool = ...,
999+
copy: builtins.bool = ...,
1000+
metadata: dict[builtins.str, Any] = ...,
1001+
) -> dtype[number[Any]]: ...
1002+
@overload
1003+
def __new__(
1004+
cls,
1005+
dtype: _CharacterCodes | type[ct.c_char],
1006+
align: builtins.bool = ...,
1007+
copy: builtins.bool = ...,
1008+
metadata: dict[builtins.str, Any] = ...,
1009+
) -> dtype[character]: ...
1010+
@overload
1011+
def __new__(
1012+
cls,
1013+
dtype: _FlexibleCodes | type[ct.c_char],
1014+
align: builtins.bool = ...,
1015+
copy: builtins.bool = ...,
1016+
metadata: dict[builtins.str, Any] = ...,
1017+
) -> dtype[flexible]: ...
1018+
@overload
1019+
def __new__(
1020+
cls,
1021+
dtype: _GenericCodes | _GenericCType,
1022+
align: builtins.bool = ...,
1023+
copy: builtins.bool = ...,
1024+
metadata: dict[builtins.str, Any] = ...,
1025+
) -> dtype[generic]: ...
9561026

957-
# Catch-all overload for object-likes
958-
# NOTE: `dtype: type[object]` also accepts e.g. `type[object | complex | ...]`
1027+
# Handle strings that can't be expressed as literals; i.e. "S1", "S2", ...
9591028
@overload
9601029
def __new__(
9611030
cls,
962-
dtype: type[_BuiltinObjectLike],
1031+
dtype: builtins.str,
9631032
align: builtins.bool = ...,
9641033
copy: builtins.bool = ...,
9651034
metadata: dict[builtins.str, Any] = ...,
966-
) -> dtype[object_]: ...
967-
# NOTE: `object_ | Any` is *not* equivalent to `Any`, see:
1035+
) -> dtype[Any]: ...
1036+
1037+
# Catch-all overload for object-likes
1038+
# NOTE: `object_ | Any` is *not* equivalent to `Any` -- it describes some
1039+
# (static) type `T` s.t. `object_ <: T <: builtins.object` (`<:` denotes
1040+
# the subtyping relation, the (gradual) typing analogue of `issubclass()`).
9681041
# https://typing.readthedocs.io/en/latest/spec/concepts.html#union-types
9691042
@overload
9701043
def __new__(

numpy/typing/tests/data/reveal/dtype.pyi

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import ctypes as ct
33
import datetime as dt
44
from decimal import Decimal
55
from fractions import Fraction
6-
from typing import Any, TypeAlias
6+
from typing import Any, Literal, TypeAlias
77

88
import numpy as np
99

@@ -12,27 +12,31 @@ if sys.version_info >= (3, 11):
1212
else:
1313
from typing_extensions import assert_type
1414

15+
# a combination of likely `object` dtype-like candidates (no `_co`)
16+
_PyObjectLike: TypeAlias = Decimal | Fraction | dt.datetime | dt.timedelta
17+
1518
dtype_U: np.dtype[np.str_]
1619
dtype_V: np.dtype[np.void]
1720
dtype_i8: np.dtype[np.int64]
1821

19-
# equivalent to type[int]
20-
py_int_co: type[int] | type[bool]
21-
# equivalent to type[float] (type-check only)
22-
py_float_co: type[float] | type[int] | type[bool]
23-
# equivalent to type[complex] (type-check only)
24-
py_complex_co: type[complex] | type[float] | type[int] | type[bool]
25-
# equivalent to type[object]
26-
py_object_co: (
27-
type[object]
28-
| type[complex] | type[float] | type[int] | type[bool]
29-
| type[str] | type[bytes]
30-
# ...
31-
)
32-
py_character_co: type[str] | type[bytes]
33-
# TODO: also include type[bytes] here once mypy has been upgraded to 1.11,
34-
# which should resolve the `memoryview` typeshed issue.
35-
py_flexible_co: type[memoryview] | type[str]
22+
py_int_co: type[int | bool]
23+
py_float_co: type[float | int | bool]
24+
py_complex_co: type[complex | float | int | bool]
25+
py_object: type[_PyObjectLike]
26+
py_character: type[str | bytes]
27+
# TODO: also include `bytes` here once mypy has been upgraded to >=1.11
28+
py_flexible: type[memoryview] | type[str] # | type[bytes]
29+
30+
ct_floating: type[ct.c_float | ct.c_double | ct.c_longdouble]
31+
ct_number: type[ct.c_uint8 | ct.c_float]
32+
ct_generic: type[ct.c_bool | ct.c_char]
33+
34+
cs_integer: Literal['u1', '<i2', 'L']
35+
cs_number: Literal['=L' ,'i', 'c16']
36+
cs_flex: Literal['>V', 'S']
37+
cs_generic: Literal['H', 'U', 'h', '|M8[Y]', '?']
38+
39+
dt_inexact: np.dtype[np.inexact[Any]]
3640

3741

3842
assert_type(np.dtype(np.float64), np.dtype[np.float64])
@@ -50,27 +54,35 @@ assert_type(np.dtype("str"), np.dtype[np.str_])
5054

5155
# Python types
5256
assert_type(np.dtype(bool), np.dtype[np.bool])
53-
assert_type(np.dtype(int), np.dtype[np.int_ | np.bool])
5457
assert_type(np.dtype(py_int_co), np.dtype[np.int_ | np.bool])
55-
assert_type(np.dtype(float), np.dtype[np.float64 | np.int_ | np.bool])
58+
assert_type(np.dtype(int), np.dtype[np.int_ | np.bool])
5659
assert_type(np.dtype(py_float_co), np.dtype[np.float64 | np.int_ | np.bool])
57-
assert_type(np.dtype(complex), np.dtype[np.complex128 | np.float64 | np.int_ | np.bool])
60+
assert_type(np.dtype(float), np.dtype[np.float64 | np.int_ | np.bool])
5861
assert_type(np.dtype(py_complex_co), np.dtype[np.complex128 | np.float64 | np.int_ | np.bool])
59-
assert_type(np.dtype(object), np.dtype[np.object_ | Any])
60-
assert_type(np.dtype(py_object_co), np.dtype[np.object_ | Any])
61-
62+
assert_type(np.dtype(complex), np.dtype[np.complex128 | np.float64 | np.int_ | np.bool])
63+
assert_type(np.dtype(py_object), np.dtype[np.object_])
6264
assert_type(np.dtype(str), np.dtype[np.str_])
6365
assert_type(np.dtype(bytes), np.dtype[np.bytes_])
64-
assert_type(np.dtype(py_character_co), np.dtype[np.character])
66+
assert_type(np.dtype(py_character), np.dtype[np.character])
6567
assert_type(np.dtype(memoryview), np.dtype[np.void])
66-
assert_type(np.dtype(py_flexible_co), np.dtype[np.flexible])
68+
assert_type(np.dtype(py_flexible), np.dtype[np.flexible])
6769

6870
assert_type(np.dtype(list), np.dtype[np.object_])
6971
assert_type(np.dtype(dt.datetime), np.dtype[np.object_])
7072
assert_type(np.dtype(dt.timedelta), np.dtype[np.object_])
7173
assert_type(np.dtype(Decimal), np.dtype[np.object_])
7274
assert_type(np.dtype(Fraction), np.dtype[np.object_])
7375

76+
# char-codes
77+
assert_type(np.dtype('u1'), np.dtype[np.uint8])
78+
assert_type(np.dtype('l'), np.dtype[np.long])
79+
assert_type(np.dtype('longlong'), np.dtype[np.longlong])
80+
assert_type(np.dtype('>g'), np.dtype[np.longdouble])
81+
assert_type(np.dtype(cs_integer), np.dtype[np.integer[Any]])
82+
assert_type(np.dtype(cs_number), np.dtype[np.number[Any]])
83+
assert_type(np.dtype(cs_flex), np.dtype[np.flexible])
84+
assert_type(np.dtype(cs_generic), np.dtype[np.generic])
85+
7486
# ctypes
7587
assert_type(np.dtype(ct.c_double), np.dtype[np.double])
7688
assert_type(np.dtype(ct.c_longlong), np.dtype[np.longlong])
@@ -82,8 +94,9 @@ assert_type(np.dtype(ct.py_object), np.dtype[np.object_])
8294
# Special case for None
8395
assert_type(np.dtype(None), np.dtype[np.float64])
8496

85-
# Dtypes of dtypes
97+
# Dypes of dtypes
8698
assert_type(np.dtype(np.dtype(np.float64)), np.dtype[np.float64])
99+
assert_type(np.dtype(dt_inexact), np.dtype[np.inexact[Any]])
87100

88101
# Parameterized dtypes
89102
assert_type(np.dtype("S8"), np.dtype[Any])

0 commit comments

Comments
 (0)