Skip to content

Commit 8b9d772

Browse files
committed
TYP: fix mypy
1 parent 76396a3 commit 8b9d772

File tree

3 files changed

+11
-9
lines changed

3 files changed

+11
-9
lines changed

numpy/_core/defchararray.pyi

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ from numpy import (
1919
_OrderKACF,
2020
_ShapeType_co,
2121
_SupportsBuffer,
22+
_SupportsArray
2223
)
2324
from numpy._typing import (
2425
NDArray,
@@ -98,6 +99,8 @@ _CharDType_co = TypeVar(
9899
_CharArray: TypeAlias = chararray[tuple[int, ...], dtype[_SCT]]
99100

100101
_StringDTypeArray: TypeAlias = np.ndarray[_Shape, np.dtypes.StringDType]
102+
_StringDTypeSupportsArray: TypeAlias = _SupportsArray[np.dtypes.StringDType]
103+
_StringDTypeOrUnicodeArray: TypeAlias = np.ndarray[_Shape, np.dtype[np.str_] | np.dtypes.StringDType]
101104

102105
class chararray(ndarray[_ShapeType_co, _CharDType_co]):
103106
@overload
@@ -564,14 +567,19 @@ def add(x1: U_co, x2: U_co) -> NDArray[np.str_]: ...
564567
@overload
565568
def add(x1: S_co, x2: S_co) -> NDArray[np.bytes_]: ...
566569
@overload
567-
def add(x1: T_co, x2: T_co) -> _StringDTypeArray: ...
570+
def add(x1: _StringDTypeSupportsArray, x2: _StringDTypeSupportsArray) -> _StringDTypeArray: ...
571+
@overload
572+
def add(x1: T_co, T_co) -> _StringDTypeOrUnicodeArray: ...
568573

569574
@overload
570575
def multiply(a: U_co, i: i_co) -> NDArray[np.str_]: ...
571576
@overload
572577
def multiply(a: S_co, i: i_co) -> NDArray[np.bytes_]: ...
573578
@overload
574-
def multiply(a: T_co, i: i_co) -> _StringDTypeArray: ...
579+
def multiply(a: _StringDTypeSupportsArray, i: i_co) -> _StringDTypeArray: ...
580+
@overload
581+
def multiply(a: T_co, i: i_co) -> _StringDTypeOrUnicodeArray: ...
582+
575583

576584
@overload
577585
def mod(a: U_co, value: Any) -> NDArray[np.str_]: ...

numpy/_typing/_array_like.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
str_,
2222
bytes_,
2323
)
24-
from numpy._core.multiarray import StringDType
2524
from ._nested_sequence import _NestedSequence
2625
from ._shape import _Shape
2726

@@ -150,7 +149,7 @@ def __array_function__(
150149
bytes,
151150
]
152151
_ArrayLikeString_co: TypeAlias = _DualArrayLike[
153-
StringDType,
152+
np.dtypes.StringDType,
154153
str
155154
]
156155
_ArrayLikeAnyString_co: TypeAlias = _ArrayLikeStr_co | _ArrayLikeBytes_co | _ArrayLikeString_co

numpy/typing/tests/data/fail/strings.pyi

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,6 @@ np.strings.partition(AR_S, "a") # E: incompatible type
3939
np.strings.rpartition(AR_U, b"a") # E: incompatible type
4040
np.strings.rpartition(AR_S, "a") # E: incompatible type
4141

42-
np.strings.split(AR_U, b"_") # E: incompatible type
43-
np.strings.split(AR_S, "_") # E: incompatible type
44-
np.strings.rsplit(AR_U, b"_") # E: incompatible type
45-
np.strings.rsplit(AR_S, "_") # E: incompatible type
46-
4742
np.strings.count(AR_U, b"a", [1, 2, 3], [1, 2, 3]) # E: incompatible type
4843
np.strings.count(AR_S, "a", 0, 9) # E: incompatible type
4944

0 commit comments

Comments
 (0)