Skip to content

Commit 4bb11f5

Browse files
committed
Add efficient primitives for str.strip().
1 parent e93f06c commit 4bb11f5

File tree

6 files changed

+235
-1
lines changed

6 files changed

+235
-1
lines changed

mypyc/lib-rt/CPy.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,13 +717,27 @@ static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) {
717717

718718
// Str operations
719719

720+
// Macros for strip type. These values are copied from CPython.
721+
#define LEFTSTRIP 0
722+
#define RIGHTSTRIP 1
723+
#define BOTHSTRIP 2
720724

721725
PyObject *CPyStr_Build(Py_ssize_t len, ...);
722726
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index);
723727
CPyTagged CPyStr_Find(PyObject *str, PyObject *substr, CPyTagged start, int direction);
724728
CPyTagged CPyStr_FindWithEnd(PyObject *str, PyObject *substr, CPyTagged start, CPyTagged end, int direction);
725729
PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split);
726730
PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split);
731+
PyObject *_CPyStr_Strip(PyObject *self, int strip_type, PyObject *sep);
732+
static inline PyObject *CPyStr_Strip(PyObject *self, PyObject *sep) {
733+
return _CPyStr_Strip(self, BOTHSTRIP, sep);
734+
}
735+
static inline PyObject *CPyStr_LStrip(PyObject *self, PyObject *sep) {
736+
return _CPyStr_Strip(self, LEFTSTRIP, sep);
737+
}
738+
static inline PyObject *CPyStr_RStrip(PyObject *self, PyObject *sep) {
739+
return _CPyStr_Strip(self, RIGHTSTRIP, sep);
740+
}
727741
PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace);
728742
PyObject *CPyStr_Append(PyObject *o1, PyObject *o2);
729743
PyObject *CPyStr_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);

mypyc/lib-rt/str_ops.c

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,58 @@
55
#include <Python.h>
66
#include "CPy.h"
77

8+
// Copied from cpython.git:Objects/unicodeobject.c.
9+
#define BLOOM_MASK unsigned long
10+
#define BLOOM(mask, ch) ((mask & (1UL << ((ch) & (BLOOM_WIDTH - 1)))))
11+
#if LONG_BIT >= 128
12+
#define BLOOM_WIDTH 128
13+
#elif LONG_BIT >= 64
14+
#define BLOOM_WIDTH 64
15+
#elif LONG_BIT >= 32
16+
#define BLOOM_WIDTH 32
17+
#else
18+
#error "LONG_BIT is smaller than 32"
19+
#endif
20+
21+
// Copied from cpython.git:Objects/unicodeobject.c. This is needed for str.strip("...").
22+
static inline BLOOM_MASK
23+
make_bloom_mask(int kind, const void* ptr, Py_ssize_t len)
24+
{
25+
#define BLOOM_UPDATE(TYPE, MASK, PTR, LEN) \
26+
do { \
27+
TYPE *data = (TYPE *)PTR; \
28+
TYPE *end = data + LEN; \
29+
Py_UCS4 ch; \
30+
for (; data != end; data++) { \
31+
ch = *data; \
32+
MASK |= (1UL << (ch & (BLOOM_WIDTH - 1))); \
33+
} \
34+
break; \
35+
} while (0)
36+
37+
/* calculate simple bloom-style bitmask for a given unicode string */
38+
39+
BLOOM_MASK mask;
40+
41+
mask = 0;
42+
switch (kind) {
43+
case PyUnicode_1BYTE_KIND:
44+
BLOOM_UPDATE(Py_UCS1, mask, ptr, len);
45+
break;
46+
case PyUnicode_2BYTE_KIND:
47+
BLOOM_UPDATE(Py_UCS2, mask, ptr, len);
48+
break;
49+
case PyUnicode_4BYTE_KIND:
50+
BLOOM_UPDATE(Py_UCS4, mask, ptr, len);
51+
break;
52+
default:
53+
Py_UNREACHABLE();
54+
}
55+
return mask;
56+
57+
#undef BLOOM_UPDATE
58+
}
59+
860
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index) {
961
if (PyUnicode_READY(str) != -1) {
1062
if (CPyTagged_CheckShort(index)) {
@@ -174,6 +226,116 @@ PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split) {
174226
return PyUnicode_RSplit(str, sep, temp_max_split);
175227
}
176228

229+
// This function has been copied from _PyUnicode_XStrip in cpython.git:Objects/unicodeobject.c.
230+
static PyObject *_PyStr_XStrip(PyObject *self, int striptype, PyObject *sepobj) {
231+
const void *data;
232+
int kind;
233+
Py_ssize_t i, j, len;
234+
BLOOM_MASK sepmask;
235+
Py_ssize_t seplen;
236+
237+
kind = PyUnicode_KIND(self);
238+
data = PyUnicode_DATA(self);
239+
len = PyUnicode_GET_LENGTH(self);
240+
seplen = PyUnicode_GET_LENGTH(sepobj);
241+
sepmask = make_bloom_mask(PyUnicode_KIND(sepobj),
242+
PyUnicode_DATA(sepobj),
243+
seplen);
244+
245+
i = 0;
246+
if (striptype != RIGHTSTRIP) {
247+
while (i < len) {
248+
Py_UCS4 ch = PyUnicode_READ(kind, data, i);
249+
if (!BLOOM(sepmask, ch))
250+
break;
251+
if (PyUnicode_FindChar(sepobj, ch, 0, seplen, 1) < 0)
252+
break;
253+
i++;
254+
}
255+
}
256+
257+
j = len;
258+
if (striptype != LEFTSTRIP) {
259+
j--;
260+
while (j >= i) {
261+
Py_UCS4 ch = PyUnicode_READ(kind, data, j);
262+
if (!BLOOM(sepmask, ch))
263+
break;
264+
if (PyUnicode_FindChar(sepobj, ch, 0, seplen, 1) < 0)
265+
break;
266+
j--;
267+
}
268+
269+
j++;
270+
}
271+
272+
return PyUnicode_Substring(self, i, j);
273+
}
274+
275+
// Copied from do_strip function in cpython.git/Objects/unicodeobject.c.
276+
PyObject *_CPyStr_Strip(PyObject *self, int strip_type, PyObject *sep) {
277+
if (sep == NULL || sep == Py_None) {
278+
Py_ssize_t len, i, j;
279+
280+
len = PyUnicode_GET_LENGTH(self);
281+
282+
if (PyUnicode_IS_ASCII(self)) {
283+
const Py_UCS1 *data = PyUnicode_1BYTE_DATA(self);
284+
285+
i = 0;
286+
if (strip_type != RIGHTSTRIP) {
287+
while (i < len) {
288+
Py_UCS1 ch = data[i];
289+
if (!_Py_ascii_whitespace[ch])
290+
break;
291+
i++;
292+
}
293+
}
294+
295+
j = len;
296+
if (strip_type != LEFTSTRIP) {
297+
j--;
298+
while (j >= i) {
299+
Py_UCS1 ch = data[j];
300+
if (!_Py_ascii_whitespace[ch])
301+
break;
302+
j--;
303+
}
304+
j++;
305+
}
306+
}
307+
else {
308+
int kind = PyUnicode_KIND(self);
309+
const void *data = PyUnicode_DATA(self);
310+
311+
i = 0;
312+
if (strip_type != RIGHTSTRIP) {
313+
while (i < len) {
314+
Py_UCS4 ch = PyUnicode_READ(kind, data, i);
315+
if (!Py_UNICODE_ISSPACE(ch))
316+
break;
317+
i++;
318+
}
319+
}
320+
321+
j = len;
322+
if (strip_type != LEFTSTRIP) {
323+
j--;
324+
while (j >= i) {
325+
Py_UCS4 ch = PyUnicode_READ(kind, data, j);
326+
if (!Py_UNICODE_ISSPACE(ch))
327+
break;
328+
j--;
329+
}
330+
j++;
331+
}
332+
}
333+
334+
return PyUnicode_Substring(self, i, j);
335+
}
336+
return _PyStr_XStrip(self, strip_type, sep);
337+
}
338+
177339
PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr,
178340
PyObject *new_substr, CPyTagged max_replace) {
179341
Py_ssize_t temp_max_replace = CPyTagged_AsSsize_t(max_replace);

mypyc/primitives/str_ops.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,26 @@
135135
var_arg_type=str_rprimitive,
136136
)
137137

138+
# str.strip, str.lstrip, str.rstrip
139+
# Order of iteration matters. It should correspond with LEFTSTRIP, RIGHTSTRIP and BOTHSTRIP macros defined in CPy.h.
140+
for strip_prefix in ["l", "r", ""]:
141+
method_op(
142+
name=f"{strip_prefix}strip",
143+
arg_types=[str_rprimitive, str_rprimitive],
144+
return_type=str_rprimitive,
145+
c_function_name=f"CPyStr_{strip_prefix.upper()}Strip",
146+
error_kind=ERR_NEVER,
147+
)
148+
method_op(
149+
name=f"{strip_prefix}strip",
150+
arg_types=[str_rprimitive],
151+
return_type=str_rprimitive,
152+
c_function_name=f"CPyStr_{strip_prefix.upper()}Strip",
153+
# This 0 below is implicitly treated as NULL in C.
154+
extra_int_constants=[(0, c_int_rprimitive)],
155+
error_kind=ERR_NEVER,
156+
)
157+
138158
# str.startswith(str)
139159
method_op(
140160
name="startswith",

mypyc/test-data/fixtures/ir.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def rfind(self, sub: str, start: Optional[int] = None, end: Optional[int] = None
107107
def split(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: pass
108108
def rsplit(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: pass
109109
def splitlines(self, keepends: bool = False) -> List[str]: ...
110-
def strip (self, item: str) -> str: pass
110+
def strip (self, item: Optional[str] = None) -> str: pass
111+
def lstrip(self, item: Optional[str] = None) -> str: pass
112+
def rstrip(self, item: Optional[str] = None) -> str: pass
111113
def join(self, x: Iterable[str]) -> str: pass
112114
def format(self, *args: Any, **kwargs: Any) -> str: ...
113115
def upper(self) -> str: ...

mypyc/test-data/irbuild-str.test

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,3 +481,26 @@ L0:
481481
keep_alive x
482482
r6 = unbox(int, r5)
483483
return r6
484+
485+
[case testStrip]
486+
def do_strip(s: str) -> None:
487+
s.lstrip("x")
488+
s.strip("y")
489+
s.rstrip("z")
490+
s.lstrip()
491+
s.strip()
492+
s.rstrip()
493+
[out]
494+
def do_strip(s):
495+
s, r0, r1, r2, r3, r4, r5, r6, r7, r8 :: str
496+
L0:
497+
r0 = 'x'
498+
r1 = CPyStr_LStrip(s, r0)
499+
r2 = 'y'
500+
r3 = CPyStr_Strip(s, r2)
501+
r4 = 'z'
502+
r5 = CPyStr_RStrip(s, r4)
503+
r6 = CPyStr_LStrip(s, 0)
504+
r7 = CPyStr_Strip(s, 0)
505+
r8 = CPyStr_RStrip(s, 0)
506+
return 1

mypyc/test-data/run-strings.test

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,3 +774,16 @@ def test_surrogate() -> None:
774774
assert ord(f()) == 0xd800
775775
assert ord("\udfff") == 0xdfff
776776
assert repr("foobar\x00\xab\ud912\U00012345") == r"'foobar\x00«\ud912𒍅'"
777+
778+
[case testStrip]
779+
# This is a negative test. strip variants without args does not use efficient primitives.
780+
def test_all_strips_default() -> None:
781+
s = " a1\t"
782+
assert s.lstrip() == "a1\t"
783+
assert s.strip() == "a1"
784+
assert s.rstrip() == " a1"
785+
def test_all_strips() -> None:
786+
s = "xxb2yy"
787+
assert s.lstrip("xy") == "b2yy"
788+
assert s.strip("xy") == "b2"
789+
assert s.rstrip("xy") == "xxb2"

0 commit comments

Comments
 (0)