Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -717,13 +717,27 @@ static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) {

// Str operations

// Macros for strip type. These values are copied from CPython.
#define LEFTSTRIP 0
#define RIGHTSTRIP 1
#define BOTHSTRIP 2

PyObject *CPyStr_Build(Py_ssize_t len, ...);
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index);
CPyTagged CPyStr_Find(PyObject *str, PyObject *substr, CPyTagged start, int direction);
CPyTagged CPyStr_FindWithEnd(PyObject *str, PyObject *substr, CPyTagged start, CPyTagged end, int direction);
PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split);
PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split);
PyObject *_CPyStr_Strip(PyObject *self, int strip_type, PyObject *sep);
static inline PyObject *CPyStr_Strip(PyObject *self, PyObject *sep) {
return _CPyStr_Strip(self, BOTHSTRIP, sep);
}
static inline PyObject *CPyStr_LStrip(PyObject *self, PyObject *sep) {
return _CPyStr_Strip(self, LEFTSTRIP, sep);
}
static inline PyObject *CPyStr_RStrip(PyObject *self, PyObject *sep) {
return _CPyStr_Strip(self, RIGHTSTRIP, sep);
}
PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace);
PyObject *CPyStr_Append(PyObject *o1, PyObject *o2);
PyObject *CPyStr_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);
Expand Down
171 changes: 171 additions & 0 deletions mypyc/lib-rt/str_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,59 @@
#include <Python.h>
#include "CPy.h"

// Copied from cpython.git:Objects/unicodeobject.c@0ef4ffeefd1737c18dc9326133c7894d58108c2e.
#define BLOOM_MASK unsigned long
#define BLOOM(mask, ch) ((mask & (1UL << ((ch) & (BLOOM_WIDTH - 1)))))
#if LONG_BIT >= 128
#define BLOOM_WIDTH 128
#elif LONG_BIT >= 64
#define BLOOM_WIDTH 64
#elif LONG_BIT >= 32
#define BLOOM_WIDTH 32
#else
#error "LONG_BIT is smaller than 32"
#endif

// Copied from cpython.git:Objects/unicodeobject.c@0ef4ffeefd1737c18dc9326133c7894d58108c2e.
// This is needed for str.strip("...").
static inline BLOOM_MASK
make_bloom_mask(int kind, const void* ptr, Py_ssize_t len)
{
#define BLOOM_UPDATE(TYPE, MASK, PTR, LEN) \
do { \
TYPE *data = (TYPE *)PTR; \
TYPE *end = data + LEN; \
Py_UCS4 ch; \
for (; data != end; data++) { \
ch = *data; \
MASK |= (1UL << (ch & (BLOOM_WIDTH - 1))); \
} \
break; \
} while (0)

/* calculate simple bloom-style bitmask for a given unicode string */

BLOOM_MASK mask;

mask = 0;
switch (kind) {
case PyUnicode_1BYTE_KIND:
BLOOM_UPDATE(Py_UCS1, mask, ptr, len);
break;
case PyUnicode_2BYTE_KIND:
BLOOM_UPDATE(Py_UCS2, mask, ptr, len);
break;
case PyUnicode_4BYTE_KIND:
BLOOM_UPDATE(Py_UCS4, mask, ptr, len);
break;
default:
Py_UNREACHABLE();
}
return mask;

#undef BLOOM_UPDATE
}

PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index) {
if (PyUnicode_READY(str) != -1) {
if (CPyTagged_CheckShort(index)) {
Expand Down Expand Up @@ -174,6 +227,124 @@ PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split) {
return PyUnicode_RSplit(str, sep, temp_max_split);
}

// This function has been copied from _PyUnicode_XStrip in cpython.git:Objects/unicodeobject.c@0ef4ffeefd1737c18dc9326133c7894d58108c2e.
static PyObject *_PyStr_XStrip(PyObject *self, int striptype, PyObject *sepobj) {
const void *data;
int kind;
Py_ssize_t i, j, len;
BLOOM_MASK sepmask;
Py_ssize_t seplen;

// This check is needed from Python 3.9 and earlier.
if (PyUnicode_READY(self) == -1 || PyUnicode_READY(sepobj) == -1)
return NULL;

kind = PyUnicode_KIND(self);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to call PyUnicode_READY on self and sepobj (and check the return values) on 3.9 at least, before you can call PyUnicode_KIND etc. You can check the relevant function in Python 3.9 to see how it's used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

data = PyUnicode_DATA(self);
len = PyUnicode_GET_LENGTH(self);
seplen = PyUnicode_GET_LENGTH(sepobj);
sepmask = make_bloom_mask(PyUnicode_KIND(sepobj),
PyUnicode_DATA(sepobj),
seplen);

i = 0;
if (striptype != RIGHTSTRIP) {
while (i < len) {
Py_UCS4 ch = PyUnicode_READ(kind, data, i);
if (!BLOOM(sepmask, ch))
break;
if (PyUnicode_FindChar(sepobj, ch, 0, seplen, 1) < 0)
break;
i++;
}
}

j = len;
if (striptype != LEFTSTRIP) {
j--;
while (j >= i) {
Py_UCS4 ch = PyUnicode_READ(kind, data, j);
if (!BLOOM(sepmask, ch))
break;
if (PyUnicode_FindChar(sepobj, ch, 0, seplen, 1) < 0)
break;
j--;
}

j++;
}

return PyUnicode_Substring(self, i, j);
}

// Copied from do_strip function in cpython.git/Objects/unicodeobject.c@0ef4ffeefd1737c18dc9326133c7894d58108c2e.
PyObject *_CPyStr_Strip(PyObject *self, int strip_type, PyObject *sep) {
if (sep == NULL || sep == Py_None) {
Py_ssize_t len, i, j;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above, I think you'll need to call PyUnicode_READY.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


// This check is needed from Python 3.9 and earlier.
if (PyUnicode_READY(self) == -1)
return NULL;

len = PyUnicode_GET_LENGTH(self);

if (PyUnicode_IS_ASCII(self)) {
const Py_UCS1 *data = PyUnicode_1BYTE_DATA(self);

i = 0;
if (strip_type != RIGHTSTRIP) {
while (i < len) {
Py_UCS1 ch = data[i];
if (!_Py_ascii_whitespace[ch])
break;
i++;
}
}

j = len;
if (strip_type != LEFTSTRIP) {
j--;
while (j >= i) {
Py_UCS1 ch = data[j];
if (!_Py_ascii_whitespace[ch])
break;
j--;
}
j++;
}
}
else {
int kind = PyUnicode_KIND(self);
const void *data = PyUnicode_DATA(self);

i = 0;
if (strip_type != RIGHTSTRIP) {
while (i < len) {
Py_UCS4 ch = PyUnicode_READ(kind, data, i);
if (!Py_UNICODE_ISSPACE(ch))
break;
i++;
}
}

j = len;
if (strip_type != LEFTSTRIP) {
j--;
while (j >= i) {
Py_UCS4 ch = PyUnicode_READ(kind, data, j);
if (!Py_UNICODE_ISSPACE(ch))
break;
j--;
}
j++;
}
}

return PyUnicode_Substring(self, i, j);
}
return _PyStr_XStrip(self, strip_type, sep);
}

PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr,
PyObject *new_substr, CPyTagged max_replace) {
Py_ssize_t temp_max_replace = CPyTagged_AsSsize_t(max_replace);
Expand Down
19 changes: 19 additions & 0 deletions mypyc/primitives/str_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,25 @@
var_arg_type=str_rprimitive,
)

# str.strip, str.lstrip, str.rstrip
for strip_prefix in ["l", "r", ""]:
method_op(
name=f"{strip_prefix}strip",
arg_types=[str_rprimitive, str_rprimitive],
return_type=str_rprimitive,
c_function_name=f"CPyStr_{strip_prefix.upper()}Strip",
error_kind=ERR_NEVER,
)
method_op(
name=f"{strip_prefix}strip",
arg_types=[str_rprimitive],
return_type=str_rprimitive,
c_function_name=f"CPyStr_{strip_prefix.upper()}Strip",
# This 0 below is implicitly treated as NULL in C.
extra_int_constants=[(0, c_int_rprimitive)],
error_kind=ERR_NEVER,
)

# str.startswith(str)
method_op(
name="startswith",
Expand Down
4 changes: 3 additions & 1 deletion mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def rfind(self, sub: str, start: Optional[int] = None, end: Optional[int] = None
def split(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: pass
def rsplit(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: pass
def splitlines(self, keepends: bool = False) -> List[str]: ...
def strip (self, item: str) -> str: pass
def strip (self, item: Optional[str] = None) -> str: pass
def lstrip(self, item: Optional[str] = None) -> str: pass
def rstrip(self, item: Optional[str] = None) -> str: pass
def join(self, x: Iterable[str]) -> str: pass
def format(self, *args: Any, **kwargs: Any) -> str: ...
def upper(self) -> str: ...
Expand Down
23 changes: 23 additions & 0 deletions mypyc/test-data/irbuild-str.test
Original file line number Diff line number Diff line change
Expand Up @@ -481,3 +481,26 @@ L0:
keep_alive x
r6 = unbox(int, r5)
return r6

[case testStrip]
def do_strip(s: str) -> None:
s.lstrip("x")
s.strip("y")
s.rstrip("z")
s.lstrip()
s.strip()
s.rstrip()
[out]
def do_strip(s):
s, r0, r1, r2, r3, r4, r5, r6, r7, r8 :: str
L0:
r0 = 'x'
r1 = CPyStr_LStrip(s, r0)
r2 = 'y'
r3 = CPyStr_Strip(s, r2)
r4 = 'z'
r5 = CPyStr_RStrip(s, r4)
r6 = CPyStr_LStrip(s, 0)
r7 = CPyStr_Strip(s, 0)
r8 = CPyStr_RStrip(s, 0)
return 1
19 changes: 19 additions & 0 deletions mypyc/test-data/run-strings.test
Original file line number Diff line number Diff line change
Expand Up @@ -774,3 +774,22 @@ def test_surrogate() -> None:
assert ord(f()) == 0xd800
assert ord("\udfff") == 0xdfff
assert repr("foobar\x00\xab\ud912\U00012345") == r"'foobar\x00«\ud912𒍅'"

[case testStrip]
def test_all_strips_default() -> None:
s = " a1\t"
assert s.lstrip() == "a1\t"
assert s.strip() == "a1"
assert s.rstrip() == " a1"
def test_all_strips() -> None:
s = "xxb2yy"
assert s.lstrip("xy") == "b2yy"
assert s.strip("xy") == "b2"
assert s.rstrip("xy") == "xxb2"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you test all string kinds and different character code ranges, such as these (and mixing these):

  • Character codes between 128 and 255 (0x80 to 0xff)
  • Character codes between 256 and 65535 (0x100 to 0xffff)
  • Character codes 65536+ (0x10000+)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

def test_unicode_whitespace() -> None:
assert "\u200A\u000D\u2009\u2020\u000Dtt\u0085\u000A".strip() == "\u2020\u000Dtt"
def test_unicode_range() -> None:
assert "\u2029 \U00107581 ".lstrip() == "\U00107581 "
assert "\u2029 \U0010AAAA\U00104444B\u205F ".strip() == "\U0010AAAA\U00104444B"
assert " \u3000\u205F ".strip() == ""
assert "\u2029 \U00102865\u205F ".rstrip() == "\u2029 \U00102865"