Skip to content

Commit 09140b8

Browse files
committed
[mypyc] Add lowered primitive for unsafe list get item op
This inlines the list get item op in loops like `for x in <list>`. I estimated the impact using two microbenchmarks that iterate over `list[int]` objects. One of them was 1.3x faster, while the other was 1.09x faster. Since we now generate detailed IR for the op, instead of using a C primitive function, this also opens up further IR optimization opportunities in the future.
1 parent 3b00002 commit 09140b8

File tree

10 files changed

+77
-36
lines changed

10 files changed

+77
-36
lines changed

mypyc/irbuild/builder.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
IntOp,
8585
LoadStatic,
8686
Op,
87+
PrimitiveDescription,
8788
RaiseStandardError,
8889
Register,
8990
SetAttr,
@@ -381,6 +382,15 @@ def load_module(self, name: str) -> Value:
381382
def call_c(self, desc: CFunctionDescription, args: list[Value], line: int) -> Value:
382383
return self.builder.call_c(desc, args, line)
383384

385+
def primitive_op(
386+
self,
387+
desc: PrimitiveDescription,
388+
args: list[Value],
389+
line: int,
390+
result_type: RType | None = None,
391+
) -> Value:
392+
return self.builder.primitive_op(desc, args, line, result_type)
393+
384394
def int_op(self, type: RType, lhs: Value, rhs: Value, op: int, line: int) -> Value:
385395
return self.builder.int_op(type, lhs, rhs, op, line)
386396

@@ -739,7 +749,7 @@ def process_sequence_assignment(
739749
item = target.items[i]
740750
index = self.builder.load_int(i)
741751
if is_list_rprimitive(rvalue.type):
742-
item_value = self.call_c(list_get_item_unsafe_op, [rvalue, index], line)
752+
item_value = self.primitive_op(list_get_item_unsafe_op, [rvalue, index], line)
743753
else:
744754
item_value = self.builder.gen_method_call(
745755
rvalue, "__getitem__", [index], item.type, line

mypyc/irbuild/for_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ def unsafe_index(builder: IRBuilder, target: Value, index: Value, line: int) ->
693693
# since we want to use __getitem__ if we don't have an unsafe version,
694694
# so we just check manually.
695695
if is_list_rprimitive(target.type):
696-
return builder.call_c(list_get_item_unsafe_op, [target, index], line)
696+
return builder.primitive_op(list_get_item_unsafe_op, [target, index], line)
697697
else:
698698
return builder.gen_method_call(target, "__getitem__", [index], None, line)
699699

mypyc/lower/list_ops.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from mypyc.common import PLATFORM_SIZE
4-
from mypyc.ir.ops import GetElementPtr, Integer, IntOp, LoadMem, SetMem, Value
4+
from mypyc.ir.ops import GetElementPtr, IncRef, Integer, IntOp, LoadMem, SetMem, Value
55
from mypyc.ir.rtypes import (
66
PyListObject,
77
c_pyssize_t_rprimitive,
@@ -43,3 +43,24 @@ def buf_init_item(builder: LowLevelIRBuilder, args: list[Value], line: int) -> V
4343
def list_items(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
4444
ob_item_ptr = builder.add(GetElementPtr(args[0], PyListObject, "ob_item", line))
4545
return builder.add(LoadMem(pointer_rprimitive, ob_item_ptr, line))
46+
47+
48+
def list_item_ptr(builder: LowLevelIRBuilder, obj: Value, index: Value, line: int) -> Value:
49+
"""Get a pointer to a list item (index must be valid and non-negative).
50+
51+
Type of index must be c_pyssize_t_rprimitive.
52+
"""
53+
items = list_items(builder, [obj], line)
54+
delta = builder.add(
55+
IntOp(c_pyssize_t_rprimitive, index, Integer(8, c_pyssize_t_rprimitive), IntOp.MUL)
56+
)
57+
return builder.add(IntOp(pointer_rprimitive, items, delta, IntOp.ADD))
58+
59+
60+
@lower_primitive_op("list_get_item_unsafe")
61+
def list_get_item_unsafe(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
62+
index = builder.coerce(args[1], c_pyssize_t_rprimitive, line)
63+
item_ptr = list_item_ptr(builder, args[0], index, line)
64+
value = builder.add(LoadMem(object_rprimitive, item_ptr, line))
65+
builder.add(IncRef(value))
66+
return value

mypyc/primitives/list_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,10 @@
134134

135135
# This is unsafe because it assumes that the index is a non-negative short integer
136136
# that is in-bounds for the list.
137-
list_get_item_unsafe_op = custom_op(
137+
list_get_item_unsafe_op = custom_primitive_op(
138+
name="list_get_item_unsafe",
138139
arg_types=[list_rprimitive, short_int_rprimitive],
139140
return_type=object_rprimitive,
140-
c_function_name="CPyList_GetItemUnsafe",
141141
error_kind=ERR_NEVER,
142142
)
143143

mypyc/test-data/irbuild-basic.test

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1874,7 +1874,7 @@ L1:
18741874
r9 = int_lt r6, r8
18751875
if r9 goto L2 else goto L8 :: bool
18761876
L2:
1877-
r10 = CPyList_GetItemUnsafe(r1, r6)
1877+
r10 = list_get_item_unsafe r1, r6
18781878
r11 = unbox(int, r10)
18791879
x = r11
18801880
r12 = int_ne x, 4
@@ -1938,7 +1938,7 @@ L1:
19381938
r9 = int_lt r6, r8
19391939
if r9 goto L2 else goto L8 :: bool
19401940
L2:
1941-
r10 = CPyList_GetItemUnsafe(r1, r6)
1941+
r10 = list_get_item_unsafe r1, r6
19421942
r11 = unbox(int, r10)
19431943
x = r11
19441944
r12 = int_ne x, 4
@@ -2000,7 +2000,7 @@ L1:
20002000
r3 = int_lt r0, r2
20012001
if r3 goto L2 else goto L4 :: bool
20022002
L2:
2003-
r4 = CPyList_GetItemUnsafe(l, r0)
2003+
r4 = list_get_item_unsafe l, r0
20042004
r5 = unbox(tuple[int, int, int], r4)
20052005
r6 = r5[0]
20062006
x = r6
@@ -2022,7 +2022,7 @@ L5:
20222022
r15 = int_lt r12, r14
20232023
if r15 goto L6 else goto L8 :: bool
20242024
L6:
2025-
r16 = CPyList_GetItemUnsafe(l, r12)
2025+
r16 = list_get_item_unsafe l, r12
20262026
r17 = unbox(tuple[int, int, int], r16)
20272027
r18 = r17[0]
20282028
x_2 = r18

mypyc/test-data/irbuild-lists.test

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ L1:
341341
r5 = int_lt r2, r4
342342
if r5 goto L2 else goto L4 :: bool
343343
L2:
344-
r6 = CPyList_GetItemUnsafe(source, r2)
344+
r6 = list_get_item_unsafe source, r2
345345
r7 = unbox(int, r6)
346346
x = r7
347347
r8 = CPyTagged_Add(x, 2)
@@ -362,7 +362,7 @@ L5:
362362
r17 = int_lt r14, r16
363363
if r17 goto L6 else goto L8 :: bool
364364
L6:
365-
r18 = CPyList_GetItemUnsafe(source, r14)
365+
r18 = list_get_item_unsafe source, r14
366366
r19 = unbox(int, r18)
367367
x_2 = r19
368368
r20 = CPyTagged_Add(x_2, 2)
@@ -403,7 +403,7 @@ L1:
403403
r3 = int_lt r0, r2
404404
if r3 goto L2 else goto L4 :: bool
405405
L2:
406-
r4 = CPyList_GetItemUnsafe(x, r0)
406+
r4 = list_get_item_unsafe x, r0
407407
r5 = unbox(int, r4)
408408
i = r5
409409
r6 = box(int, i)
@@ -476,7 +476,7 @@ L1:
476476
r3 = int_lt r0, r2
477477
if r3 goto L2 else goto L4 :: bool
478478
L2:
479-
r4 = CPyList_GetItemUnsafe(a, r0)
479+
r4 = list_get_item_unsafe a, r0
480480
r5 = cast(union[str, bytes], r4)
481481
x = r5
482482
L3:
@@ -502,7 +502,7 @@ L1:
502502
r3 = int_lt r0, r2
503503
if r3 goto L2 else goto L4 :: bool
504504
L2:
505-
r4 = CPyList_GetItemUnsafe(a, r0)
505+
r4 = list_get_item_unsafe a, r0
506506
r5 = cast(union[str, None], r4)
507507
x = r5
508508
L3:

mypyc/test-data/irbuild-set.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ L1:
115115
r9 = int_lt r6, r8
116116
if r9 goto L2 else goto L4 :: bool
117117
L2:
118-
r10 = CPyList_GetItemUnsafe(tmp_list, r6)
118+
r10 = list_get_item_unsafe tmp_list, r6
119119
r11 = unbox(int, r10)
120120
x = r11
121121
r12 = f(x)
@@ -361,7 +361,7 @@ L1:
361361
r13 = int_lt r10, r12
362362
if r13 goto L2 else goto L6 :: bool
363363
L2:
364-
r14 = CPyList_GetItemUnsafe(tmp_list, r10)
364+
r14 = list_get_item_unsafe tmp_list, r10
365365
r15 = unbox(int, r14)
366366
z = r15
367367
r16 = int_lt z, 8

mypyc/test-data/irbuild-statements.test

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ L1:
246246
r3 = int_lt r0, r2
247247
if r3 goto L2 else goto L4 :: bool
248248
L2:
249-
r4 = CPyList_GetItemUnsafe(ls, r0)
249+
r4 = list_get_item_unsafe ls, r0
250250
r5 = unbox(int, r4)
251251
x = r5
252252
r6 = CPyTagged_Add(y, x)
@@ -594,8 +594,8 @@ def f(l, t):
594594
L0:
595595
r0 = CPySequence_CheckUnpackCount(l, 2)
596596
r1 = r0 >= 0 :: signed
597-
r2 = CPyList_GetItemUnsafe(l, 0)
598-
r3 = CPyList_GetItemUnsafe(l, 2)
597+
r2 = list_get_item_unsafe l, 0
598+
r3 = list_get_item_unsafe l, 2
599599
x = r2
600600
r4 = unbox(int, r3)
601601
y = r4
@@ -883,7 +883,7 @@ L1:
883883
r4 = int_lt r1, r3
884884
if r4 goto L2 else goto L4 :: bool
885885
L2:
886-
r5 = CPyList_GetItemUnsafe(a, r1)
886+
r5 = list_get_item_unsafe a, r1
887887
r6 = unbox(int, r5)
888888
x = r6
889889
r7 = CPyTagged_Add(i, x)
@@ -965,7 +965,7 @@ L2:
965965
r5 = PyIter_Next(r1)
966966
if is_error(r5) goto L7 else goto L3
967967
L3:
968-
r6 = CPyList_GetItemUnsafe(a, r0)
968+
r6 = list_get_item_unsafe a, r0
969969
r7 = unbox(int, r6)
970970
x = r7
971971
r8 = unbox(bool, r5)
@@ -1019,7 +1019,7 @@ L3:
10191019
L4:
10201020
r8 = unbox(bool, r3)
10211021
x = r8
1022-
r9 = CPyList_GetItemUnsafe(b, r1)
1022+
r9 = list_get_item_unsafe b, r1
10231023
r10 = unbox(int, r9)
10241024
y = r10
10251025
x = 0

mypyc/test-data/irbuild-tuple.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ L1:
265265
r10 = int_lt r7, r9
266266
if r10 goto L2 else goto L4 :: bool
267267
L2:
268-
r11 = CPyList_GetItemUnsafe(source, r7)
268+
r11 = list_get_item_unsafe source, r7
269269
r12 = unbox(int, r11)
270270
x = r12
271271
r13 = f(x)

mypyc/test-data/lowering-int.test

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -346,10 +346,14 @@ def f(l):
346346
r2 :: native_int
347347
r3 :: short_int
348348
r4 :: bit
349-
r5 :: object
350-
r6, x :: int
351-
r7 :: short_int
352-
r8 :: None
349+
r5 :: native_int
350+
r6, r7 :: ptr
351+
r8 :: native_int
352+
r9 :: ptr
353+
r10 :: object
354+
r11, x :: int
355+
r12 :: short_int
356+
r13 :: None
353357
L0:
354358
r0 = 0
355359
L1:
@@ -359,19 +363,25 @@ L1:
359363
r4 = r0 < r3 :: signed
360364
if r4 goto L2 else goto L5 :: bool
361365
L2:
362-
r5 = CPyList_GetItemUnsafe(l, r0)
363-
r6 = unbox(int, r5)
364-
dec_ref r5
365-
if is_error(r6) goto L6 (error at f:4) else goto L3
366+
r5 = r0 >> 1
367+
r6 = get_element_ptr l ob_item :: PyListObject
368+
r7 = load_mem r6 :: ptr*
369+
r8 = r5 * 8
370+
r9 = r7 + r8
371+
r10 = load_mem r9 :: builtins.object*
372+
inc_ref r10
373+
r11 = unbox(int, r10)
374+
dec_ref r10
375+
if is_error(r11) goto L6 (error at f:4) else goto L3
366376
L3:
367-
x = r6
377+
x = r11
368378
dec_ref x :: int
369379
L4:
370-
r7 = r0 + 2
371-
r0 = r7
380+
r12 = r0 + 2
381+
r0 = r12
372382
goto L1
373383
L5:
374384
return 1
375385
L6:
376-
r8 = <error> :: None
377-
return r8
386+
r13 = <error> :: None
387+
return r13

0 commit comments

Comments
 (0)