From 3d9bbc6b7034796b59886396f85b800f53531b8b Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Sat, 15 Feb 2025 20:06:58 +0100 Subject: [PATCH] [mypyc] Optimize str.find and str.rfind --- mypyc/doc/str_operations.rst | 6 ++++++ mypyc/lib-rt/CPy.h | 2 ++ mypyc/lib-rt/str_ops.c | 23 +++++++++++++++++++++ mypyc/primitives/str_ops.py | 23 +++++++++++++++++++++ mypyc/test-data/fixtures/ir.py | 2 ++ mypyc/test-data/run-strings.test | 34 ++++++++++++++++++++++++++++++++ 6 files changed, 90 insertions(+) diff --git a/mypyc/doc/str_operations.rst b/mypyc/doc/str_operations.rst index a7e9ccc58cd1..5b18c0c927d6 100644 --- a/mypyc/doc/str_operations.rst +++ b/mypyc/doc/str_operations.rst @@ -33,12 +33,18 @@ Methods * ``s.encode(encoding: str, errors: str)`` * ``s1.endswith(s2: str)`` * ``s1.endswith(t: tuple[str, ...])`` +* ``s1.find(s2: str)`` +* ``s1.find(s2: str, start: int)`` +* ``s1.find(s2: str, start: int, end: int)`` * ``s.join(x: Iterable)`` * ``s.partition(sep: str)`` * ``s.removeprefix(prefix: str)`` * ``s.removesuffix(suffix: str)`` * ``s.replace(old: str, new: str)`` * ``s.replace(old: str, new: str, count: int)`` +* ``s1.rfind(s2: str)`` +* ``s1.rfind(s2: str, start: int)`` +* ``s1.rfind(s2: str, start: int, end: int)`` * ``s.rpartition(sep: str)`` * ``s.rsplit()`` * ``s.rsplit(sep: str)`` diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 22ab0f253ed7..1c8b59855fc7 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -720,6 +720,8 @@ static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) { PyObject *CPyStr_Build(Py_ssize_t len, ...); PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index); +CPyTagged CPyStr_Find(PyObject *str, PyObject *substr, CPyTagged start, int direction); +CPyTagged CPyStr_FindWithEnd(PyObject *str, PyObject *substr, CPyTagged start, CPyTagged end, int direction); PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split); PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split); PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace); diff --git a/mypyc/lib-rt/str_ops.c b/mypyc/lib-rt/str_ops.c index 00759166df35..5b295f84440b 100644 --- a/mypyc/lib-rt/str_ops.c +++ b/mypyc/lib-rt/str_ops.c @@ -133,6 +133,29 @@ PyObject *CPyStr_Build(Py_ssize_t len, ...) { return res; } +CPyTagged CPyStr_Find(PyObject *str, PyObject *substr, CPyTagged start, int direction) { + CPyTagged end = PyUnicode_GET_LENGTH(str) << 1; + return CPyStr_FindWithEnd(str, substr, start, end, direction); +} + +CPyTagged CPyStr_FindWithEnd(PyObject *str, PyObject *substr, CPyTagged start, CPyTagged end, int direction) { + Py_ssize_t temp_start = CPyTagged_AsSsize_t(start); + if (temp_start == -1 && PyErr_Occurred()) { + PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG); + return CPY_INT_TAG; + } + Py_ssize_t temp_end = CPyTagged_AsSsize_t(end); + if (temp_end == -1 && PyErr_Occurred()) { + PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG); + return CPY_INT_TAG; + } + Py_ssize_t index = PyUnicode_Find(str, substr, temp_start, temp_end, direction); + if (unlikely(index == -2)) { + return CPY_INT_TAG; + } + return index << 1; +} + PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split) { Py_ssize_t temp_max_split = CPyTagged_AsSsize_t(max_split); if (temp_max_split == -1 && PyErr_Occurred()) { diff --git a/mypyc/primitives/str_ops.py b/mypyc/primitives/str_ops.py index aef3575d8eb4..e4c644470ba4 100644 --- a/mypyc/primitives/str_ops.py +++ b/mypyc/primitives/str_ops.py @@ -95,6 +95,29 @@ ordering=[1, 0], ) +# str.find(...) and str.rfind(...) +str_find_types: list[RType] = [str_rprimitive, str_rprimitive, int_rprimitive, int_rprimitive] +str_find_functions = ["CPyStr_Find", "CPyStr_Find", "CPyStr_FindWithEnd"] +str_find_constants: list[list[tuple[int, RType]]] = [[(0, c_int_rprimitive)], [], []] +str_rfind_constants: list[list[tuple[int, RType]]] = [[(0, c_int_rprimitive)], [], []] +for i in range(len(str_find_types) - 1): + method_op( + name="find", + arg_types=str_find_types[0 : i + 2], + return_type=int_rprimitive, + c_function_name=str_find_functions[i], + extra_int_constants=str_find_constants[i] + [(1, c_int_rprimitive)], + error_kind=ERR_MAGIC, + ) + method_op( + name="rfind", + arg_types=str_find_types[0 : i + 2], + return_type=int_rprimitive, + c_function_name=str_find_functions[i], + extra_int_constants=str_rfind_constants[i] + [(-1, c_int_rprimitive)], + error_kind=ERR_MAGIC, + ) + # str.join(obj) method_op( name="join", diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 1c7346791c68..38fecbc20c65 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -102,6 +102,8 @@ def __getitem__(self, i: int) -> str: pass def __getitem__(self, i: slice) -> str: pass def __contains__(self, item: str) -> bool: pass def __iter__(self) -> Iterator[str]: ... + def find(self, sub: str, start: Optional[int] = None, end: Optional[int] = None, /) -> int: ... + def rfind(self, sub: str, start: Optional[int] = None, end: Optional[int] = None, /) -> int: ... def split(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: pass def rsplit(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: pass def splitlines(self, keepends: bool = False) -> List[str]: ... diff --git a/mypyc/test-data/run-strings.test b/mypyc/test-data/run-strings.test index 7eadaeee0707..ce5c85059aed 100644 --- a/mypyc/test-data/run-strings.test +++ b/mypyc/test-data/run-strings.test @@ -146,6 +146,20 @@ def contains(s: str, o: str) -> bool: def getitem(s: str, index: int) -> str: return s[index] +def find(s: str, substr: str, start: Optional[int] = None, end: Optional[int] = None) -> int: + if start is not None: + if end is not None: + return s.find(substr, start, end) + return s.find(substr, start) + return s.find(substr) + +def rfind(s: str, substr: str, start: Optional[int] = None, end: Optional[int] = None) -> int: + if start is not None: + if end is not None: + return s.rfind(substr, start, end) + return s.rfind(substr, start) + return s.rfind(substr) + s = "abc" def test_contains() -> None: @@ -170,6 +184,26 @@ def test_getitem() -> None: with assertRaises(IndexError, "string index out of range"): getitem(s, -4) +def test_find() -> None: + s = "abcab" + assert find(s, "Hello") == -1 + assert find(s, "abc") == 0 + assert find(s, "b") == 1 + assert find(s, "b", 1) == 1 + assert find(s, "b", 1, 2) == 1 + assert find(s, "b", 3) == 4 + assert find(s, "b", 3, 5) == 4 + assert find(s, "b", 3, 4) == -1 + + assert rfind(s, "Hello") == -1 + assert rfind(s, "abc") == 0 + assert rfind(s, "b") == 4 + assert rfind(s, "b", 1) == 4 + assert rfind(s, "b", 1, 2) == 1 + assert rfind(s, "b", 3) == 4 + assert rfind(s, "b", 3, 5) == 4 + assert rfind(s, "b", 3, 4) == -1 + def str_to_int(s: str, base: Optional[int] = None) -> int: if base: return int(s, base)