diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 358f7cb76ba8..a7ed97ac8eab 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -76,6 +76,7 @@ from mypyc.primitives.misc_ops import stop_async_iteration_op from mypyc.primitives.registry import CFunctionDescription from mypyc.primitives.set_ops import set_add_op +from mypyc.primitives.str_ops import str_get_item_unsafe_op from mypyc.primitives.tuple_ops import tuple_get_item_unsafe_op GenFunc = Callable[[], None] @@ -772,6 +773,8 @@ def unsafe_index(builder: IRBuilder, target: Value, index: Value, line: int) -> return builder.primitive_op(list_get_item_unsafe_op, [target, index], line) elif is_tuple_rprimitive(target.type): return builder.call_c(tuple_get_item_unsafe_op, [target, index], line) + elif is_str_rprimitive(target.type): + return builder.call_c(str_get_item_unsafe_op, [target, index], line) else: return builder.gen_method_call(target, "__getitem__", [index], None, line) diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index dba84d44f363..698e65155da4 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -727,6 +727,7 @@ static inline char CPyDict_CheckSize(PyObject *dict, Py_ssize_t size) { char CPyStr_Equal(PyObject *str1, PyObject *str2); PyObject *CPyStr_Build(Py_ssize_t len, ...); PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index); +PyObject *CPyStr_GetItemUnsafe(PyObject *str, Py_ssize_t index); CPyTagged CPyStr_Find(PyObject *str, PyObject *substr, CPyTagged start, int direction); CPyTagged CPyStr_FindWithEnd(PyObject *str, PyObject *substr, CPyTagged start, CPyTagged end, int direction); PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split); diff --git a/mypyc/lib-rt/str_ops.c b/mypyc/lib-rt/str_ops.c index 5fd376f21cfa..a2d10aacea46 100644 --- a/mypyc/lib-rt/str_ops.c +++ b/mypyc/lib-rt/str_ops.c @@ -117,6 +117,11 @@ PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index) { } } +PyObject *CPyStr_GetItemUnsafe(PyObject *str, Py_ssize_t index) { + // This is unsafe since we don't check for overflow when doing <<. + return CPyStr_GetItem(str, index << 1); +} + // A simplification of _PyUnicode_JoinArray() from CPython 3.9.6 PyObject *CPyStr_Build(Py_ssize_t len, ...) { Py_ssize_t i; diff --git a/mypyc/primitives/str_ops.py b/mypyc/primitives/str_ops.py index 37dbdf21bb5d..e3f0b9dbbc2a 100644 --- a/mypyc/primitives/str_ops.py +++ b/mypyc/primitives/str_ops.py @@ -95,6 +95,15 @@ error_kind=ERR_MAGIC, ) +# This is unsafe since it assumes that the index is within reasonable bounds. +# In the future this might do no bounds checking at all. +str_get_item_unsafe_op = custom_op( + arg_types=[str_rprimitive, c_pyssize_t_rprimitive], + return_type=str_rprimitive, + c_function_name="CPyStr_GetItemUnsafe", + error_kind=ERR_MAGIC, +) + # str[begin:end] str_slice_op = custom_op( arg_types=[str_rprimitive, int_rprimitive, int_rprimitive], diff --git a/mypyc/test-data/irbuild-tuple.test b/mypyc/test-data/irbuild-tuple.test index c39968fc139e..5c5ec27b1882 100644 --- a/mypyc/test-data/irbuild-tuple.test +++ b/mypyc/test-data/irbuild-tuple.test @@ -272,7 +272,7 @@ L4: a = r6 return 1 -[case testTupleBuiltFromStr_64bit] +[case testTupleBuiltFromStr] def f2(val: str) -> str: return val + "f2" @@ -292,10 +292,9 @@ def test(): r2 :: bit r3 :: tuple r4, r5 :: native_int - r6, r7, r8, r9 :: bit - r10, r11, r12 :: int - r13, x, r14 :: str - r15 :: native_int + r6, r7 :: bit + r8, x, r9 :: str + r10 :: native_int a :: tuple L0: r0 = 'abc' @@ -308,30 +307,17 @@ L1: r5 = CPyStr_Size_size_t(source) r6 = r5 >= 0 :: signed r7 = r4 < r5 :: signed - if r7 goto L2 else goto L8 :: bool + if r7 goto L2 else goto L4 :: bool L2: - r8 = r4 <= 4611686018427387903 :: signed - if r8 goto L3 else goto L4 :: bool + r8 = CPyStr_GetItemUnsafe(source, r4) + x = r8 + r9 = f2(x) + CPySequenceTuple_SetItemUnsafe(r3, r4, r9) L3: - r9 = r4 >= -4611686018427387904 :: signed - if r9 goto L5 else goto L4 :: bool -L4: - r10 = CPyTagged_FromInt64(r4) - r11 = r10 - goto L6 -L5: - r12 = r4 << 1 - r11 = r12 -L6: - r13 = CPyStr_GetItem(source, r11) - x = r13 - r14 = f2(x) - CPySequenceTuple_SetItemUnsafe(r3, r4, r14) -L7: - r15 = r4 + 1 - r4 = r15 + r10 = r4 + 1 + r4 = r10 goto L1 -L8: +L4: a = r3 return 1