Skip to content

Commit 86bae1a

Browse files
authored
Merge pull request numpy#19543 from BvB93/numerictypes
MAINT: Improve the `np.core.numerictypes` stubs
2 parents 8365e3a + 39b3f3f commit 86bae1a

File tree

3 files changed

+84
-43
lines changed

3 files changed

+84
-43
lines changed

numpy/core/numerictypes.pyi

Lines changed: 66 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
import sys
2+
import types
23
from typing import (
3-
TypeVar,
4-
Optional,
54
Type,
65
Union,
76
Tuple,
8-
Sequence,
97
overload,
108
Any,
119
TypeVar,
1210
Dict,
1311
List,
12+
Iterable,
1413
)
1514

1615
from numpy import (
@@ -48,58 +47,94 @@ from numpy.core._type_aliases import (
4847
sctypes as sctypes,
4948
)
5049

51-
from numpy.typing import DTypeLike, ArrayLike
50+
from numpy.typing import DTypeLike, ArrayLike, _SupportsDType
5251

5352
if sys.version_info >= (3, 8):
54-
from typing import Literal, Protocol, TypedDict
53+
from typing import Literal as L, Protocol, TypedDict
5554
else:
56-
from typing_extensions import Literal, Protocol, TypedDict
55+
from typing_extensions import Literal as L, Protocol, TypedDict
5756

5857
_T = TypeVar("_T")
59-
_ScalarType = TypeVar("_ScalarType", bound=generic)
58+
_SCT = TypeVar("_SCT", bound=generic)
59+
60+
# A paramtrizable subset of `npt.DTypeLike`
61+
_DTypeLike = Union[
62+
Type[_SCT],
63+
dtype[_SCT],
64+
_SupportsDType[dtype[_SCT]],
65+
]
6066

6167
class _CastFunc(Protocol):
6268
def __call__(
6369
self, x: ArrayLike, k: DTypeLike = ...
6470
) -> ndarray[Any, dtype[Any]]: ...
6571

6672
class _TypeCodes(TypedDict):
67-
Character: Literal['c']
68-
Integer: Literal['bhilqp']
69-
UnsignedInteger: Literal['BHILQP']
70-
Float: Literal['efdg']
71-
Complex: Literal['FDG']
72-
AllInteger: Literal['bBhHiIlLqQpP']
73-
AllFloat: Literal['efdgFDG']
74-
Datetime: Literal['Mm']
75-
All: Literal['?bhilqpBHILQPefdgFDGSUVOMm']
73+
Character: L['c']
74+
Integer: L['bhilqp']
75+
UnsignedInteger: L['BHILQP']
76+
Float: L['efdg']
77+
Complex: L['FDG']
78+
AllInteger: L['bBhHiIlLqQpP']
79+
AllFloat: L['efdgFDG']
80+
Datetime: L['Mm']
81+
All: L['?bhilqpBHILQPefdgFDGSUVOMm']
7682

7783
class _typedict(Dict[Type[generic], _T]):
7884
def __getitem__(self, key: DTypeLike) -> _T: ...
7985

86+
if sys.version_info >= (3, 10):
87+
_TypeTuple = Union[
88+
Type[Any],
89+
types.Union,
90+
Tuple[Union[Type[Any], types.Union, Tuple[Any, ...]], ...],
91+
]
92+
else:
93+
_TypeTuple = Union[
94+
Type[Any],
95+
Tuple[Union[Type[Any], Tuple[Any, ...]], ...],
96+
]
97+
8098
__all__: List[str]
8199

82-
# TODO: Clean up the annotations for the 7 functions below
100+
@overload
101+
def maximum_sctype(t: _DTypeLike[_SCT]) -> Type[_SCT]: ...
102+
@overload
103+
def maximum_sctype(t: DTypeLike) -> Type[Any]: ...
104+
105+
@overload
106+
def issctype(rep: dtype[Any] | Type[Any]) -> bool: ...
107+
@overload
108+
def issctype(rep: object) -> L[False]: ...
83109

84-
def maximum_sctype(t: DTypeLike) -> dtype: ...
85-
def issctype(rep: object) -> bool: ...
86110
@overload
87-
def obj2sctype(rep: object) -> Optional[generic]: ...
111+
def obj2sctype(rep: _DTypeLike[_SCT], default: None = ...) -> None | Type[_SCT]: ...
88112
@overload
89-
def obj2sctype(rep: object, default: None) -> Optional[generic]: ...
113+
def obj2sctype(rep: _DTypeLike[_SCT], default: _T) -> _T | Type[_SCT]: ...
90114
@overload
91-
def obj2sctype(
92-
rep: object, default: Type[_T]
93-
) -> Union[generic, Type[_T]]: ...
94-
def issubclass_(arg1: object, arg2: Union[object, Tuple[object, ...]]) -> bool: ...
95-
def issubsctype(
96-
arg1: Union[ndarray, DTypeLike], arg2: Union[ndarray, DTypeLike]
97-
) -> bool: ...
115+
def obj2sctype(rep: DTypeLike, default: None = ...) -> None | Type[Any]: ...
116+
@overload
117+
def obj2sctype(rep: DTypeLike, default: _T) -> _T | Type[Any]: ...
118+
@overload
119+
def obj2sctype(rep: object, default: None = ...) -> None: ...
120+
@overload
121+
def obj2sctype(rep: object, default: _T) -> _T: ...
122+
123+
@overload
124+
def issubclass_(arg1: Type[Any], arg2: _TypeTuple) -> bool: ...
125+
@overload
126+
def issubclass_(arg1: object, arg2: object) -> L[False]: ...
127+
128+
def issubsctype(arg1: DTypeLike, arg2: DTypeLike) -> bool: ...
129+
98130
def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool: ...
99-
def sctype2char(sctype: object) -> str: ...
131+
132+
def sctype2char(sctype: DTypeLike) -> str: ...
133+
100134
def find_common_type(
101-
array_types: Sequence[DTypeLike], scalar_types: Sequence[DTypeLike]
102-
) -> dtype: ...
135+
array_types: Iterable[DTypeLike],
136+
scalar_types: Iterable[DTypeLike],
137+
) -> dtype[Any]: ...
103138

104139
cast: _typedict[_CastFunc]
105140
nbytes: _typedict[int]

numpy/typing/tests/data/fail/numerictypes.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
#
55
# https://github.com/numpy/numpy/issues/16366
66
#
7-
np.maximum_sctype(1) # E: incompatible type "int"
7+
np.maximum_sctype(1) # E: No overload variant
88

9-
np.issubsctype(1, np.int64) # E: incompatible type "int"
9+
np.issubsctype(1, np.int64) # E: incompatible type
1010

11-
np.issubdtype(1, np.int64) # E: incompatible type "int"
11+
np.issubdtype(1, np.int64) # E: incompatible type
1212

13-
np.find_common_type(np.int64, np.int64) # E: incompatible type "Type[signedinteger[Any]]"
13+
np.find_common_type(np.int64, np.int64) # E: incompatible type

numpy/typing/tests/data/reveal/numerictypes.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,27 @@
11
import numpy as np
22

3-
reveal_type(np.issctype(np.generic)) # E: bool
4-
reveal_type(np.issctype("foo")) # E: bool
3+
reveal_type(np.maximum_sctype(np.float64)) # E: Type[{float64}]
4+
reveal_type(np.maximum_sctype("f8")) # E: Type[Any]
55

6-
reveal_type(np.obj2sctype("S8")) # E: Union[numpy.generic, None]
7-
reveal_type(np.obj2sctype("S8", default=None)) # E: Union[numpy.generic, None]
8-
reveal_type(
9-
np.obj2sctype("foo", default=int) # E: Union[numpy.generic, Type[builtins.int*]]
10-
)
6+
reveal_type(np.issctype(np.float64)) # E: bool
7+
reveal_type(np.issctype("foo")) # E: Literal[False]
8+
9+
reveal_type(np.obj2sctype(np.float64)) # E: Union[None, Type[{float64}]]
10+
reveal_type(np.obj2sctype(np.float64, default=False)) # E: Union[builtins.bool, Type[{float64}]]
11+
reveal_type(np.obj2sctype("S8")) # E: Union[None, Type[Any]]
12+
reveal_type(np.obj2sctype("S8", default=None)) # E: Union[None, Type[Any]]
13+
reveal_type(np.obj2sctype("foo", default=False)) # E: Union[builtins.bool, Type[Any]]
14+
reveal_type(np.obj2sctype(1)) # E: None
15+
reveal_type(np.obj2sctype(1, default=False)) # E: bool
1116

1217
reveal_type(np.issubclass_(np.float64, float)) # E: bool
1318
reveal_type(np.issubclass_(np.float64, (int, float))) # E: bool
19+
reveal_type(np.issubclass_(1, 1)) # E: Literal[False]
1420

1521
reveal_type(np.sctype2char("S8")) # E: str
1622
reveal_type(np.sctype2char(list)) # E: str
1723

18-
reveal_type(np.find_common_type([np.int64], [np.int64])) # E: numpy.dtype
24+
reveal_type(np.find_common_type([np.int64], [np.int64])) # E: numpy.dtype[Any]
1925

2026
reveal_type(np.cast[int]) # E: _CastFunc
2127
reveal_type(np.cast["i8"]) # E: _CastFunc

0 commit comments

Comments
 (0)