Skip to content

Commit 3a44d3d

Browse files
committed
[mypyc] Simplify IR generated for "for" loops over strings
Add unsafe list get item primitive. The new primitive just calls the primary get item primitive, but we could later provide an optimized primitive if this turns out to be a performance bottleneck.
1 parent db67888 commit 3a44d3d

File tree

5 files changed

+30
-26
lines changed

5 files changed

+30
-26
lines changed

mypyc/irbuild/for_helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
from mypyc.primitives.misc_ops import stop_async_iteration_op
7777
from mypyc.primitives.registry import CFunctionDescription
7878
from mypyc.primitives.set_ops import set_add_op
79+
from mypyc.primitives.str_ops import str_get_item_unsafe_op
7980
from mypyc.primitives.tuple_ops import tuple_get_item_unsafe_op
8081

8182
GenFunc = Callable[[], None]
@@ -772,6 +773,8 @@ def unsafe_index(builder: IRBuilder, target: Value, index: Value, line: int) ->
772773
return builder.primitive_op(list_get_item_unsafe_op, [target, index], line)
773774
elif is_tuple_rprimitive(target.type):
774775
return builder.call_c(tuple_get_item_unsafe_op, [target, index], line)
776+
elif is_str_rprimitive(target.type):
777+
return builder.call_c(str_get_item_unsafe_op, [target, index], line)
775778
else:
776779
return builder.gen_method_call(target, "__getitem__", [index], None, line)
777780

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,7 @@ static inline char CPyDict_CheckSize(PyObject *dict, Py_ssize_t size) {
727727
char CPyStr_Equal(PyObject *str1, PyObject *str2);
728728
PyObject *CPyStr_Build(Py_ssize_t len, ...);
729729
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index);
730+
PyObject *CPyStr_GetItemUnsafe(PyObject *str, Py_ssize_t index);
730731
CPyTagged CPyStr_Find(PyObject *str, PyObject *substr, CPyTagged start, int direction);
731732
CPyTagged CPyStr_FindWithEnd(PyObject *str, PyObject *substr, CPyTagged start, CPyTagged end, int direction);
732733
PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split);

mypyc/lib-rt/str_ops.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@ PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index) {
117117
}
118118
}
119119

120+
PyObject *CPyStr_GetItemUnsafe(PyObject *str, Py_ssize_t index) {
121+
// This is unsafe since we don't check for overflow when doing <<.
122+
return CPyStr_GetItem(str, index << 1);
123+
}
124+
120125
// A simplification of _PyUnicode_JoinArray() from CPython 3.9.6
121126
PyObject *CPyStr_Build(Py_ssize_t len, ...) {
122127
Py_ssize_t i;

mypyc/primitives/str_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,15 @@
9595
error_kind=ERR_MAGIC,
9696
)
9797

98+
# This is unsafe since it assumes that the index is within reasonable bounds.
99+
# In the future this might do no bounds checking at all.
100+
str_get_item_unsafe_op = custom_op(
101+
arg_types=[str_rprimitive, c_pyssize_t_rprimitive],
102+
return_type=str_rprimitive,
103+
c_function_name="CPyStr_GetItemUnsafe",
104+
error_kind=ERR_MAGIC,
105+
)
106+
98107
# str[begin:end]
99108
str_slice_op = custom_op(
100109
arg_types=[str_rprimitive, int_rprimitive, int_rprimitive],

mypyc/test-data/irbuild-tuple.test

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ L4:
272272
a = r6
273273
return 1
274274

275-
[case testTupleBuiltFromStr_64bit]
275+
[case testTupleBuiltFromStr]
276276
def f2(val: str) -> str:
277277
return val + "f2"
278278

@@ -292,10 +292,9 @@ def test():
292292
r2 :: bit
293293
r3 :: tuple
294294
r4, r5 :: native_int
295-
r6, r7, r8, r9 :: bit
296-
r10, r11, r12 :: int
297-
r13, x, r14 :: str
298-
r15 :: native_int
295+
r6, r7 :: bit
296+
r8, x, r9 :: str
297+
r10 :: native_int
299298
a :: tuple
300299
L0:
301300
r0 = 'abc'
@@ -308,30 +307,17 @@ L1:
308307
r5 = CPyStr_Size_size_t(source)
309308
r6 = r5 >= 0 :: signed
310309
r7 = r4 < r5 :: signed
311-
if r7 goto L2 else goto L8 :: bool
310+
if r7 goto L2 else goto L4 :: bool
312311
L2:
313-
r8 = r4 <= 4611686018427387903 :: signed
314-
if r8 goto L3 else goto L4 :: bool
312+
r8 = CPyStr_GetItemUnsafe(source, r4)
313+
x = r8
314+
r9 = f2(x)
315+
CPySequenceTuple_SetItemUnsafe(r3, r4, r9)
315316
L3:
316-
r9 = r4 >= -4611686018427387904 :: signed
317-
if r9 goto L5 else goto L4 :: bool
318-
L4:
319-
r10 = CPyTagged_FromInt64(r4)
320-
r11 = r10
321-
goto L6
322-
L5:
323-
r12 = r4 << 1
324-
r11 = r12
325-
L6:
326-
r13 = CPyStr_GetItem(source, r11)
327-
x = r13
328-
r14 = f2(x)
329-
CPySequenceTuple_SetItemUnsafe(r3, r4, r14)
330-
L7:
331-
r15 = r4 + 1
332-
r4 = r15
317+
r10 = r4 + 1
318+
r4 = r10
333319
goto L1
334-
L8:
320+
L4:
335321
a = r3
336322
return 1
337323

0 commit comments

Comments
 (0)