Skip to content

Commit 1f8c545

Browse files
committed
TYP: Add support for StringDType in numpy.dtype
1 parent 087a4d8 commit 1f8c545

File tree

4 files changed

+34
-8
lines changed

4 files changed

+34
-8
lines changed

numpy/__init__.pyi

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ from numpy._typing import (
123123
_BytesCodes,
124124
_VoidCodes,
125125
_ObjectCodes,
126+
_StringCodes,
126127

127128
_UnsignedIntegerCodes,
128129
_SignedIntegerCodes,
@@ -942,6 +943,16 @@ class dtype(Generic[_DTypeScalar_co]):
942943
@overload
943944
def __new__(cls, dtype: _ObjectCodes | type[ct.py_object[Any]], align: builtins.bool = ..., copy: builtins.bool = ..., metadata: dict[builtins.str, Any] = ...) -> dtype[object_]: ...
944945

946+
# `StringDType` requires special treatment because it has no scalar type
947+
@overload
948+
def __new__(
949+
cls,
950+
dtype: dtypes.StringDType | _StringCodes,
951+
align: builtins.bool = ...,
952+
copy: builtins.bool = ...,
953+
metadata: dict[builtins.str, Any] = ...
954+
) -> dtypes.StringDType: ...
955+
945956
# Combined char-codes and ctypes, analogous to the scalar-type hierarchy
946957
@overload
947958
def __new__(

numpy/_typing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
_BytesCodes as _BytesCodes,
7272
_VoidCodes as _VoidCodes,
7373
_ObjectCodes as _ObjectCodes,
74+
_StringCodes as _StringCodes,
7475
_UnsignedIntegerCodes as _UnsignedIntegerCodes,
7576
_SignedIntegerCodes as _SignedIntegerCodes,
7677
_IntegerCodes as _IntegerCodes,

numpy/_typing/_char_codes.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@
140140
"m8[as]", "|m8[as]", "=m8[as]", "<m8[as]", ">m8[as]",
141141
]
142142

143+
# NOTE: `StringDType' has no scalar type, and therefore has no name that can
144+
# be passed to the `dtype` constructor
145+
_StringCodes = Literal["T", "|T", "=T", "<T", ">T"]
143146

144147
# NOTE: Nested literals get flattened and de-duplicated at runtime, which isn't
145148
# the case for a `Union` of `Literal`s.
@@ -202,4 +205,6 @@
202205
_DT64Codes,
203206
_TD64Codes,
204207
_ObjectCodes,
208+
# TODO: add `_StringCodes` once it has a scalar type
209+
# _StringCodes,
205210
]

numpy/typing/tests/data/reveal/dtype.pyi

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ from fractions import Fraction
66
from typing import Any, Literal, TypeAlias
77

88
import numpy as np
9+
from numpy.dtypes import StringDType
910

1011
if sys.version_info >= (3, 11):
1112
from typing import assert_type
@@ -31,12 +32,13 @@ ct_floating: type[ct.c_float | ct.c_double | ct.c_longdouble]
3132
ct_number: type[ct.c_uint8 | ct.c_float]
3233
ct_generic: type[ct.c_bool | ct.c_char]
3334

34-
cs_integer: Literal['u1', '<i2', 'L']
35-
cs_number: Literal['=L' ,'i', 'c16']
36-
cs_flex: Literal['>V', 'S']
37-
cs_generic: Literal['H', 'U', 'h', '|M8[Y]', '?']
35+
cs_integer: Literal["u1", "<i2", "L"]
36+
cs_number: Literal["=L" ,"i", "c16"]
37+
cs_flex: Literal[">V", "S"]
38+
cs_generic: Literal["H", "U", "h", "|M8[Y]", "?"]
3839

3940
dt_inexact: np.dtype[np.inexact[Any]]
41+
dt_string: StringDType
4042

4143

4244
assert_type(np.dtype(np.float64), np.dtype[np.float64])
@@ -74,10 +76,10 @@ assert_type(np.dtype(Decimal), np.dtype[np.object_])
7476
assert_type(np.dtype(Fraction), np.dtype[np.object_])
7577

7678
# char-codes
77-
assert_type(np.dtype('u1'), np.dtype[np.uint8])
78-
assert_type(np.dtype('l'), np.dtype[np.long])
79-
assert_type(np.dtype('longlong'), np.dtype[np.longlong])
80-
assert_type(np.dtype('>g'), np.dtype[np.longdouble])
79+
assert_type(np.dtype("u1"), np.dtype[np.uint8])
80+
assert_type(np.dtype("l"), np.dtype[np.long])
81+
assert_type(np.dtype("longlong"), np.dtype[np.longlong])
82+
assert_type(np.dtype(">g"), np.dtype[np.longdouble])
8183
assert_type(np.dtype(cs_integer), np.dtype[np.integer[Any]])
8284
assert_type(np.dtype(cs_number), np.dtype[np.number[Any]])
8385
assert_type(np.dtype(cs_flex), np.dtype[np.flexible])
@@ -104,6 +106,13 @@ assert_type(np.dtype("S8"), np.dtype[Any])
104106
# Void
105107
assert_type(np.dtype(("U", 10)), np.dtype[np.void])
106108

109+
# StringDType
110+
assert_type(np.dtype(dt_string), StringDType)
111+
assert_type(np.dtype("T"), StringDType)
112+
assert_type(np.dtype("=T"), StringDType)
113+
assert_type(np.dtype("|T"), StringDType)
114+
115+
107116
# Methods and attributes
108117
assert_type(dtype_U.base, np.dtype[Any])
109118
assert_type(dtype_U.subdtype, None | tuple[np.dtype[Any], tuple[int, ...]])

0 commit comments

Comments
 (0)