Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mypyc/doc/str_operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,18 @@ Methods
* ``s.encode(encoding: str, errors: str)``
* ``s1.endswith(s2: str)``
* ``s1.endswith(t: tuple[str, ...])``
* ``s1.find(s2: str)``
* ``s1.find(s2: str, start: int)``
* ``s1.find(s2: str, start: int, end: int)``
* ``s.join(x: Iterable)``
* ``s.partition(sep: str)``
* ``s.removeprefix(prefix: str)``
* ``s.removesuffix(suffix: str)``
* ``s.replace(old: str, new: str)``
* ``s.replace(old: str, new: str, count: int)``
* ``s1.rfind(s2: str)``
* ``s1.rfind(s2: str, start: int)``
* ``s1.rfind(s2: str, start: int, end: int)``
* ``s.rpartition(sep: str)``
* ``s.rsplit()``
* ``s.rsplit(sep: str)``
Expand Down
2 changes: 2 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,8 @@ static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) {

PyObject *CPyStr_Build(Py_ssize_t len, ...);
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index);
CPyTagged CPyStr_Find(PyObject *str, PyObject *substr, CPyTagged start, int direction);
CPyTagged CPyStr_FindWithEnd(PyObject *str, PyObject *substr, CPyTagged start, CPyTagged end, int direction);
PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split);
PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split);
PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace);
Expand Down
23 changes: 23 additions & 0 deletions mypyc/lib-rt/str_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,29 @@ PyObject *CPyStr_Build(Py_ssize_t len, ...) {
return res;
}

CPyTagged CPyStr_Find(PyObject *str, PyObject *substr, CPyTagged start, int direction) {
CPyTagged end = PyUnicode_GET_LENGTH(str) << 1;
return CPyStr_FindWithEnd(str, substr, start, end, direction);
}

CPyTagged CPyStr_FindWithEnd(PyObject *str, PyObject *substr, CPyTagged start, CPyTagged end, int direction) {
Py_ssize_t temp_start = CPyTagged_AsSsize_t(start);
if (temp_start == -1 && PyErr_Occurred()) {
PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG);
return CPY_INT_TAG;
}
Py_ssize_t temp_end = CPyTagged_AsSsize_t(end);
if (temp_end == -1 && PyErr_Occurred()) {
PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG);
return CPY_INT_TAG;
}
Py_ssize_t index = PyUnicode_Find(str, substr, temp_start, temp_end, direction);
if (unlikely(index == -2)) {
return CPY_INT_TAG;
}
return index << 1;
}

PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split) {
Py_ssize_t temp_max_split = CPyTagged_AsSsize_t(max_split);
if (temp_max_split == -1 && PyErr_Occurred()) {
Expand Down
23 changes: 23 additions & 0 deletions mypyc/primitives/str_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,29 @@
ordering=[1, 0],
)

# str.find(...) and str.rfind(...)
str_find_types: list[RType] = [str_rprimitive, str_rprimitive, int_rprimitive, int_rprimitive]
str_find_functions = ["CPyStr_Find", "CPyStr_Find", "CPyStr_FindWithEnd"]
str_find_constants: list[list[tuple[int, RType]]] = [[(0, c_int_rprimitive)], [], []]
str_rfind_constants: list[list[tuple[int, RType]]] = [[(0, c_int_rprimitive)], [], []]
for i in range(len(str_find_types) - 1):
method_op(
name="find",
arg_types=str_find_types[0 : i + 2],
return_type=int_rprimitive,
c_function_name=str_find_functions[i],
extra_int_constants=str_find_constants[i] + [(1, c_int_rprimitive)],
error_kind=ERR_MAGIC,
)
method_op(
name="rfind",
arg_types=str_find_types[0 : i + 2],
return_type=int_rprimitive,
c_function_name=str_find_functions[i],
extra_int_constants=str_rfind_constants[i] + [(-1, c_int_rprimitive)],
error_kind=ERR_MAGIC,
)

# str.join(obj)
method_op(
name="join",
Expand Down
2 changes: 2 additions & 0 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def __getitem__(self, i: int) -> str: pass
def __getitem__(self, i: slice) -> str: pass
def __contains__(self, item: str) -> bool: pass
def __iter__(self) -> Iterator[str]: ...
def find(self, sub: str, start: Optional[int] = None, end: Optional[int] = None, /) -> int: ...
def rfind(self, sub: str, start: Optional[int] = None, end: Optional[int] = None, /) -> int: ...
def split(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: pass
def rsplit(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: pass
def splitlines(self, keepends: bool = False) -> List[str]: ...
Expand Down
34 changes: 34 additions & 0 deletions mypyc/test-data/run-strings.test
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,20 @@ def contains(s: str, o: str) -> bool:
def getitem(s: str, index: int) -> str:
return s[index]

def find(s: str, substr: str, start: Optional[int] = None, end: Optional[int] = None) -> int:
if start is not None:
if end is not None:
return s.find(substr, start, end)
return s.find(substr, start)
return s.find(substr)

def rfind(s: str, substr: str, start: Optional[int] = None, end: Optional[int] = None) -> int:
if start is not None:
if end is not None:
return s.rfind(substr, start, end)
return s.rfind(substr, start)
return s.rfind(substr)

s = "abc"

def test_contains() -> None:
Expand All @@ -170,6 +184,26 @@ def test_getitem() -> None:
with assertRaises(IndexError, "string index out of range"):
getitem(s, -4)

def test_find() -> None:
s = "abcab"
assert find(s, "Hello") == -1
assert find(s, "abc") == 0
assert find(s, "b") == 1
assert find(s, "b", 1) == 1
assert find(s, "b", 1, 2) == 1
assert find(s, "b", 3) == 4
assert find(s, "b", 3, 5) == 4
assert find(s, "b", 3, 4) == -1

assert rfind(s, "Hello") == -1
assert rfind(s, "abc") == 0
assert rfind(s, "b") == 4
assert rfind(s, "b", 1) == 4
assert rfind(s, "b", 1, 2) == 1
assert rfind(s, "b", 3) == 4
assert rfind(s, "b", 3, 5) == 4
assert rfind(s, "b", 3, 4) == -1

def str_to_int(s: str, base: Optional[int] = None) -> int:
if base:
return int(s, base)
Expand Down