Skip to content

Commit e1cc10a

Browse files
authored
Merge pull request numpy#27470 from ngoldbaum/stringdtype-typing
TYP: Add type stubs for stringdtype in np.char and np.strings
2 parents 855bed7 + 46844ca commit e1cc10a

File tree

8 files changed

+599
-98
lines changed

8 files changed

+599
-98
lines changed

numpy/_core/defchararray.pyi

Lines changed: 225 additions & 27 deletions
Large diffs are not rendered by default.

numpy/_core/strings.pyi

Lines changed: 194 additions & 47 deletions
Large diffs are not rendered by default.

numpy/_typing/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@
137137
_ArrayLikeVoid_co as _ArrayLikeVoid_co,
138138
_ArrayLikeStr_co as _ArrayLikeStr_co,
139139
_ArrayLikeBytes_co as _ArrayLikeBytes_co,
140+
_ArrayLikeString_co as _ArrayLikeString_co,
141+
_ArrayLikeAnyString_co as _ArrayLikeAnyString_co,
140142
_ArrayLikeUnknown as _ArrayLikeUnknown,
141143
_UnknownType as _UnknownType,
142144
)

numpy/_typing/_array_like.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import sys
44
from collections.abc import Collection, Callable, Sequence
5-
from typing import Any, Protocol, TypeAlias, TypeVar, runtime_checkable
5+
from typing import Any, Protocol, TypeAlias, TypeVar, runtime_checkable, TYPE_CHECKING
66

77
import numpy as np
88
from numpy import (
@@ -24,6 +24,13 @@
2424
from ._nested_sequence import _NestedSequence
2525
from ._shape import _Shape
2626

27+
if TYPE_CHECKING:
28+
StringDType = np.dtypes.StringDType
29+
else:
30+
# at runtime outside of type checking importing this from numpy.dtypes
31+
# would lead to a circular import
32+
from numpy._core.multiarray import StringDType
33+
2734
_T = TypeVar("_T")
2835
_ScalarType = TypeVar("_ScalarType", bound=generic)
2936
_ScalarType_co = TypeVar("_ScalarType_co", bound=generic, covariant=True)
@@ -148,6 +155,15 @@ def __array_function__(
148155
dtype[bytes_],
149156
bytes,
150157
]
158+
_ArrayLikeString_co: TypeAlias = _DualArrayLike[
159+
StringDType,
160+
str
161+
]
162+
_ArrayLikeAnyString_co: TypeAlias = (
163+
_ArrayLikeStr_co |
164+
_ArrayLikeBytes_co |
165+
_ArrayLikeString_co
166+
)
151167

152168
# NOTE: This includes `builtins.bool`, but not `numpy.bool`.
153169
_ArrayLikeInt: TypeAlias = _DualArrayLike[

numpy/typing/tests/data/fail/strings.pyi

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,6 @@ np.strings.partition(AR_S, "a") # E: incompatible type
3939
np.strings.rpartition(AR_U, b"a") # E: incompatible type
4040
np.strings.rpartition(AR_S, "a") # E: incompatible type
4141

42-
np.strings.split(AR_U, b"_") # E: incompatible type
43-
np.strings.split(AR_S, "_") # E: incompatible type
44-
np.strings.rsplit(AR_U, b"_") # E: incompatible type
45-
np.strings.rsplit(AR_S, "_") # E: incompatible type
46-
4742
np.strings.count(AR_U, b"a", [1, 2, 3], [1, 2, 3]) # E: incompatible type
4843
np.strings.count(AR_S, "a", 0, 9) # E: incompatible type
4944

numpy/typing/tests/data/reveal/char.pyi

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,137 +1,208 @@
1-
from typing import Any
2-
31
import numpy as np
42
import numpy.typing as npt
3+
import numpy._typing as np_t
54

65
from typing_extensions import assert_type
6+
from typing import TypeAlias
77

88
AR_U: npt.NDArray[np.str_]
99
AR_S: npt.NDArray[np.bytes_]
10+
AR_T: np.ndarray[np_t._Shape, np.dtypes.StringDType]
11+
12+
AR_T_alias: TypeAlias = np.ndarray[np_t._Shape, np.dtypes.StringDType]
13+
AR_TU_alias: TypeAlias = AR_T_alias | npt.NDArray[np.str_]
1014

1115
assert_type(np.char.equal(AR_U, AR_U), npt.NDArray[np.bool])
1216
assert_type(np.char.equal(AR_S, AR_S), npt.NDArray[np.bool])
17+
assert_type(np.char.equal(AR_T, AR_T), npt.NDArray[np.bool])
1318

1419
assert_type(np.char.not_equal(AR_U, AR_U), npt.NDArray[np.bool])
1520
assert_type(np.char.not_equal(AR_S, AR_S), npt.NDArray[np.bool])
21+
assert_type(np.char.not_equal(AR_T, AR_T), npt.NDArray[np.bool])
1622

1723
assert_type(np.char.greater_equal(AR_U, AR_U), npt.NDArray[np.bool])
1824
assert_type(np.char.greater_equal(AR_S, AR_S), npt.NDArray[np.bool])
25+
assert_type(np.char.greater_equal(AR_T, AR_T), npt.NDArray[np.bool])
1926

2027
assert_type(np.char.less_equal(AR_U, AR_U), npt.NDArray[np.bool])
2128
assert_type(np.char.less_equal(AR_S, AR_S), npt.NDArray[np.bool])
29+
assert_type(np.char.less_equal(AR_T, AR_T), npt.NDArray[np.bool])
2230

2331
assert_type(np.char.greater(AR_U, AR_U), npt.NDArray[np.bool])
2432
assert_type(np.char.greater(AR_S, AR_S), npt.NDArray[np.bool])
33+
assert_type(np.char.greater(AR_T, AR_T), npt.NDArray[np.bool])
2534

2635
assert_type(np.char.less(AR_U, AR_U), npt.NDArray[np.bool])
2736
assert_type(np.char.less(AR_S, AR_S), npt.NDArray[np.bool])
37+
assert_type(np.char.less(AR_T, AR_T), npt.NDArray[np.bool])
2838

2939
assert_type(np.char.multiply(AR_U, 5), npt.NDArray[np.str_])
3040
assert_type(np.char.multiply(AR_S, [5, 4, 3]), npt.NDArray[np.bytes_])
41+
assert_type(np.char.multiply(AR_T, 5), AR_T_alias)
3142

3243
assert_type(np.char.mod(AR_U, "test"), npt.NDArray[np.str_])
3344
assert_type(np.char.mod(AR_S, "test"), npt.NDArray[np.bytes_])
45+
assert_type(np.char.mod(AR_T, "test"), AR_T_alias)
3446

3547
assert_type(np.char.capitalize(AR_U), npt.NDArray[np.str_])
3648
assert_type(np.char.capitalize(AR_S), npt.NDArray[np.bytes_])
49+
assert_type(np.char.capitalize(AR_T), AR_T_alias)
3750

3851
assert_type(np.char.center(AR_U, 5), npt.NDArray[np.str_])
3952
assert_type(np.char.center(AR_S, [2, 3, 4], b"a"), npt.NDArray[np.bytes_])
53+
assert_type(np.char.center(AR_T, 5), AR_T_alias)
4054

4155
assert_type(np.char.encode(AR_U), npt.NDArray[np.bytes_])
56+
assert_type(np.char.encode(AR_T), npt.NDArray[np.bytes_])
4257
assert_type(np.char.decode(AR_S), npt.NDArray[np.str_])
4358

4459
assert_type(np.char.expandtabs(AR_U), npt.NDArray[np.str_])
4560
assert_type(np.char.expandtabs(AR_S, tabsize=4), npt.NDArray[np.bytes_])
61+
assert_type(np.char.expandtabs(AR_T), AR_T_alias)
4662

4763
assert_type(np.char.join(AR_U, "_"), npt.NDArray[np.str_])
4864
assert_type(np.char.join(AR_S, [b"_", b""]), npt.NDArray[np.bytes_])
65+
assert_type(np.char.join(AR_T, "_"), AR_TU_alias)
4966

5067
assert_type(np.char.ljust(AR_U, 5), npt.NDArray[np.str_])
5168
assert_type(np.char.ljust(AR_S, [4, 3, 1], fillchar=[b"a", b"b", b"c"]), npt.NDArray[np.bytes_])
69+
assert_type(np.char.ljust(AR_T, 5), AR_T_alias)
70+
assert_type(np.char.ljust(AR_T, [4, 2, 1], fillchar=["a", "b", "c"]), AR_TU_alias)
71+
5272
assert_type(np.char.rjust(AR_U, 5), npt.NDArray[np.str_])
5373
assert_type(np.char.rjust(AR_S, [4, 3, 1], fillchar=[b"a", b"b", b"c"]), npt.NDArray[np.bytes_])
74+
assert_type(np.char.rjust(AR_T, 5), AR_T_alias)
75+
assert_type(np.char.rjust(AR_T, [4, 2, 1], fillchar=["a", "b", "c"]), AR_TU_alias)
5476

5577
assert_type(np.char.lstrip(AR_U), npt.NDArray[np.str_])
56-
assert_type(np.char.lstrip(AR_S, chars=b"_"), npt.NDArray[np.bytes_])
78+
assert_type(np.char.lstrip(AR_S, b"_"), npt.NDArray[np.bytes_])
79+
assert_type(np.char.lstrip(AR_T), AR_T_alias)
80+
assert_type(np.char.lstrip(AR_T, "_"), AR_TU_alias)
81+
5782
assert_type(np.char.rstrip(AR_U), npt.NDArray[np.str_])
58-
assert_type(np.char.rstrip(AR_S, chars=b"_"), npt.NDArray[np.bytes_])
83+
assert_type(np.char.rstrip(AR_S, b"_"), npt.NDArray[np.bytes_])
84+
assert_type(np.char.rstrip(AR_T), AR_T_alias)
85+
assert_type(np.char.rstrip(AR_T, "_"), AR_TU_alias)
86+
5987
assert_type(np.char.strip(AR_U), npt.NDArray[np.str_])
60-
assert_type(np.char.strip(AR_S, chars=b"_"), npt.NDArray[np.bytes_])
88+
assert_type(np.char.strip(AR_S, b"_"), npt.NDArray[np.bytes_])
89+
assert_type(np.char.strip(AR_T), AR_T_alias)
90+
assert_type(np.char.strip(AR_T, "_"), AR_TU_alias)
91+
92+
assert_type(np.char.count(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.int_])
93+
assert_type(np.char.count(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.int_])
94+
assert_type(np.char.count(AR_T, AR_T, start=[1, 2, 3]), npt.NDArray[np.int_])
95+
assert_type(np.char.count(AR_T, ["a", "b", "c"], end=9), npt.NDArray[np.int_])
6196

6297
assert_type(np.char.partition(AR_U, "\n"), npt.NDArray[np.str_])
6398
assert_type(np.char.partition(AR_S, [b"a", b"b", b"c"]), npt.NDArray[np.bytes_])
99+
assert_type(np.char.partition(AR_T, "\n"), AR_TU_alias)
100+
64101
assert_type(np.char.rpartition(AR_U, "\n"), npt.NDArray[np.str_])
65102
assert_type(np.char.rpartition(AR_S, [b"a", b"b", b"c"]), npt.NDArray[np.bytes_])
103+
assert_type(np.char.rpartition(AR_T, "\n"), AR_TU_alias)
66104

67105
assert_type(np.char.replace(AR_U, "_", "-"), npt.NDArray[np.str_])
68106
assert_type(np.char.replace(AR_S, [b"_", b""], [b"a", b"b"]), npt.NDArray[np.bytes_])
107+
assert_type(np.char.replace(AR_T, "_", "_"), AR_TU_alias)
69108

70109
assert_type(np.char.split(AR_U, "_"), npt.NDArray[np.object_])
71110
assert_type(np.char.split(AR_S, maxsplit=[1, 2, 3]), npt.NDArray[np.object_])
111+
assert_type(np.char.split(AR_T, "_"), npt.NDArray[np.object_])
112+
72113
assert_type(np.char.rsplit(AR_U, "_"), npt.NDArray[np.object_])
73114
assert_type(np.char.rsplit(AR_S, maxsplit=[1, 2, 3]), npt.NDArray[np.object_])
115+
assert_type(np.char.rsplit(AR_T, "_"), npt.NDArray[np.object_])
74116

75117
assert_type(np.char.splitlines(AR_U), npt.NDArray[np.object_])
76118
assert_type(np.char.splitlines(AR_S, keepends=[True, True, False]), npt.NDArray[np.object_])
119+
assert_type(np.char.splitlines(AR_T), npt.NDArray[np.object_])
120+
121+
assert_type(np.char.lower(AR_U), npt.NDArray[np.str_])
122+
assert_type(np.char.lower(AR_S), npt.NDArray[np.bytes_])
123+
assert_type(np.char.lower(AR_T), AR_T_alias)
124+
125+
assert_type(np.char.upper(AR_U), npt.NDArray[np.str_])
126+
assert_type(np.char.upper(AR_S), npt.NDArray[np.bytes_])
127+
assert_type(np.char.upper(AR_T), AR_T_alias)
77128

78129
assert_type(np.char.swapcase(AR_U), npt.NDArray[np.str_])
79130
assert_type(np.char.swapcase(AR_S), npt.NDArray[np.bytes_])
131+
assert_type(np.char.swapcase(AR_T), AR_T_alias)
80132

81133
assert_type(np.char.title(AR_U), npt.NDArray[np.str_])
82134
assert_type(np.char.title(AR_S), npt.NDArray[np.bytes_])
83-
84-
assert_type(np.char.upper(AR_U), npt.NDArray[np.str_])
85-
assert_type(np.char.upper(AR_S), npt.NDArray[np.bytes_])
135+
assert_type(np.char.title(AR_T), AR_T_alias)
86136

87137
assert_type(np.char.zfill(AR_U, 5), npt.NDArray[np.str_])
88138
assert_type(np.char.zfill(AR_S, [2, 3, 4]), npt.NDArray[np.bytes_])
89-
90-
assert_type(np.char.count(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.int_])
91-
assert_type(np.char.count(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.int_])
139+
assert_type(np.char.zfill(AR_T, 5), AR_T_alias)
92140

93141
assert_type(np.char.endswith(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.bool])
94142
assert_type(np.char.endswith(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.bool])
143+
assert_type(np.char.endswith(AR_T, "a", start=[1, 2, 3]), npt.NDArray[np.bool])
144+
95145
assert_type(np.char.startswith(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.bool])
96146
assert_type(np.char.startswith(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.bool])
147+
assert_type(np.char.startswith(AR_T, "a", start=[1, 2, 3]), npt.NDArray[np.bool])
97148

98149
assert_type(np.char.find(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.int_])
99150
assert_type(np.char.find(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.int_])
151+
assert_type(np.char.find(AR_T, "a", start=[1, 2, 3]), npt.NDArray[np.int_])
152+
100153
assert_type(np.char.rfind(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.int_])
101154
assert_type(np.char.rfind(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.int_])
155+
assert_type(np.char.rfind(AR_T, "a", start=[1, 2, 3]), npt.NDArray[np.int_])
102156

103157
assert_type(np.char.index(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.int_])
104158
assert_type(np.char.index(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.int_])
159+
assert_type(np.char.index(AR_T, "a", start=[1, 2, 3]), npt.NDArray[np.int_])
160+
105161
assert_type(np.char.rindex(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.int_])
106162
assert_type(np.char.rindex(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.int_])
163+
assert_type(np.char.rindex(AR_T, "a", start=[1, 2, 3]), npt.NDArray[np.int_])
107164

108165
assert_type(np.char.isalpha(AR_U), npt.NDArray[np.bool])
109166
assert_type(np.char.isalpha(AR_S), npt.NDArray[np.bool])
167+
assert_type(np.char.isalpha(AR_T), npt.NDArray[np.bool])
110168

111169
assert_type(np.char.isalnum(AR_U), npt.NDArray[np.bool])
112170
assert_type(np.char.isalnum(AR_S), npt.NDArray[np.bool])
171+
assert_type(np.char.isalnum(AR_T), npt.NDArray[np.bool])
113172

114173
assert_type(np.char.isdecimal(AR_U), npt.NDArray[np.bool])
174+
assert_type(np.char.isdecimal(AR_T), npt.NDArray[np.bool])
115175

116176
assert_type(np.char.isdigit(AR_U), npt.NDArray[np.bool])
117177
assert_type(np.char.isdigit(AR_S), npt.NDArray[np.bool])
178+
assert_type(np.char.isdigit(AR_T), npt.NDArray[np.bool])
118179

119180
assert_type(np.char.islower(AR_U), npt.NDArray[np.bool])
120181
assert_type(np.char.islower(AR_S), npt.NDArray[np.bool])
182+
assert_type(np.char.islower(AR_T), npt.NDArray[np.bool])
121183

122184
assert_type(np.char.isnumeric(AR_U), npt.NDArray[np.bool])
185+
assert_type(np.char.isnumeric(AR_T), npt.NDArray[np.bool])
123186

124187
assert_type(np.char.isspace(AR_U), npt.NDArray[np.bool])
125188
assert_type(np.char.isspace(AR_S), npt.NDArray[np.bool])
189+
assert_type(np.char.isspace(AR_T), npt.NDArray[np.bool])
126190

127191
assert_type(np.char.istitle(AR_U), npt.NDArray[np.bool])
128192
assert_type(np.char.istitle(AR_S), npt.NDArray[np.bool])
193+
assert_type(np.char.istitle(AR_T), npt.NDArray[np.bool])
129194

130195
assert_type(np.char.isupper(AR_U), npt.NDArray[np.bool])
131196
assert_type(np.char.isupper(AR_S), npt.NDArray[np.bool])
197+
assert_type(np.char.isupper(AR_T), npt.NDArray[np.bool])
132198

133199
assert_type(np.char.str_len(AR_U), npt.NDArray[np.int_])
134200
assert_type(np.char.str_len(AR_S), npt.NDArray[np.int_])
201+
assert_type(np.char.str_len(AR_T), npt.NDArray[np.int_])
202+
203+
assert_type(np.char.translate(AR_U, ""), npt.NDArray[np.str_])
204+
assert_type(np.char.translate(AR_S, ""), npt.NDArray[np.bytes_])
205+
assert_type(np.char.translate(AR_T, ""), AR_T_alias)
135206

136207
assert_type(np.char.array(AR_U), np.char.chararray[tuple[int, ...], np.dtype[np.str_]])
137208
assert_type(np.char.array(AR_S, order="K"), np.char.chararray[tuple[int, ...], np.dtype[np.bytes_]])

0 commit comments

Comments
 (0)