Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mypyc/doc/str_operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Methods
* ``s.encode(encoding: str)``
* ``s.encode(encoding: str, errors: str)``
* ``s1.endswith(s2: str)``
* ``s1.endswith(t: tuple[str, ...])``
* ``s.join(x: Iterable)``
* ``s.removeprefix(prefix: str)``
* ``s.removesuffix(suffix: str)``
Expand All @@ -43,6 +44,7 @@ Methods
* ``s.split(sep: str)``
* ``s.split(sep: str, maxsplit: int)``
* ``s1.startswith(s2: str)``
* ``s1.startswith(t: tuple[str, ...])``

.. note::

Expand Down
4 changes: 2 additions & 2 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,8 @@ PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split);
PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace);
PyObject *CPyStr_Append(PyObject *o1, PyObject *o2);
PyObject *CPyStr_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);
bool CPyStr_Startswith(PyObject *self, PyObject *subobj);
bool CPyStr_Endswith(PyObject *self, PyObject *subobj);
int CPyStr_Startswith(PyObject *self, PyObject *subobj);
int CPyStr_Endswith(PyObject *self, PyObject *subobj);
PyObject *CPyStr_Removeprefix(PyObject *self, PyObject *prefix);
PyObject *CPyStr_Removesuffix(PyObject *self, PyObject *suffix);
bool CPyStr_IsTrue(PyObject *obj);
Expand Down
40 changes: 38 additions & 2 deletions mypyc/lib-rt/str_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,51 @@ PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr,
return PyUnicode_Replace(str, old_substr, new_substr, temp_max_replace);
}

bool CPyStr_Startswith(PyObject *self, PyObject *subobj) {
int CPyStr_Startswith(PyObject *self, PyObject *subobj) {
Py_ssize_t start = 0;
Py_ssize_t end = PyUnicode_GET_LENGTH(self);
if (PyTuple_Check(subobj)) {
Py_ssize_t i;
for (i = 0; i < PyTuple_GET_SIZE(subobj); i++) {
PyObject *substring = PyTuple_GET_ITEM(subobj, i);
if (!PyUnicode_Check(substring)) {
PyErr_Format(PyExc_TypeError,
"tuple for startswith must only contain str, "
"not %.100s",
Py_TYPE(substring)->tp_name);
return -1;
}
int result = PyUnicode_Tailmatch(self, substring, start, end, -1);
if (result) {
return 1;
}
}
return 0;
}
return PyUnicode_Tailmatch(self, subobj, start, end, -1);
}

bool CPyStr_Endswith(PyObject *self, PyObject *subobj) {
int CPyStr_Endswith(PyObject *self, PyObject *subobj) {
Py_ssize_t start = 0;
Py_ssize_t end = PyUnicode_GET_LENGTH(self);
if (PyTuple_Check(subobj)) {
Py_ssize_t i;
for (i = 0; i < PyTuple_GET_SIZE(subobj); i++) {
PyObject *substring = PyTuple_GET_ITEM(subobj, i);
if (!PyUnicode_Check(substring)) {
PyErr_Format(PyExc_TypeError,
"tuple for endswith must only contain str, "
"not %.100s",
Py_TYPE(substring)->tp_name);
return -1;
}
int result = PyUnicode_Tailmatch(self, substring, start, end, 1);
if (result) {
return 1;
}
}
return 0;
}
return PyUnicode_Tailmatch(self, subobj, start, end, 1);
}

Expand Down
27 changes: 25 additions & 2 deletions mypyc/primitives/str_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
object_rprimitive,
pointer_rprimitive,
str_rprimitive,
tuple_rprimitive,
)
from mypyc.primitives.registry import (
ERR_NEG_INT,
Expand Down Expand Up @@ -104,20 +105,42 @@
method_op(
name="startswith",
arg_types=[str_rprimitive, str_rprimitive],
return_type=bool_rprimitive,
return_type=c_int_rprimitive,
c_function_name="CPyStr_Startswith",
truncated_type=bool_rprimitive,
error_kind=ERR_NEVER,
)

# str.startswith(tuple) (return -1/0/1)
method_op(
name="startswith",
arg_types=[str_rprimitive, tuple_rprimitive],
return_type=c_int_rprimitive,
c_function_name="CPyStr_Startswith",
truncated_type=bool_rprimitive,
error_kind=ERR_NEG_INT,
)

# str.endswith(str)
method_op(
name="endswith",
arg_types=[str_rprimitive, str_rprimitive],
return_type=bool_rprimitive,
return_type=c_int_rprimitive,
c_function_name="CPyStr_Endswith",
truncated_type=bool_rprimitive,
error_kind=ERR_NEVER,
)

# str.endswith(tuple) (return -1/0/1)
method_op(
name="endswith",
arg_types=[str_rprimitive, tuple_rprimitive],
return_type=c_int_rprimitive,
c_function_name="CPyStr_Endswith",
truncated_type=bool_rprimitive,
error_kind=ERR_NEG_INT,
)

# str.removeprefix(str)
method_op(
name="removeprefix",
Expand Down
4 changes: 2 additions & 2 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def strip (self, item: str) -> str: pass
def join(self, x: Iterable[str]) -> str: pass
def format(self, *args: Any, **kwargs: Any) -> str: ...
def upper(self) -> str: ...
def startswith(self, x: str, start: int=..., end: int=...) -> bool: ...
def endswith(self, x: str, start: int=..., end: int=...) -> bool: ...
def startswith(self, x: Union[str, Tuple[str, ...]], start: int=..., end: int=...) -> bool: ...
def endswith(self, x: Union[str, Tuple[str, ...]], start: int=..., end: int=...) -> bool: ...
def replace(self, old: str, new: str, maxcount: int=...) -> str: ...
def encode(self, encoding: str=..., errors: str=...) -> bytes: ...
def removeprefix(self, prefix: str, /) -> str: ...
Expand Down
8 changes: 7 additions & 1 deletion mypyc/test-data/run-strings.test
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ def eq(x: str) -> int:
return 2
def match(x: str, y: str) -> Tuple[bool, bool]:
return (x.startswith(y), x.endswith(y))
def match_tuple(x: str, y: Tuple[str, ...]) -> Tuple[bool, bool]:
return (x.startswith(y), x.endswith(y))
def remove_prefix_suffix(x: str, y: str) -> Tuple[str, str]:
return (x.removeprefix(y), x.removesuffix(y))

[file driver.py]
from native import f, g, tostr, booltostr, concat, eq, match, remove_prefix_suffix
from native import f, g, tostr, booltostr, concat, eq, match, match_tuple, remove_prefix_suffix
import sys

assert f() == 'some string'
Expand All @@ -45,6 +47,10 @@ assert match('abc', '') == (True, True)
assert match('abc', 'a') == (True, False)
assert match('abc', 'c') == (False, True)
assert match('', 'abc') == (False, False)
assert match_tuple('abc', ('d', 'e')) == (False, False)
assert match_tuple('abc', ('a', 'c')) == (True, True)
assert match_tuple('abc', ('a',)) == (True, False)
assert match_tuple('abc', ('c',)) == (False, True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test case where startswith matches a non-first tuple item. Also add test for error case (tuple contains a non-string).

It would be good to add an irbuild test for a tuple literal argument, since it's easy to imagine how a fixed-length tuple literal wouldn't match against a variable-length tuple in the primitive arg type.


assert remove_prefix_suffix('', '') == ('', '')
assert remove_prefix_suffix('abc', 'a') == ('bc', 'abc')
Expand Down