diff --git a/mypyc/doc/str_operations.rst b/mypyc/doc/str_operations.rst index ef109f5bca8a..1419d56b0647 100644 --- a/mypyc/doc/str_operations.rst +++ b/mypyc/doc/str_operations.rst @@ -36,6 +36,9 @@ Methods * ``s.removesuffix(suffix: str)`` * ``s.replace(old: str, new: str)`` * ``s.replace(old: str, new: str, count: int)`` +* ``s.rsplit()`` +* ``s.rsplit(sep: str)`` +* ``s.rsplit(sep: str, maxsplit: int)`` * ``s.split()`` * ``s.split(sep: str)`` * ``s.split(sep: str, maxsplit: int)`` diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 5abe35fb689b..93d79a37aaf8 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -721,6 +721,7 @@ static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) { PyObject *CPyStr_Build(Py_ssize_t len, ...); PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index); PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split); +PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split); PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace); PyObject *CPyStr_Append(PyObject *o1, PyObject *o2); PyObject *CPyStr_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end); diff --git a/mypyc/lib-rt/str_ops.c b/mypyc/lib-rt/str_ops.c index 5b02dd33df31..46458f9b57dc 100644 --- a/mypyc/lib-rt/str_ops.c +++ b/mypyc/lib-rt/str_ops.c @@ -142,6 +142,15 @@ PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split) { return PyUnicode_Split(str, sep, temp_max_split); } +PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split) { + Py_ssize_t temp_max_split = CPyTagged_AsSsize_t(max_split); + if (temp_max_split == -1 && PyErr_Occurred()) { + PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG); + return NULL; + } + return PyUnicode_RSplit(str, sep, temp_max_split); +} + PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace) { Py_ssize_t temp_max_replace = CPyTagged_AsSsize_t(max_replace); diff --git a/mypyc/primitives/str_ops.py b/mypyc/primitives/str_ops.py index 65c60bff8c6e..fed471cd9a4e 100644 --- a/mypyc/primitives/str_ops.py +++ b/mypyc/primitives/str_ops.py @@ -136,9 +136,10 @@ error_kind=ERR_NEVER, ) -# str.split(...) +# str.split(...) and str.rsplit(...) str_split_types: list[RType] = [str_rprimitive, str_rprimitive, int_rprimitive] str_split_functions = ["PyUnicode_Split", "PyUnicode_Split", "CPyStr_Split"] +str_rsplit_functions = ["PyUnicode_RSplit", "PyUnicode_RSplit", "CPyStr_RSplit"] str_split_constants: list[list[tuple[int, RType]]] = [ [(0, pointer_rprimitive), (-1, c_int_rprimitive)], [(-1, c_int_rprimitive)], @@ -153,6 +154,14 @@ extra_int_constants=str_split_constants[i], error_kind=ERR_MAGIC, ) + method_op( + name="rsplit", + arg_types=str_split_types[0 : i + 1], + return_type=list_rprimitive, + c_function_name=str_rsplit_functions[i], + extra_int_constants=str_split_constants[i], + error_kind=ERR_MAGIC, + ) # str.replace(old, new) method_op( diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index ffd425aab049..01f189b4f08b 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -102,7 +102,8 @@ def __getitem__(self, i: int) -> str: pass def __getitem__(self, i: slice) -> str: pass def __contains__(self, item: str) -> bool: pass def __iter__(self) -> Iterator[str]: ... - def split(self, sep: Optional[str] = None, max: Optional[int] = None) -> List[str]: pass + def split(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: pass + def rsplit(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: pass def strip (self, item: str) -> str: pass def join(self, x: Iterable[str]) -> str: pass def format(self, *args: Any, **kwargs: Any) -> str: ... diff --git a/mypyc/test-data/run-strings.test b/mypyc/test-data/run-strings.test index 69422cb824d4..3998f6f7dbc4 100644 --- a/mypyc/test-data/run-strings.test +++ b/mypyc/test-data/run-strings.test @@ -61,6 +61,14 @@ def do_split(s: str, sep: Optional[str] = None, max_split: Optional[int] = None) return s.split(sep) return s.split() +def do_rsplit(s: str, sep: Optional[str] = None, max_split: Optional[int] = None) -> List[str]: + if sep is not None: + if max_split is not None: + return s.rsplit(sep, max_split) + else: + return s.rsplit(sep) + return s.rsplit() + ss = "abc abcd abcde abcdef" def test_split() -> None: @@ -72,6 +80,15 @@ def test_split() -> None: assert do_split(ss, " ", 1) == ["abc", "abcd abcde abcdef"] assert do_split(ss, " ", 2) == ["abc", "abcd", "abcde abcdef"] +def test_rsplit() -> None: + assert do_rsplit(ss) == ["abc", "abcd", "abcde", "abcdef"] + assert do_rsplit(ss, " ") == ["abc", "abcd", "abcde", "abcdef"] + assert do_rsplit(ss, "-") == ["abc abcd abcde abcdef"] + assert do_rsplit(ss, " ", -1) == ["abc", "abcd", "abcde", "abcdef"] + assert do_rsplit(ss, " ", 0) == ["abc abcd abcde abcdef"] + assert do_rsplit(ss, " ", 1) == ["abc abcd abcde", "abcdef"] # different to do_split + assert do_rsplit(ss, " ", 2) == ["abc abcd", "abcde", "abcdef"] # different to do_split + def getitem(s: str, index: int) -> str: return s[index]