Skip to content

Commit 87a01ae

Browse files
committed
BUG: fix more issues with string ufunc promotion
1 parent 26cdf63 commit 87a01ae

File tree

2 files changed

+107
-51
lines changed

2 files changed

+107
-51
lines changed

numpy/_core/strings.py

Lines changed: 94 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -669,20 +669,29 @@ def center(a, width, fillchar=' '):
669669
array(['a1b2', '1b2a', 'b2a1', '2a1b'], dtype='<U4')
670670
671671
"""
672+
width = np.asanyarray(width)
673+
if not np.issubdtype(width.dtype, np.integer):
674+
raise TypeError(f"unsupported type {width.dtype} for operand 'width'")
675+
672676
a = np.asanyarray(a)
673-
fillchar = np.asanyarray(fillchar, dtype=a.dtype)
677+
fillchar = np.asanyarray(fillchar)
678+
679+
try_out_dt = np.result_type(a, fillchar)
680+
if try_out_dt.char == "T":
681+
a = a.astype(try_out_dt, copy=False)
682+
fillchar = fillchar.astype(try_out_dt, copy=False)
683+
out = None
684+
else:
685+
fillchar = fillchar.astype(a.dtype, copy=False)
686+
width = np.maximum(str_len(a), width)
687+
out_dtype = f"{a.dtype.char}{width.max()}"
688+
shape = np.broadcast_shapes(a.shape, width.shape, fillchar.shape)
689+
out = np.empty_like(a, shape=shape, dtype=out_dtype)
674690

675691
if np.any(str_len(fillchar) != 1):
676692
raise TypeError(
677693
"The fill character must be exactly one character long")
678694

679-
if a.dtype.char == "T":
680-
return _center(a, width, fillchar)
681-
682-
width = np.maximum(str_len(a), width)
683-
out_dtype = f"{a.dtype.char}{width.max()}"
684-
shape = np.broadcast_shapes(a.shape, width.shape, fillchar.shape)
685-
out = np.empty_like(a, shape=shape, dtype=out_dtype)
686695
return _center(a, width, fillchar, out=out)
687696

688697

@@ -726,20 +735,29 @@ def ljust(a, width, fillchar=' '):
726735
array(['aAaAaA ', ' aA ', 'abBABba '], dtype='<U9')
727736
728737
"""
738+
width = np.asanyarray(width)
739+
if not np.issubdtype(width.dtype, np.integer):
740+
raise TypeError(f"unsupported type {width.dtype} for operand 'width'")
741+
729742
a = np.asanyarray(a)
730-
fillchar = np.asanyarray(fillchar, dtype=a.dtype)
743+
fillchar = np.asanyarray(fillchar)
744+
745+
try_out_dt = np.result_type(a, fillchar)
746+
if try_out_dt.char == "T":
747+
a = a.astype(try_out_dt, copy=False)
748+
fillchar = fillchar.astype(try_out_dt, copy=False)
749+
out = None
750+
else:
751+
fillchar = fillchar.astype(a.dtype, copy=False)
752+
width = np.maximum(str_len(a), width)
753+
shape = np.broadcast_shapes(a.shape, width.shape, fillchar.shape)
754+
out_dtype = f"{a.dtype.char}{width.max()}"
755+
out = np.empty_like(a, shape=shape, dtype=out_dtype)
731756

732757
if np.any(str_len(fillchar) != 1):
733758
raise TypeError(
734759
"The fill character must be exactly one character long")
735760

736-
if a.dtype.char == "T":
737-
return _ljust(a, width, fillchar)
738-
739-
width = np.maximum(str_len(a), width)
740-
shape = np.broadcast_shapes(a.shape, width.shape, fillchar.shape)
741-
out_dtype = f"{a.dtype.char}{width.max()}"
742-
out = np.empty_like(a, shape=shape, dtype=out_dtype)
743761
return _ljust(a, width, fillchar, out=out)
744762

745763

@@ -783,20 +801,29 @@ def rjust(a, width, fillchar=' '):
783801
array([' aAaAaA', ' aA ', ' abBABba'], dtype='<U9')
784802
785803
"""
804+
width = np.asanyarray(width)
805+
if not np.issubdtype(width.dtype, np.integer):
806+
raise TypeError(f"unsupported type {width.dtype} for operand 'width'")
807+
786808
a = np.asanyarray(a)
787-
fillchar = np.asanyarray(fillchar, dtype=a.dtype)
809+
fillchar = np.asanyarray(fillchar)
810+
811+
try_out_dt = np.result_type(a, fillchar)
812+
if try_out_dt.char == "T":
813+
a = a.astype(try_out_dt, copy=False)
814+
fillchar = fillchar.astype(try_out_dt, copy=False)
815+
out = None
816+
else:
817+
fillchar = fillchar.astype(a.dtype, copy=False)
818+
width = np.maximum(str_len(a), width)
819+
shape = np.broadcast_shapes(a.shape, width.shape, fillchar.shape)
820+
out_dtype = f"{a.dtype.char}{width.max()}"
821+
out = np.empty_like(a, shape=shape, dtype=out_dtype)
788822

789823
if np.any(str_len(fillchar) != 1):
790824
raise TypeError(
791825
"The fill character must be exactly one character long")
792826

793-
if a.dtype.char == "T":
794-
return _rjust(a, width, fillchar)
795-
796-
width = np.maximum(str_len(a), width)
797-
shape = np.broadcast_shapes(a.shape, width.shape, fillchar.shape)
798-
out_dtype = f"{a.dtype.char}{width.max()}"
799-
out = np.empty_like(a, shape=shape, dtype=out_dtype)
800827
return _rjust(a, width, fillchar, out=out)
801828

802829

@@ -830,6 +857,10 @@ def zfill(a, width):
830857
array(['001', '-01', '+01'], dtype='<U3')
831858
832859
"""
860+
width = np.asanyarray(width)
861+
if not np.issubdtype(width.dtype, np.integer):
862+
raise TypeError(f"unsupported type {width.dtype} for operand 'width'")
863+
833864
a = np.asanyarray(a)
834865

835866
if a.dtype.char == "T":
@@ -1205,22 +1236,33 @@ def replace(a, old, new, count=-1):
12051236
array(['The dwash was fresh', 'Thwas was it'], dtype='<U19')
12061237
12071238
"""
1208-
arr = np.asanyarray(a)
1209-
a_dt = arr.dtype
1210-
old = np.asanyarray(old, dtype=getattr(old, 'dtype', a_dt))
1211-
new = np.asanyarray(new, dtype=getattr(new, 'dtype', a_dt))
12121239
count = np.asanyarray(count)
1240+
if not np.issubdtype(count.dtype, np.integer):
1241+
raise TypeError(f"unsupported type {count.dtype} for operand 'count'")
12131242

1214-
if arr.dtype.char == "T":
1215-
return _replace(arr, old, new, count)
1216-
1217-
max_int64 = np.iinfo(np.int64).max
1218-
counts = _count_ufunc(arr, old, 0, max_int64)
1219-
counts = np.where(count < 0, counts, np.minimum(counts, count))
1220-
1221-
buffersizes = str_len(arr) + counts * (str_len(new) - str_len(old))
1222-
out_dtype = f"{arr.dtype.char}{buffersizes.max()}"
1223-
out = np.empty_like(arr, shape=buffersizes.shape, dtype=out_dtype)
1243+
arr = np.asanyarray(a)
1244+
old_dtype = getattr(old, 'dtype', None)
1245+
old = np.asanyarray(old)
1246+
new_dtype = getattr(new, 'dtype', None)
1247+
new = np.asanyarray(new)
1248+
1249+
try_out_dt = np.result_type(arr, old, new)
1250+
if try_out_dt.char == "T":
1251+
arr = a.astype(try_out_dt, copy=False)
1252+
old = old.astype(try_out_dt, copy=False)
1253+
new = new.astype(try_out_dt, copy=False)
1254+
counts = count
1255+
out = None
1256+
else:
1257+
a_dt = arr.dtype
1258+
old = old.astype(old_dtype if old_dtype else a_dt, copy=False)
1259+
new = new.astype(new_dtype if new_dtype else a_dt, copy=False)
1260+
max_int64 = np.iinfo(np.int64).max
1261+
counts = _count_ufunc(arr, old, 0, max_int64)
1262+
counts = np.where(count < 0, counts, np.minimum(counts, count))
1263+
buffersizes = str_len(arr) + counts * (str_len(new) - str_len(old))
1264+
out_dtype = f"{arr.dtype.char}{buffersizes.max()}"
1265+
out = np.empty_like(arr, shape=buffersizes.shape, dtype=out_dtype)
12241266
return _replace(arr, old, new, counts, out=out)
12251267

12261268

@@ -1429,11 +1471,15 @@ def partition(a, sep):
14291471
14301472
"""
14311473
a = np.asanyarray(a)
1432-
# TODO switch to copy=False when issues around views are fixed
1433-
sep = np.array(sep, dtype=a.dtype, copy=True, subok=True)
1434-
if a.dtype.char == "T":
1474+
sep = np.asanyarray(sep)
1475+
1476+
try_out_dt = np.result_type(a, sep)
1477+
if try_out_dt.char == "T":
1478+
a = a.astype(try_out_dt, copy=False)
1479+
sep = sep.astype(try_out_dt, copy=False)
14351480
return _partition(a, sep)
14361481

1482+
sep = sep.astype(a.dtype, copy=False)
14371483
pos = _find_ufunc(a, sep, 0, MAX)
14381484
a_len = str_len(a)
14391485
sep_len = str_len(sep)
@@ -1495,11 +1541,15 @@ def rpartition(a, sep):
14951541
14961542
"""
14971543
a = np.asanyarray(a)
1498-
# TODO switch to copy=False when issues around views are fixed
1499-
sep = np.array(sep, dtype=a.dtype, copy=True, subok=True)
1500-
if a.dtype.char == "T":
1544+
sep = np.asanyarray(sep)
1545+
1546+
try_out_dt = np.result_type(a, sep)
1547+
if try_out_dt.char == "T":
1548+
a = a.astype(try_out_dt, copy=False)
1549+
sep = sep.astype(try_out_dt, copy=False)
15011550
return _rpartition(a, sep)
15021551

1552+
sep = sep.astype(a.dtype, copy=False)
15031553
pos = _rfind_ufunc(a, sep, 0, MAX)
15041554
a_len = str_len(a)
15051555
sep_len = str_len(sep)

numpy/_core/tests/test_stringdtype.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,13 +1012,16 @@ def test_findlike_promoters():
10121012

10131013

10141014
def test_strip_promoter():
1015-
arg = "Hello!!!!"
1015+
arg = ["Hello!!!!", "Hello??!!"]
10161016
strip_char = "!"
1017-
answer = "Hello"
1017+
answer = ["Hello", "Hello??"]
10181018
for dtypes in [("T", "U"), ("U", "T")]:
1019-
assert answer == np.strings.strip(
1020-
np.array(arg, dtype=dtypes[0]), np.array(strip_char, dtype=dtypes[1])
1019+
result = np.strings.strip(
1020+
np.array(arg, dtype=dtypes[0]),
1021+
np.array(strip_char, dtype=dtypes[1])
10211022
)
1023+
assert_array_equal(result, answer)
1024+
assert result.dtype.char == "T"
10221025

10231026

10241027
def test_replace_promoter():
@@ -1035,15 +1038,18 @@ def test_replace_promoter():
10351038
np.array(new, dtype=dtypes[2]),
10361039
)
10371040
assert_array_equal(answer_arr, answer)
1041+
assert answer_arr.dtype.char == "T"
10381042

10391043

10401044
def test_center_promoter():
1041-
arg = "Hello, planet!"
1045+
arg = ["Hello", "planet!"]
10421046
fillchar = "/"
10431047
for dtypes in [("T", "U"), ("U", "T")]:
1044-
assert "/Hello, planet!/" == np.strings.center(
1045-
np.array(arg, dtype=dtypes[0]), 16, np.array(fillchar, dtype=dtypes[1])
1048+
answer = np.strings.center(
1049+
np.array(arg, dtype=dtypes[0]), 9, np.array(fillchar, dtype=dtypes[1])
10461050
)
1051+
assert_array_equal(answer, ["//Hello//", "/planet!/"])
1052+
assert answer.dtype.char == "T"
10471053

10481054

10491055
DATETIME_INPUT = [

0 commit comments

Comments
 (0)