From 4bb11f504bffaed722a15d8bd335850f3405ce0a Mon Sep 17 00:00:00 2001 From: Advait Dixit Date: Wed, 26 Feb 2025 14:36:07 -0800 Subject: [PATCH 1/2] Add efficient primitives for str.strip(). --- mypyc/lib-rt/CPy.h | 14 +++ mypyc/lib-rt/str_ops.c | 162 +++++++++++++++++++++++++++++++ mypyc/primitives/str_ops.py | 20 ++++ mypyc/test-data/fixtures/ir.py | 4 +- mypyc/test-data/irbuild-str.test | 23 +++++ mypyc/test-data/run-strings.test | 13 +++ 6 files changed, 235 insertions(+), 1 deletion(-) diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 1c8b59855fc7..fda7ff4eb09c 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -717,6 +717,10 @@ static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) { // Str operations +// Macros for strip type. These values are copied from CPython. +#define LEFTSTRIP 0 +#define RIGHTSTRIP 1 +#define BOTHSTRIP 2 PyObject *CPyStr_Build(Py_ssize_t len, ...); PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index); @@ -724,6 +728,16 @@ CPyTagged CPyStr_Find(PyObject *str, PyObject *substr, CPyTagged start, int dire CPyTagged CPyStr_FindWithEnd(PyObject *str, PyObject *substr, CPyTagged start, CPyTagged end, int direction); PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split); PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split); +PyObject *_CPyStr_Strip(PyObject *self, int strip_type, PyObject *sep); +static inline PyObject *CPyStr_Strip(PyObject *self, PyObject *sep) { + return _CPyStr_Strip(self, BOTHSTRIP, sep); +} +static inline PyObject *CPyStr_LStrip(PyObject *self, PyObject *sep) { + return _CPyStr_Strip(self, LEFTSTRIP, sep); +} +static inline PyObject *CPyStr_RStrip(PyObject *self, PyObject *sep) { + return _CPyStr_Strip(self, RIGHTSTRIP, sep); +} PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace); PyObject *CPyStr_Append(PyObject *o1, PyObject *o2); PyObject *CPyStr_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end); diff --git a/mypyc/lib-rt/str_ops.c b/mypyc/lib-rt/str_ops.c index 5b295f84440b..3749a1600f0a 100644 --- a/mypyc/lib-rt/str_ops.c +++ b/mypyc/lib-rt/str_ops.c @@ -5,6 +5,58 @@ #include #include "CPy.h" +// Copied from cpython.git:Objects/unicodeobject.c. +#define BLOOM_MASK unsigned long +#define BLOOM(mask, ch) ((mask & (1UL << ((ch) & (BLOOM_WIDTH - 1))))) +#if LONG_BIT >= 128 +#define BLOOM_WIDTH 128 +#elif LONG_BIT >= 64 +#define BLOOM_WIDTH 64 +#elif LONG_BIT >= 32 +#define BLOOM_WIDTH 32 +#else +#error "LONG_BIT is smaller than 32" +#endif + +// Copied from cpython.git:Objects/unicodeobject.c. This is needed for str.strip("..."). +static inline BLOOM_MASK +make_bloom_mask(int kind, const void* ptr, Py_ssize_t len) +{ +#define BLOOM_UPDATE(TYPE, MASK, PTR, LEN) \ + do { \ + TYPE *data = (TYPE *)PTR; \ + TYPE *end = data + LEN; \ + Py_UCS4 ch; \ + for (; data != end; data++) { \ + ch = *data; \ + MASK |= (1UL << (ch & (BLOOM_WIDTH - 1))); \ + } \ + break; \ + } while (0) + + /* calculate simple bloom-style bitmask for a given unicode string */ + + BLOOM_MASK mask; + + mask = 0; + switch (kind) { + case PyUnicode_1BYTE_KIND: + BLOOM_UPDATE(Py_UCS1, mask, ptr, len); + break; + case PyUnicode_2BYTE_KIND: + BLOOM_UPDATE(Py_UCS2, mask, ptr, len); + break; + case PyUnicode_4BYTE_KIND: + BLOOM_UPDATE(Py_UCS4, mask, ptr, len); + break; + default: + Py_UNREACHABLE(); + } + return mask; + +#undef BLOOM_UPDATE +} + PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index) { if (PyUnicode_READY(str) != -1) { if (CPyTagged_CheckShort(index)) { @@ -174,6 +226,116 @@ PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split) { return PyUnicode_RSplit(str, sep, temp_max_split); } +// This function has been copied from _PyUnicode_XStrip in cpython.git:Objects/unicodeobject.c. +static PyObject *_PyStr_XStrip(PyObject *self, int striptype, PyObject *sepobj) { + const void *data; + int kind; + Py_ssize_t i, j, len; + BLOOM_MASK sepmask; + Py_ssize_t seplen; + + kind = PyUnicode_KIND(self); + data = PyUnicode_DATA(self); + len = PyUnicode_GET_LENGTH(self); + seplen = PyUnicode_GET_LENGTH(sepobj); + sepmask = make_bloom_mask(PyUnicode_KIND(sepobj), + PyUnicode_DATA(sepobj), + seplen); + + i = 0; + if (striptype != RIGHTSTRIP) { + while (i < len) { + Py_UCS4 ch = PyUnicode_READ(kind, data, i); + if (!BLOOM(sepmask, ch)) + break; + if (PyUnicode_FindChar(sepobj, ch, 0, seplen, 1) < 0) + break; + i++; + } + } + + j = len; + if (striptype != LEFTSTRIP) { + j--; + while (j >= i) { + Py_UCS4 ch = PyUnicode_READ(kind, data, j); + if (!BLOOM(sepmask, ch)) + break; + if (PyUnicode_FindChar(sepobj, ch, 0, seplen, 1) < 0) + break; + j--; + } + + j++; + } + + return PyUnicode_Substring(self, i, j); +} + +// Copied from do_strip function in cpython.git/Objects/unicodeobject.c. +PyObject *_CPyStr_Strip(PyObject *self, int strip_type, PyObject *sep) { + if (sep == NULL || sep == Py_None) { + Py_ssize_t len, i, j; + + len = PyUnicode_GET_LENGTH(self); + + if (PyUnicode_IS_ASCII(self)) { + const Py_UCS1 *data = PyUnicode_1BYTE_DATA(self); + + i = 0; + if (strip_type != RIGHTSTRIP) { + while (i < len) { + Py_UCS1 ch = data[i]; + if (!_Py_ascii_whitespace[ch]) + break; + i++; + } + } + + j = len; + if (strip_type != LEFTSTRIP) { + j--; + while (j >= i) { + Py_UCS1 ch = data[j]; + if (!_Py_ascii_whitespace[ch]) + break; + j--; + } + j++; + } + } + else { + int kind = PyUnicode_KIND(self); + const void *data = PyUnicode_DATA(self); + + i = 0; + if (strip_type != RIGHTSTRIP) { + while (i < len) { + Py_UCS4 ch = PyUnicode_READ(kind, data, i); + if (!Py_UNICODE_ISSPACE(ch)) + break; + i++; + } + } + + j = len; + if (strip_type != LEFTSTRIP) { + j--; + while (j >= i) { + Py_UCS4 ch = PyUnicode_READ(kind, data, j); + if (!Py_UNICODE_ISSPACE(ch)) + break; + j--; + } + j++; + } + } + + return PyUnicode_Substring(self, i, j); + } + return _PyStr_XStrip(self, strip_type, sep); +} + PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace) { Py_ssize_t temp_max_replace = CPyTagged_AsSsize_t(max_replace); diff --git a/mypyc/primitives/str_ops.py b/mypyc/primitives/str_ops.py index e4c644470ba4..4ea4edef9be1 100644 --- a/mypyc/primitives/str_ops.py +++ b/mypyc/primitives/str_ops.py @@ -135,6 +135,26 @@ var_arg_type=str_rprimitive, ) +# str.strip, str.lstrip, str.rstrip +# Order of iteration matters. It should correspond with LEFTSTRIP, RIGHTSTRIP and BOTHSTRIP macros defined in CPy.h. +for strip_prefix in ["l", "r", ""]: + method_op( + name=f"{strip_prefix}strip", + arg_types=[str_rprimitive, str_rprimitive], + return_type=str_rprimitive, + c_function_name=f"CPyStr_{strip_prefix.upper()}Strip", + error_kind=ERR_NEVER, + ) + method_op( + name=f"{strip_prefix}strip", + arg_types=[str_rprimitive], + return_type=str_rprimitive, + c_function_name=f"CPyStr_{strip_prefix.upper()}Strip", + # This 0 below is implicitly treated as NULL in C. + extra_int_constants=[(0, c_int_rprimitive)], + error_kind=ERR_NEVER, + ) + # str.startswith(str) method_op( name="startswith", diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 38fecbc20c65..e651e7adc384 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -107,7 +107,9 @@ def rfind(self, sub: str, start: Optional[int] = None, end: Optional[int] = None def split(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: pass def rsplit(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: pass def splitlines(self, keepends: bool = False) -> List[str]: ... - def strip (self, item: str) -> str: pass + def strip (self, item: Optional[str] = None) -> str: pass + def lstrip(self, item: Optional[str] = None) -> str: pass + def rstrip(self, item: Optional[str] = None) -> str: pass def join(self, x: Iterable[str]) -> str: pass def format(self, *args: Any, **kwargs: Any) -> str: ... def upper(self) -> str: ... diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test index 352fb6cf72d9..ad495dddcb15 100644 --- a/mypyc/test-data/irbuild-str.test +++ b/mypyc/test-data/irbuild-str.test @@ -481,3 +481,26 @@ L0: keep_alive x r6 = unbox(int, r5) return r6 + +[case testStrip] +def do_strip(s: str) -> None: + s.lstrip("x") + s.strip("y") + s.rstrip("z") + s.lstrip() + s.strip() + s.rstrip() +[out] +def do_strip(s): + s, r0, r1, r2, r3, r4, r5, r6, r7, r8 :: str +L0: + r0 = 'x' + r1 = CPyStr_LStrip(s, r0) + r2 = 'y' + r3 = CPyStr_Strip(s, r2) + r4 = 'z' + r5 = CPyStr_RStrip(s, r4) + r6 = CPyStr_LStrip(s, 0) + r7 = CPyStr_Strip(s, 0) + r8 = CPyStr_RStrip(s, 0) + return 1 diff --git a/mypyc/test-data/run-strings.test b/mypyc/test-data/run-strings.test index ce5c85059aed..cd5def316e24 100644 --- a/mypyc/test-data/run-strings.test +++ b/mypyc/test-data/run-strings.test @@ -774,3 +774,16 @@ def test_surrogate() -> None: assert ord(f()) == 0xd800 assert ord("\udfff") == 0xdfff assert repr("foobar\x00\xab\ud912\U00012345") == r"'foobar\x00«\ud912𒍅'" + +[case testStrip] +# This is a negative test. strip variants without args does not use efficient primitives. +def test_all_strips_default() -> None: + s = " a1\t" + assert s.lstrip() == "a1\t" + assert s.strip() == "a1" + assert s.rstrip() == " a1" +def test_all_strips() -> None: + s = "xxb2yy" + assert s.lstrip("xy") == "b2yy" + assert s.strip("xy") == "b2" + assert s.rstrip("xy") == "xxb2" From f2ccd8a4681dc82a7b324a6c3505d6ce59662aef Mon Sep 17 00:00:00 2001 From: Advait Dixit Date: Tue, 4 Mar 2025 15:43:55 -0800 Subject: [PATCH 2/2] Addressing review comments. * Fixing code comments. * Adding tests with more unicode chars. * Adding commit ID for code copied from cpython.git. --- mypyc/lib-rt/str_ops.c | 17 +++++++++++++---- mypyc/primitives/str_ops.py | 1 - mypyc/test-data/run-strings.test | 8 +++++++- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/mypyc/lib-rt/str_ops.c b/mypyc/lib-rt/str_ops.c index 3749a1600f0a..130840cf4e08 100644 --- a/mypyc/lib-rt/str_ops.c +++ b/mypyc/lib-rt/str_ops.c @@ -5,7 +5,7 @@ #include #include "CPy.h" -// Copied from cpython.git:Objects/unicodeobject.c. +// Copied from cpython.git:Objects/unicodeobject.c@0ef4ffeefd1737c18dc9326133c7894d58108c2e. #define BLOOM_MASK unsigned long #define BLOOM(mask, ch) ((mask & (1UL << ((ch) & (BLOOM_WIDTH - 1))))) #if LONG_BIT >= 128 @@ -18,7 +18,8 @@ #error "LONG_BIT is smaller than 32" #endif -// Copied from cpython.git:Objects/unicodeobject.c. This is needed for str.strip("..."). +// Copied from cpython.git:Objects/unicodeobject.c@0ef4ffeefd1737c18dc9326133c7894d58108c2e. +// This is needed for str.strip("..."). static inline BLOOM_MASK make_bloom_mask(int kind, const void* ptr, Py_ssize_t len) { @@ -226,7 +227,7 @@ PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split) { return PyUnicode_RSplit(str, sep, temp_max_split); } -// This function has been copied from _PyUnicode_XStrip in cpython.git:Objects/unicodeobject.c. +// This function has been copied from _PyUnicode_XStrip in cpython.git:Objects/unicodeobject.c@0ef4ffeefd1737c18dc9326133c7894d58108c2e. static PyObject *_PyStr_XStrip(PyObject *self, int striptype, PyObject *sepobj) { const void *data; int kind; @@ -234,6 +235,10 @@ static PyObject *_PyStr_XStrip(PyObject *self, int striptype, PyObject *sepobj) BLOOM_MASK sepmask; Py_ssize_t seplen; + // This check is needed from Python 3.9 and earlier. + if (PyUnicode_READY(self) == -1 || PyUnicode_READY(sepobj) == -1) + return NULL; + kind = PyUnicode_KIND(self); data = PyUnicode_DATA(self); len = PyUnicode_GET_LENGTH(self); @@ -272,11 +277,15 @@ static PyObject *_PyStr_XStrip(PyObject *self, int striptype, PyObject *sepobj) return PyUnicode_Substring(self, i, j); } -// Copied from do_strip function in cpython.git/Objects/unicodeobject.c. +// Copied from do_strip function in cpython.git/Objects/unicodeobject.c@0ef4ffeefd1737c18dc9326133c7894d58108c2e. PyObject *_CPyStr_Strip(PyObject *self, int strip_type, PyObject *sep) { if (sep == NULL || sep == Py_None) { Py_ssize_t len, i, j; + // This check is needed from Python 3.9 and earlier. + if (PyUnicode_READY(self) == -1) + return NULL; + len = PyUnicode_GET_LENGTH(self); if (PyUnicode_IS_ASCII(self)) { diff --git a/mypyc/primitives/str_ops.py b/mypyc/primitives/str_ops.py index 4ea4edef9be1..75d47b0f0e7a 100644 --- a/mypyc/primitives/str_ops.py +++ b/mypyc/primitives/str_ops.py @@ -136,7 +136,6 @@ ) # str.strip, str.lstrip, str.rstrip -# Order of iteration matters. It should correspond with LEFTSTRIP, RIGHTSTRIP and BOTHSTRIP macros defined in CPy.h. for strip_prefix in ["l", "r", ""]: method_op( name=f"{strip_prefix}strip", diff --git a/mypyc/test-data/run-strings.test b/mypyc/test-data/run-strings.test index cd5def316e24..07122c2707ac 100644 --- a/mypyc/test-data/run-strings.test +++ b/mypyc/test-data/run-strings.test @@ -776,7 +776,6 @@ def test_surrogate() -> None: assert repr("foobar\x00\xab\ud912\U00012345") == r"'foobar\x00«\ud912𒍅'" [case testStrip] -# This is a negative test. strip variants without args does not use efficient primitives. def test_all_strips_default() -> None: s = " a1\t" assert s.lstrip() == "a1\t" @@ -787,3 +786,10 @@ def test_all_strips() -> None: assert s.lstrip("xy") == "b2yy" assert s.strip("xy") == "b2" assert s.rstrip("xy") == "xxb2" +def test_unicode_whitespace() -> None: + assert "\u200A\u000D\u2009\u2020\u000Dtt\u0085\u000A".strip() == "\u2020\u000Dtt" +def test_unicode_range() -> None: + assert "\u2029 \U00107581 ".lstrip() == "\U00107581 " + assert "\u2029 \U0010AAAA\U00104444B\u205F ".strip() == "\U0010AAAA\U00104444B" + assert " \u3000\u205F ".strip() == "" + assert "\u2029 \U00102865\u205F ".rstrip() == "\u2029 \U00102865"