From 3a44d3d4719ad89be95eda7e5e7cfce5df1afa5f Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Sat, 17 Sep 2022 14:48:06 +0100 Subject: [PATCH] [mypyc] Simplify IR generated for "for" loops over strings Add unsafe list get item primitive. The new primitive just calls the primary get item primitive, but we could later provide an optimized primitive if this turns out to be a performance bottleneck. --- mypyc/irbuild/for_helpers.py | 3 +++ mypyc/lib-rt/CPy.h | 1 + mypyc/lib-rt/str_ops.c | 5 ++++ mypyc/primitives/str_ops.py | 9 +++++++ mypyc/test-data/irbuild-tuple.test | 38 ++++++++++-------------------- 5 files changed, 30 insertions(+), 26 deletions(-) diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 358f7cb76ba8..a7ed97ac8eab 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -76,6 +76,7 @@ from mypyc.primitives.misc_ops import stop_async_iteration_op from mypyc.primitives.registry import CFunctionDescription from mypyc.primitives.set_ops import set_add_op +from mypyc.primitives.str_ops import str_get_item_unsafe_op from mypyc.primitives.tuple_ops import tuple_get_item_unsafe_op GenFunc = Callable[[], None] @@ -772,6 +773,8 @@ def unsafe_index(builder: IRBuilder, target: Value, index: Value, line: int) -> return builder.primitive_op(list_get_item_unsafe_op, [target, index], line) elif is_tuple_rprimitive(target.type): return builder.call_c(tuple_get_item_unsafe_op, [target, index], line) + elif is_str_rprimitive(target.type): + return builder.call_c(str_get_item_unsafe_op, [target, index], line) else: return builder.gen_method_call(target, "__getitem__", [index], None, line) diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index dba84d44f363..698e65155da4 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -727,6 +727,7 @@ static inline char CPyDict_CheckSize(PyObject *dict, Py_ssize_t size) { char CPyStr_Equal(PyObject *str1, PyObject *str2); PyObject *CPyStr_Build(Py_ssize_t len, ...); PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index); +PyObject *CPyStr_GetItemUnsafe(PyObject *str, Py_ssize_t index); CPyTagged CPyStr_Find(PyObject *str, PyObject *substr, CPyTagged start, int direction); CPyTagged CPyStr_FindWithEnd(PyObject *str, PyObject *substr, CPyTagged start, CPyTagged end, int direction); PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split); diff --git a/mypyc/lib-rt/str_ops.c b/mypyc/lib-rt/str_ops.c index 5fd376f21cfa..a2d10aacea46 100644 --- a/mypyc/lib-rt/str_ops.c +++ b/mypyc/lib-rt/str_ops.c @@ -117,6 +117,11 @@ PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index) { } } +PyObject *CPyStr_GetItemUnsafe(PyObject *str, Py_ssize_t index) { + // This is unsafe since we don't check for overflow when doing <<. + return CPyStr_GetItem(str, index << 1); +} + // A simplification of _PyUnicode_JoinArray() from CPython 3.9.6 PyObject *CPyStr_Build(Py_ssize_t len, ...) { Py_ssize_t i; diff --git a/mypyc/primitives/str_ops.py b/mypyc/primitives/str_ops.py index 37dbdf21bb5d..e3f0b9dbbc2a 100644 --- a/mypyc/primitives/str_ops.py +++ b/mypyc/primitives/str_ops.py @@ -95,6 +95,15 @@ error_kind=ERR_MAGIC, ) +# This is unsafe since it assumes that the index is within reasonable bounds. +# In the future this might do no bounds checking at all. +str_get_item_unsafe_op = custom_op( + arg_types=[str_rprimitive, c_pyssize_t_rprimitive], + return_type=str_rprimitive, + c_function_name="CPyStr_GetItemUnsafe", + error_kind=ERR_MAGIC, +) + # str[begin:end] str_slice_op = custom_op( arg_types=[str_rprimitive, int_rprimitive, int_rprimitive], diff --git a/mypyc/test-data/irbuild-tuple.test b/mypyc/test-data/irbuild-tuple.test index c39968fc139e..5c5ec27b1882 100644 --- a/mypyc/test-data/irbuild-tuple.test +++ b/mypyc/test-data/irbuild-tuple.test @@ -272,7 +272,7 @@ L4: a = r6 return 1 -[case testTupleBuiltFromStr_64bit] +[case testTupleBuiltFromStr] def f2(val: str) -> str: return val + "f2" @@ -292,10 +292,9 @@ def test(): r2 :: bit r3 :: tuple r4, r5 :: native_int - r6, r7, r8, r9 :: bit - r10, r11, r12 :: int - r13, x, r14 :: str - r15 :: native_int + r6, r7 :: bit + r8, x, r9 :: str + r10 :: native_int a :: tuple L0: r0 = 'abc' @@ -308,30 +307,17 @@ L1: r5 = CPyStr_Size_size_t(source) r6 = r5 >= 0 :: signed r7 = r4 < r5 :: signed - if r7 goto L2 else goto L8 :: bool + if r7 goto L2 else goto L4 :: bool L2: - r8 = r4 <= 4611686018427387903 :: signed - if r8 goto L3 else goto L4 :: bool + r8 = CPyStr_GetItemUnsafe(source, r4) + x = r8 + r9 = f2(x) + CPySequenceTuple_SetItemUnsafe(r3, r4, r9) L3: - r9 = r4 >= -4611686018427387904 :: signed - if r9 goto L5 else goto L4 :: bool -L4: - r10 = CPyTagged_FromInt64(r4) - r11 = r10 - goto L6 -L5: - r12 = r4 << 1 - r11 = r12 -L6: - r13 = CPyStr_GetItem(source, r11) - x = r13 - r14 = f2(x) - CPySequenceTuple_SetItemUnsafe(r3, r4, r14) -L7: - r15 = r4 + 1 - r4 = r15 + r10 = r4 + 1 + r4 = r10 goto L1 -L8: +L4: a = r3 return 1