Skip to content

Commit 4f8c1e6

Browse files
authored
Merge pull request numpy#26353 from ngoldbaum/fix-text-padding-nulls
BUG: ensure text padding ufuncs handle stringdtype nan-like nulls
2 parents 3d33d5f + 197c915 commit 4f8c1e6

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

numpy/_core/src/umath/stringdtype_ufuncs.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1625,12 +1625,19 @@ center_ljust_rjust_strided_loop(PyArrayMethod_Context *context,
16251625
Buffer<ENCODING::UTF8> inbuf((char *)s1.buf, s1.size);
16261626
Buffer<ENCODING::UTF8> fill((char *)s2.buf, s2.size);
16271627

1628+
size_t num_codepoints = inbuf.num_codepoints();
1629+
npy_intp width = (npy_intp)*(npy_int64*)in2;
1630+
1631+
if (num_codepoints > (size_t)width) {
1632+
width = num_codepoints;
1633+
}
1634+
16281635
char *buf = NULL;
16291636
npy_intp newsize;
16301637
int overflowed = npy_mul_sizes_with_overflow(
16311638
&(newsize),
16321639
(npy_intp)num_bytes_for_utf8_character((unsigned char *)s2.buf),
1633-
(npy_intp)*(npy_int64*)in2 - inbuf.num_codepoints());
1640+
width - num_codepoints);
16341641
newsize += s1.size;
16351642

16361643
if (overflowed) {
@@ -1752,6 +1759,9 @@ zfill_strided_loop(PyArrayMethod_Context *context,
17521759
Buffer<ENCODING::UTF8> inbuf((char *)is.buf, is.size);
17531760
size_t in_codepoints = inbuf.num_codepoints();
17541761
size_t width = (size_t)*(npy_int64 *)in2;
1762+
if (in_codepoints > width) {
1763+
width = in_codepoints;
1764+
}
17551765
// number of leading one-byte characters plus the size of the
17561766
// original string
17571767
size_t outsize = (width - in_codepoints) + is.size;

numpy/_core/strings.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,6 @@ def center(a, width, fillchar=' '):
626626
627627
"""
628628
a = np.asanyarray(a)
629-
width = np.maximum(str_len(a), width)
630629
fillchar = np.asanyarray(fillchar, dtype=a.dtype)
631630

632631
if np.any(str_len(fillchar) != 1):
@@ -636,6 +635,7 @@ def center(a, width, fillchar=' '):
636635
if a.dtype.char == "T":
637636
return _center(a, width, fillchar)
638637

638+
width = np.maximum(str_len(a), width)
639639
out_dtype = f"{a.dtype.char}{width.max()}"
640640
shape = np.broadcast_shapes(a.shape, width.shape, fillchar.shape)
641641
out = np.empty_like(a, shape=shape, dtype=out_dtype)
@@ -682,7 +682,6 @@ def ljust(a, width, fillchar=' '):
682682
683683
"""
684684
a = np.asanyarray(a)
685-
width = np.maximum(str_len(a), width)
686685
fillchar = np.asanyarray(fillchar, dtype=a.dtype)
687686

688687
if np.any(str_len(fillchar) != 1):
@@ -692,6 +691,7 @@ def ljust(a, width, fillchar=' '):
692691
if a.dtype.char == "T":
693692
return _ljust(a, width, fillchar)
694693

694+
width = np.maximum(str_len(a), width)
695695
shape = np.broadcast_shapes(a.shape, width.shape, fillchar.shape)
696696
out_dtype = f"{a.dtype.char}{width.max()}"
697697
out = np.empty_like(a, shape=shape, dtype=out_dtype)
@@ -738,7 +738,6 @@ def rjust(a, width, fillchar=' '):
738738
739739
"""
740740
a = np.asanyarray(a)
741-
width = np.maximum(str_len(a), width)
742741
fillchar = np.asanyarray(fillchar, dtype=a.dtype)
743742

744743
if np.any(str_len(fillchar) != 1):
@@ -748,6 +747,7 @@ def rjust(a, width, fillchar=' '):
748747
if a.dtype.char == "T":
749748
return _rjust(a, width, fillchar)
750749

750+
width = np.maximum(str_len(a), width)
751751
shape = np.broadcast_shapes(a.shape, width.shape, fillchar.shape)
752752
out_dtype = f"{a.dtype.char}{width.max()}"
753753
out = np.empty_like(a, shape=shape, dtype=out_dtype)
@@ -784,11 +784,11 @@ def zfill(a, width):
784784
785785
"""
786786
a = np.asanyarray(a)
787-
width = np.maximum(str_len(a), width)
788787

789788
if a.dtype.char == "T":
790789
return _zfill(a, width)
791790

791+
width = np.maximum(str_len(a), width)
792792
shape = np.broadcast_shapes(a.shape, width.shape)
793793
out_dtype = f"{a.dtype.char}{width.max()}"
794794
out = np.empty_like(a, shape=shape, dtype=out_dtype)

numpy/_core/tests/test_stringdtype.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,11 +1210,15 @@ def test_unary(string_array, unicode_array, function_name):
12101210

12111211
PASSES_THROUGH_NAN_NULLS = [
12121212
"add",
1213+
"center",
1214+
"ljust",
12131215
"multiply",
12141216
"replace",
1217+
"rjust",
12151218
"strip",
12161219
"lstrip",
12171220
"rstrip",
1221+
"zfill",
12181222
]
12191223

12201224
NULLS_ARE_FALSEY = [

0 commit comments

Comments
 (0)