Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mypyc/doc/str_operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Methods
* ``s.encode(encoding: str)``
* ``s.encode(encoding: str, errors: str)``
* ``s1.endswith(s2: str)``
* ``s1.endswith(t: tuple[str, ...])``
* ``s.join(x: Iterable)``
* ``s.removeprefix(prefix: str)``
* ``s.removesuffix(suffix: str)``
Expand All @@ -43,6 +44,7 @@ Methods
* ``s.split(sep: str)``
* ``s.split(sep: str, maxsplit: int)``
* ``s1.startswith(s2: str)``
* ``s1.startswith(t: tuple[str, ...])``

.. note::

Expand Down
4 changes: 2 additions & 2 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,8 @@ PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split);
PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace);
PyObject *CPyStr_Append(PyObject *o1, PyObject *o2);
PyObject *CPyStr_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);
bool CPyStr_Startswith(PyObject *self, PyObject *subobj);
bool CPyStr_Endswith(PyObject *self, PyObject *subobj);
int CPyStr_Startswith(PyObject *self, PyObject *subobj);
int CPyStr_Endswith(PyObject *self, PyObject *subobj);
PyObject *CPyStr_Removeprefix(PyObject *self, PyObject *prefix);
PyObject *CPyStr_Removesuffix(PyObject *self, PyObject *suffix);
bool CPyStr_IsTrue(PyObject *obj);
Expand Down
40 changes: 38 additions & 2 deletions mypyc/lib-rt/str_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,51 @@ PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr,
return PyUnicode_Replace(str, old_substr, new_substr, temp_max_replace);
}

bool CPyStr_Startswith(PyObject *self, PyObject *subobj) {
int CPyStr_Startswith(PyObject *self, PyObject *subobj) {
Py_ssize_t start = 0;
Py_ssize_t end = PyUnicode_GET_LENGTH(self);
if (PyTuple_Check(subobj)) {
Py_ssize_t i;
for (i = 0; i < PyTuple_GET_SIZE(subobj); i++) {
PyObject *substring = PyTuple_GET_ITEM(subobj, i);
if (!PyUnicode_Check(substring)) {
PyErr_Format(PyExc_TypeError,
"tuple for startswith must only contain str, "
"not %.100s",
Py_TYPE(substring)->tp_name);
return -1;
}
int result = PyUnicode_Tailmatch(self, substring, start, end, -1);
if (result) {
return 1;
}
}
return 0;
}
return PyUnicode_Tailmatch(self, subobj, start, end, -1);
}

bool CPyStr_Endswith(PyObject *self, PyObject *subobj) {
int CPyStr_Endswith(PyObject *self, PyObject *subobj) {
Py_ssize_t start = 0;
Py_ssize_t end = PyUnicode_GET_LENGTH(self);
if (PyTuple_Check(subobj)) {
Py_ssize_t i;
for (i = 0; i < PyTuple_GET_SIZE(subobj); i++) {
PyObject *substring = PyTuple_GET_ITEM(subobj, i);
if (!PyUnicode_Check(substring)) {
PyErr_Format(PyExc_TypeError,
"tuple for endswith must only contain str, "
"not %.100s",
Py_TYPE(substring)->tp_name);
return -1;
}
int result = PyUnicode_Tailmatch(self, substring, start, end, 1);
if (result) {
return 1;
}
}
return 0;
}
return PyUnicode_Tailmatch(self, subobj, start, end, 1);
}

Expand Down
27 changes: 25 additions & 2 deletions mypyc/primitives/str_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
object_rprimitive,
pointer_rprimitive,
str_rprimitive,
tuple_rprimitive,
)
from mypyc.primitives.registry import (
ERR_NEG_INT,
Expand Down Expand Up @@ -104,20 +105,42 @@
method_op(
name="startswith",
arg_types=[str_rprimitive, str_rprimitive],
return_type=bool_rprimitive,
return_type=c_int_rprimitive,
c_function_name="CPyStr_Startswith",
truncated_type=bool_rprimitive,
error_kind=ERR_NEVER,
)

# str.startswith(tuple) (return -1/0/1)
method_op(
name="startswith",
arg_types=[str_rprimitive, tuple_rprimitive],
return_type=c_int_rprimitive,
c_function_name="CPyStr_Startswith",
truncated_type=bool_rprimitive,
error_kind=ERR_NEG_INT,
)

# str.endswith(str)
method_op(
name="endswith",
arg_types=[str_rprimitive, str_rprimitive],
return_type=bool_rprimitive,
return_type=c_int_rprimitive,
c_function_name="CPyStr_Endswith",
truncated_type=bool_rprimitive,
error_kind=ERR_NEVER,
)

# str.endswith(tuple) (return -1/0/1)
method_op(
name="endswith",
arg_types=[str_rprimitive, tuple_rprimitive],
return_type=c_int_rprimitive,
c_function_name="CPyStr_Endswith",
truncated_type=bool_rprimitive,
error_kind=ERR_NEG_INT,
)

# str.removeprefix(str)
method_op(
name="removeprefix",
Expand Down
4 changes: 2 additions & 2 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def strip (self, item: str) -> str: pass
def join(self, x: Iterable[str]) -> str: pass
def format(self, *args: Any, **kwargs: Any) -> str: ...
def upper(self) -> str: ...
def startswith(self, x: str, start: int=..., end: int=...) -> bool: ...
def endswith(self, x: str, start: int=..., end: int=...) -> bool: ...
def startswith(self, x: Union[str, Tuple[str, ...]], start: int=..., end: int=...) -> bool: ...
def endswith(self, x: Union[str, Tuple[str, ...]], start: int=..., end: int=...) -> bool: ...
def replace(self, old: str, new: str, maxcount: int=...) -> str: ...
def encode(self, encoding: str=..., errors: str=...) -> bytes: ...
def removeprefix(self, prefix: str, /) -> str: ...
Expand Down
32 changes: 32 additions & 0 deletions mypyc/test-data/irbuild-str.test
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,38 @@ L4:
L5:
unreachable

[case testStrStartswithEndswithTuple]
from typing import Tuple

def do_startswith(s1: str, s2: Tuple[str, ...]) -> bool:
return s1.startswith(s2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm more interested in what happens when startswith is given a tuple literal argument (e.g. return s1.startswith(('x', 'y'))), since this is probably the most common use case. Can you test also this? It would be good to have also a run test for this.


def do_endswith(s1: str, s2: Tuple[str, ...]) -> bool:
return s1.endswith(s2)
[out]
def do_startswith(s1, s2):
s1 :: str
s2 :: tuple
r0 :: i32
r1 :: bit
r2 :: bool
L0:
r0 = CPyStr_Startswith(s1, s2)
r1 = r0 >= 0 :: signed
r2 = truncate r0: i32 to builtins.bool
return r2
def do_endswith(s1, s2):
s1 :: str
s2 :: tuple
r0 :: i32
r1 :: bit
r2 :: bool
L0:
r0 = CPyStr_Endswith(s1, s2)
r1 = r0 >= 0 :: signed
r2 = truncate r0: i32 to builtins.bool
return r2

[case testStrToBool]
def is_true(x: str) -> bool:
if x:
Expand Down
15 changes: 14 additions & 1 deletion mypyc/test-data/run-strings.test
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ def eq(x: str) -> int:
return 2
def match(x: str, y: str) -> Tuple[bool, bool]:
return (x.startswith(y), x.endswith(y))
def match_tuple(x: str, y: Tuple[str, ...]) -> Tuple[bool, bool]:
return (x.startswith(y), x.endswith(y))
def remove_prefix_suffix(x: str, y: str) -> Tuple[str, str]:
return (x.removeprefix(y), x.removesuffix(y))

[file driver.py]
from native import f, g, tostr, booltostr, concat, eq, match, remove_prefix_suffix
from native import f, g, tostr, booltostr, concat, eq, match, match_tuple, remove_prefix_suffix
import sys
from testutil import assertRaises

assert f() == 'some string'
assert f() is sys.intern('some string')
Expand All @@ -45,6 +48,16 @@ assert match('abc', '') == (True, True)
assert match('abc', 'a') == (True, False)
assert match('abc', 'c') == (False, True)
assert match('', 'abc') == (False, False)
assert match_tuple('abc', ('d', 'e')) == (False, False)
assert match_tuple('abc', ('a', 'c')) == (True, True)
assert match_tuple('abc', ('a',)) == (True, False)
assert match_tuple('abc', ('c',)) == (False, True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test case where startswith matches a non-first tuple item. Also add test for error case (tuple contains a non-string).

It would be good to add an irbuild test for a tuple literal argument, since it's easy to imagine how a fixed-length tuple literal wouldn't match against a variable-length tuple in the primitive arg type.

assert match_tuple('abc', ('x', 'y', 'z')) == (False, False)
assert match_tuple('abc', ('x', 'y', 'z', 'a', 'c')) == (True, True)
with assertRaises(TypeError, "tuple for startswith must only contain str"):
assert match_tuple('abc', (None,))
with assertRaises(TypeError, "tuple for endswith must only contain str"):
assert match_tuple('abc', ('a', None))

assert remove_prefix_suffix('', '') == ('', '')
assert remove_prefix_suffix('abc', 'a') == ('bc', 'abc')
Expand Down