Skip to content

Commit af75b6f

Browse files
authored
Merge pull request numpy#27139 from jorenham/typing/ufunc-constructor-specialization
TYP: Fixed & improved ``numpy.dtype.__new__``
2 parents 1b4bd2a + 9f03471 commit af75b6f

File tree

5 files changed

+393
-38
lines changed

5 files changed

+393
-38
lines changed

numpy/__init__.pyi

Lines changed: 246 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ import enum
99
from abc import abstractmethod
1010
from types import EllipsisType, TracebackType, MappingProxyType, GenericAlias
1111
from contextlib import contextmanager
12+
from decimal import Decimal
13+
from fractions import Fraction
14+
from uuid import UUID
1215

1316
import numpy as np
1417
from numpy._pytesttester import PytestTester
@@ -120,6 +123,18 @@ from numpy._typing import (
120123
_BytesCodes,
121124
_VoidCodes,
122125
_ObjectCodes,
126+
_StringCodes,
127+
128+
_UnsignedIntegerCodes,
129+
_SignedIntegerCodes,
130+
_IntegerCodes,
131+
_FloatingCodes,
132+
_ComplexFloatingCodes,
133+
_InexactCodes,
134+
_NumberCodes,
135+
_CharacterCodes,
136+
_FlexibleCodes,
137+
_GenericCodes,
123138

124139
# Ufuncs
125140
_UFunc_Nin1_Nout1,
@@ -747,40 +762,170 @@ _DTypeBuiltinKind: TypeAlias = L[
747762
2, # user-defined
748763
]
749764

765+
# NOTE: `type[S] | type[T]` is equivalent to `type[S | T]`
766+
_UnsignedIntegerCType: TypeAlias = type[
767+
ct.c_uint8 | ct.c_uint16 | ct.c_uint32 | ct.c_uint64
768+
| ct.c_ubyte | ct.c_ushort | ct.c_uint | ct.c_ulong | ct.c_ulonglong
769+
| ct.c_size_t | ct.c_void_p
770+
]
771+
_SignedIntegerCType: TypeAlias = type[
772+
ct.c_int8 | ct.c_int16 | ct.c_int32 | ct.c_int64
773+
| ct.c_byte | ct.c_short | ct.c_int | ct.c_long | ct.c_longlong
774+
| ct.c_ssize_t
775+
]
776+
_FloatingCType: TypeAlias = type[ct.c_float | ct.c_double | ct.c_longdouble]
777+
_IntegerCType: TypeAlias = _UnsignedIntegerCType | _SignedIntegerCType
778+
_NumberCType: TypeAlias = _IntegerCType | _IntegerCType
779+
_GenericCType: TypeAlias = _NumberCType | type[ct.c_bool | ct.c_char | ct.py_object[Any]]
780+
781+
# some commonly used builtin types that are known to result in a
782+
# `dtype[object_]`, when their *type* is passed to the `dtype` constructor
783+
# NOTE: `builtins.object` should not be included here
784+
_BuiltinObjectLike: TypeAlias = (
785+
slice | Decimal | Fraction | UUID
786+
| dt.date | dt.time | dt.timedelta | dt.tzinfo
787+
| tuple[Any, ...] | list[Any] | set[Any] | frozenset[Any] | dict[Any, Any]
788+
) # fmt: skip
789+
750790
@final
751791
class dtype(Generic[_DTypeScalar_co]):
752792
names: None | tuple[builtins.str, ...]
753793
def __hash__(self) -> int: ...
754-
# Overload for subclass of generic
794+
795+
# `None` results in the default dtype
796+
@overload
797+
def __new__(
798+
cls,
799+
dtype: None | type[float64],
800+
align: builtins.bool = ...,
801+
copy: builtins.bool = ...,
802+
metadata: dict[builtins.str, Any] = ...
803+
) -> dtype[float64]: ...
804+
805+
# Overload for `dtype` instances, scalar types, and instances that have a
806+
# `dtype: dtype[_SCT]` attribute
755807
@overload
756808
def __new__(
757809
cls,
758-
dtype: type[_DTypeScalar_co],
810+
dtype: _DTypeLike[_SCT],
759811
align: builtins.bool = ...,
760812
copy: builtins.bool = ...,
761813
metadata: dict[builtins.str, Any] = ...,
762-
) -> dtype[_DTypeScalar_co]: ...
763-
# Overloads for string aliases, Python types, and some assorted
764-
# other special cases. Order is sometimes important because of the
765-
# subtype relationships
814+
) -> dtype[_SCT]: ...
815+
816+
# Builtin types
766817
#
767-
# builtins.bool < int < float < complex < object
818+
# NOTE: Typecheckers act as if `bool <: int <: float <: complex <: object`,
819+
# even though at runtime `int`, `float`, and `complex` aren't subtypes..
820+
# This makes it impossible to express e.g. "a float that isn't an int",
821+
# since type checkers treat `_: float` like `_: float | int`.
768822
#
769-
# so we have to make sure the overloads for the narrowest type is
770-
# first.
771-
# Builtin types
823+
# For more details, see:
824+
# - https://github.com/numpy/numpy/issues/27032#issuecomment-2278958251
825+
# - https://typing.readthedocs.io/en/latest/spec/special-types.html#special-cases-for-float-and-complex
772826
@overload
773-
def __new__(cls, dtype: type[builtins.bool], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[np.bool]: ...
827+
def __new__(
828+
cls,
829+
dtype: type[builtins.bool | np.bool],
830+
align: builtins.bool = ...,
831+
copy: builtins.bool = ...,
832+
metadata: dict[str, Any] = ...,
833+
) -> dtype[np.bool]: ...
834+
# NOTE: `_: type[int]` also accepts `type[int | bool]`
774835
@overload
775-
def __new__(cls, dtype: type[int], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[int_]: ...
836+
def __new__(
837+
cls,
838+
dtype: type[int | int_ | np.bool],
839+
align: builtins.bool = ...,
840+
copy: builtins.bool = ...,
841+
metadata: dict[str, Any] = ...,
842+
) -> dtype[int_ | np.bool]: ...
843+
# NOTE: `_: type[float]` also accepts `type[float | int | bool]`
844+
# NOTE: `float64` inheritcs from `float` at runtime; but this isn't
845+
# reflected in these stubs. So an explicit `float64` is required here.
776846
@overload
777-
def __new__(cls, dtype: None | type[float], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[float64]: ...
847+
def __new__(
848+
cls,
849+
dtype: None | type[float | float64 | int_ | np.bool],
850+
align: builtins.bool = ...,
851+
copy: builtins.bool = ...,
852+
metadata: dict[str, Any] = ...,
853+
) -> dtype[float64 | int_ | np.bool]: ...
854+
# NOTE: `_: type[complex]` also accepts `type[complex | float | int | bool]`
778855
@overload
779-
def __new__(cls, dtype: type[complex], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[complex128]: ...
856+
def __new__(
857+
cls,
858+
dtype: type[complex | complex128 | float64 | int_ | np.bool],
859+
align: builtins.bool = ...,
860+
copy: builtins.bool = ...,
861+
metadata: dict[str, Any] = ...,
862+
) -> dtype[complex128 | float64 | int_ | np.bool]: ...
780863
@overload
781-
def __new__(cls, dtype: type[builtins.str], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[str_]: ...
864+
def __new__(
865+
cls,
866+
dtype: type[bytes], # also includes `type[bytes_]`
867+
align: builtins.bool = ...,
868+
copy: builtins.bool = ...,
869+
metadata: dict[str, Any] = ...,
870+
) -> dtype[bytes_]: ...
871+
@overload
872+
def __new__(
873+
cls,
874+
dtype: type[str], # also includes `type[str_]`
875+
align: builtins.bool = ...,
876+
copy: builtins.bool = ...,
877+
metadata: dict[str, Any] = ...,
878+
) -> dtype[str_]: ...
879+
# NOTE: These `memoryview` overloads assume PEP 688, which requires mypy to
880+
# be run with the (undocumented) `--disable-memoryview-promotion` flag,
881+
# This will be the default in a future mypy release, see:
882+
# https://github.com/python/mypy/issues/15313
883+
# Pyright / Pylance requires setting `disableBytesTypePromotions=true`,
884+
# which is the default in strict mode
885+
@overload
886+
def __new__(
887+
cls,
888+
dtype: type[memoryview | void],
889+
align: builtins.bool = ...,
890+
copy: builtins.bool = ...,
891+
metadata: dict[str, Any] = ...,
892+
) -> dtype[void]: ...
893+
# NOTE: `_: type[object]` would also accept e.g. `type[object | complex]`,
894+
# and is therefore not included here
895+
@overload
896+
def __new__(
897+
cls,
898+
dtype: type[_BuiltinObjectLike | object_],
899+
align: builtins.bool = ...,
900+
copy: builtins.bool = ...,
901+
metadata: dict[str, Any] = ...,
902+
) -> dtype[object_]: ...
903+
904+
# Unions of builtins.
905+
@overload
906+
def __new__(
907+
cls,
908+
dtype: type[bytes | str],
909+
align: builtins.bool = ...,
910+
copy: builtins.bool = ...,
911+
metadata: dict[str, Any] = ...,
912+
) -> dtype[character]: ...
913+
@overload
914+
def __new__(
915+
cls,
916+
dtype: type[bytes | str | memoryview],
917+
align: builtins.bool = ...,
918+
copy: builtins.bool = ...,
919+
metadata: dict[str, Any] = ...,
920+
) -> dtype[flexible]: ...
782921
@overload
783-
def __new__(cls, dtype: type[bytes], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[bytes_]: ...
922+
def __new__(
923+
cls,
924+
dtype: type[complex | bytes | str | memoryview | _BuiltinObjectLike],
925+
align: builtins.bool = ...,
926+
copy: builtins.bool = ...,
927+
metadata: dict[str, Any] = ...,
928+
) -> dtype[np.bool | int_ | float64 | complex128 | flexible | object_]: ...
784929

785930
# `unsignedinteger` string-based representations and ctypes
786931
@overload
@@ -797,7 +942,6 @@ class dtype(Generic[_DTypeScalar_co]):
797942
def __new__(cls, dtype: _UShortCodes | type[ct.c_ushort], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[ushort]: ...
798943
@overload
799944
def __new__(cls, dtype: _UIntCCodes | type[ct.c_uint], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[uintc]: ...
800-
801945
# NOTE: We're assuming here that `uint_ptr_t == size_t`,
802946
# an assumption that does not hold in rare cases (same for `ssize_t`)
803947
@overload
@@ -869,54 +1013,125 @@ class dtype(Generic[_DTypeScalar_co]):
8691013
@overload
8701014
def __new__(cls, dtype: _BytesCodes | type[ct.c_char], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[bytes_]: ...
8711015
@overload
872-
def __new__(cls, dtype: _VoidCodes, align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[void]: ...
1016+
def __new__(cls, dtype: _VoidCodes | _VoidDTypeLike, align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[void]: ...
8731017
@overload
8741018
def __new__(cls, dtype: _ObjectCodes | type[ct.py_object[Any]], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[object_]: ...
8751019

876-
# dtype of a dtype is the same dtype
1020+
# `StringDType` requires special treatment because it has no scalar type
1021+
@overload
1022+
def __new__(
1023+
cls,
1024+
dtype: dtypes.StringDType | _StringCodes,
1025+
align: builtins.bool = ...,
1026+
copy: builtins.bool = ...,
1027+
metadata: dict[builtins.str, Any] = ...
1028+
) -> dtypes.StringDType: ...
1029+
1030+
# Combined char-codes and ctypes, analogous to the scalar-type hierarchy
8771031
@overload
8781032
def __new__(
8791033
cls,
880-
dtype: dtype[_DTypeScalar_co],
1034+
dtype: _UnsignedIntegerCodes | _UnsignedIntegerCType,
8811035
align: builtins.bool = ...,
8821036
copy: builtins.bool = ...,
8831037
metadata: dict[builtins.str, Any] = ...,
884-
) -> dtype[_DTypeScalar_co]: ...
1038+
) -> dtype[unsignedinteger[Any]]: ...
8851039
@overload
8861040
def __new__(
8871041
cls,
888-
dtype: _SupportsDType[dtype[_DTypeScalar_co]],
1042+
dtype: _SignedIntegerCodes | _SignedIntegerCType,
8891043
align: builtins.bool = ...,
8901044
copy: builtins.bool = ...,
8911045
metadata: dict[builtins.str, Any] = ...,
892-
) -> dtype[_DTypeScalar_co]: ...
893-
# Handle strings that can't be expressed as literals; i.e. s1, s2, ...
1046+
) -> dtype[signedinteger[Any]]: ...
8941047
@overload
8951048
def __new__(
8961049
cls,
897-
dtype: builtins.str,
1050+
dtype: _IntegerCodes | _IntegerCType,
8981051
align: builtins.bool = ...,
8991052
copy: builtins.bool = ...,
9001053
metadata: dict[builtins.str, Any] = ...,
901-
) -> dtype[Any]: ...
902-
# Catchall overload for void-likes
1054+
) -> dtype[integer[Any]]: ...
9031055
@overload
9041056
def __new__(
9051057
cls,
906-
dtype: _VoidDTypeLike,
1058+
dtype: _FloatingCodes | _FloatingCType,
9071059
align: builtins.bool = ...,
9081060
copy: builtins.bool = ...,
9091061
metadata: dict[builtins.str, Any] = ...,
910-
) -> dtype[void]: ...
911-
# Catchall overload for object-likes
1062+
) -> dtype[floating[Any]]: ...
1063+
@overload
1064+
def __new__(
1065+
cls,
1066+
dtype: _ComplexFloatingCodes,
1067+
align: builtins.bool = ...,
1068+
copy: builtins.bool = ...,
1069+
metadata: dict[builtins.str, Any] = ...,
1070+
) -> dtype[complexfloating[Any, Any]]: ...
1071+
@overload
1072+
def __new__(
1073+
cls,
1074+
dtype: _InexactCodes | _FloatingCType,
1075+
align: builtins.bool = ...,
1076+
copy: builtins.bool = ...,
1077+
metadata: dict[builtins.str, Any] = ...,
1078+
) -> dtype[inexact[Any]]: ...
1079+
@overload
1080+
def __new__(
1081+
cls,
1082+
dtype: _NumberCodes | _NumberCType,
1083+
align: builtins.bool = ...,
1084+
copy: builtins.bool = ...,
1085+
metadata: dict[builtins.str, Any] = ...,
1086+
) -> dtype[number[Any]]: ...
1087+
@overload
1088+
def __new__(
1089+
cls,
1090+
dtype: _CharacterCodes | type[ct.c_char],
1091+
align: builtins.bool = ...,
1092+
copy: builtins.bool = ...,
1093+
metadata: dict[builtins.str, Any] = ...,
1094+
) -> dtype[character]: ...
1095+
@overload
1096+
def __new__(
1097+
cls,
1098+
dtype: _FlexibleCodes | type[ct.c_char],
1099+
align: builtins.bool = ...,
1100+
copy: builtins.bool = ...,
1101+
metadata: dict[builtins.str, Any] = ...,
1102+
) -> dtype[flexible]: ...
1103+
@overload
1104+
def __new__(
1105+
cls,
1106+
dtype: _GenericCodes | _GenericCType,
1107+
align: builtins.bool = ...,
1108+
copy: builtins.bool = ...,
1109+
metadata: dict[builtins.str, Any] = ...,
1110+
) -> dtype[generic]: ...
1111+
1112+
# Handle strings that can't be expressed as literals; i.e. "S1", "S2", ...
1113+
@overload
1114+
def __new__(
1115+
cls,
1116+
dtype: builtins.str,
1117+
align: builtins.bool = ...,
1118+
copy: builtins.bool = ...,
1119+
metadata: dict[builtins.str, Any] = ...,
1120+
) -> dtype[Any]: ...
1121+
1122+
# Catch-all overload for object-likes
1123+
# NOTE: `object_ | Any` is *not* equivalent to `Any` -- it describes some
1124+
# (static) type `T` s.t. `object_ <: T <: builtins.object` (`<:` denotes
1125+
# the subtyping relation, the (gradual) typing analogue of `issubclass()`).
1126+
# https://typing.readthedocs.io/en/latest/spec/concepts.html#union-types
9121127
@overload
9131128
def __new__(
9141129
cls,
9151130
dtype: type[object],
9161131
align: builtins.bool = ...,
9171132
copy: builtins.bool = ...,
9181133
metadata: dict[builtins.str, Any] = ...,
919-
) -> dtype[object_]: ...
1134+
) -> dtype[object_ | Any]: ...
9201135

9211136
def __class_getitem__(cls, item: Any, /) -> GenericAlias: ...
9221137

numpy/_typing/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,17 @@
7171
_BytesCodes as _BytesCodes,
7272
_VoidCodes as _VoidCodes,
7373
_ObjectCodes as _ObjectCodes,
74+
_StringCodes as _StringCodes,
75+
_UnsignedIntegerCodes as _UnsignedIntegerCodes,
76+
_SignedIntegerCodes as _SignedIntegerCodes,
77+
_IntegerCodes as _IntegerCodes,
78+
_FloatingCodes as _FloatingCodes,
79+
_ComplexFloatingCodes as _ComplexFloatingCodes,
80+
_InexactCodes as _InexactCodes,
81+
_NumberCodes as _NumberCodes,
82+
_CharacterCodes as _CharacterCodes,
83+
_FlexibleCodes as _FlexibleCodes,
84+
_GenericCodes as _GenericCodes,
7485
)
7586
from ._scalars import (
7687
_CharLike_co as _CharLike_co,

0 commit comments

Comments
 (0)