Skip to content

Commit e007010

Browse files
committed
feat: len only once
1 parent 298e5db commit e007010

File tree

5 files changed

+185
-142
lines changed

5 files changed

+185
-142
lines changed

mypyc/ir/rtypes.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,15 @@ def is_sequence_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
635635
)
636636

637637

638+
def is_immutable_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
639+
return (
640+
is_str_rprimitive(rtype)
641+
or is_bytes_rprimitive(rtype)
642+
or is_tuple_rprimitive(rtype)
643+
or is_frozenset_rprimitive(rtype)
644+
)
645+
646+
638647
class TupleNameVisitor(RTypeVisitor[str]):
639648
"""Produce a tuple name based on the concrete representations of types."""
640649

mypyc/irbuild/for_helpers.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,19 @@
1111

1212
from mypy.nodes import (
1313
ARG_POS,
14+
BytesExpr,
1415
CallExpr,
1516
DictionaryComprehension,
1617
Expression,
1718
GeneratorExpr,
19+
ListExpr,
1820
Lvalue,
1921
MemberExpr,
2022
NameExpr,
23+
Optional,
2124
RefExpr,
2225
SetExpr,
26+
StrExpr,
2327
TupleExpr,
2428
TypeAlias,
2529
)
@@ -46,9 +50,9 @@
4650
bool_rprimitive,
4751
c_pyssize_t_rprimitive,
4852
int_rprimitive,
49-
is_bytes_rprimitive,
5053
is_dict_rprimitive,
5154
is_fixed_width_rtype,
55+
is_immutable_rprimitive,
5256
is_list_rprimitive,
5357
is_sequence_rprimitive,
5458
is_short_int_rprimitive,
@@ -152,6 +156,7 @@ def for_loop_helper_with_index(
152156
expr_reg: Value,
153157
body_insts: Callable[[Value], None],
154158
line: int,
159+
length: Value,
155160
) -> None:
156161
"""Generate IR for a sequence iteration.
157162
@@ -172,7 +177,7 @@ def for_loop_helper_with_index(
172177
exit_block = BasicBlock()
173178
condition_block = BasicBlock()
174179

175-
for_gen = ForSequence(builder, index, body_block, exit_block, line, False)
180+
for_gen = ForSequence(builder, index, body_block, exit_block, line, False, length)
176181
for_gen.init(expr_reg, target_type, reverse=False)
177182

178183
builder.push_loop_stack(step_block, exit_block)
@@ -227,15 +232,15 @@ def sequence_from_generator_preallocate_helper(
227232
rtype = builder.node_type(gen.sequences[0])
228233
if is_sequence_rprimitive(rtype):
229234
sequence = builder.accept(gen.sequences[0])
230-
length = builder.builder.builtin_len(sequence, gen.line, use_pyssize_t=True)
235+
length = get_expr_length_value(builder, gen.sequences[0], sequence, gen.line, use_pyssize_t=True)
231236
target_op = empty_op_llbuilder(length, gen.line)
232237

233238
def set_item(item_index: Value) -> None:
234239
e = builder.accept(gen.left_expr)
235240
builder.call_c(set_item_op, [target_op, item_index, e], gen.line)
236241

237242
for_loop_helper_with_index(
238-
builder, gen.indices[0], gen.sequences[0], sequence, set_item, gen.line
243+
builder, gen.indices[0], gen.sequences[0], sequence, set_item, gen.line, length
239244
)
240245

241246
return target_op
@@ -788,20 +793,34 @@ class ForSequence(ForGenerator):
788793

789794
length_reg: Value | AssignmentTarget | None
790795

796+
def __init__(
797+
self,
798+
builder: IRBuilder,
799+
index: Lvalue,
800+
body_block: BasicBlock,
801+
loop_exit: BasicBlock,
802+
line: int,
803+
nested: bool,
804+
length: Value | None = None,
805+
) -> None:
806+
super().__init__(builder, index, body_block, loop_exit, line, nested)
807+
self.length = length
808+
"""A Value representing the length of the sequence, if known."""
809+
791810
def init(self, expr_reg: Value, target_type: RType, reverse: bool) -> None:
811+
assert is_sequence_rprimitive(expr_reg.type), expr_reg
792812
builder = self.builder
793813
self.reverse = reverse
794814
# Define target to contain the expression, along with the index that will be used
795815
# for the for-loop. If we are inside of a generator function, spill these into the
796816
# environment class.
797817
self.expr_target = builder.maybe_spill(expr_reg)
798-
if (
799-
is_tuple_rprimitive(expr_reg.type)
800-
or is_str_rprimitive(expr_reg.type)
801-
or is_bytes_rprimitive(expr_reg.type)
802-
):
803-
self.length_reg = builder.maybe_spill(self.load_len(self.expr_target))
818+
if is_immutable_rprimitive(expr_reg.type):
819+
# If the expression is an immutable type, we can load the length just once.
820+
self.length_reg = builder.maybe_spill(self.length or self.load_len(self.expr_target))
804821
else:
822+
# Otherwise, even if the length is known, we must recalculate the length
823+
# at every iteration for compatibility with python semantics.
805824
self.length_reg = None
806825
if not reverse:
807826
index_reg: Value = Integer(0, c_pyssize_t_rprimitive)
@@ -1166,3 +1185,24 @@ def gen_step(self) -> None:
11661185
def gen_cleanup(self) -> None:
11671186
for gen in self.gens:
11681187
gen.gen_cleanup()
1188+
1189+
1190+
def get_expr_length(expr: Expression) -> Optional[int]:
1191+
if isinstance(expr, (StrExpr, BytesExpr)):
1192+
return len(expr.value)
1193+
elif isinstance(expr, (ListExpr, TupleExpr)):
1194+
if all(get_expr_length(i) is not None for i in expr.items):
1195+
return len(expr.items)
1196+
# TODO: extend this, unrolling should come with a good performance boost
1197+
return None
1198+
1199+
1200+
def get_expr_length_value(builder: IRBuilder, expr: Expression, expr_reg, line: int, use_pyssize_t: bool) -> Value:
1201+
rtype = builder.node_type(expr)
1202+
assert is_sequence_rprimitive(rtype), rtype
1203+
length = get_expr_length(expr)
1204+
if length is None:
1205+
# We cannot compute the length at compile time, so we will fetch it.
1206+
return builder.builder.builtin_len(expr_reg, line, use_pyssize_t=use_pyssize_t)
1207+
# The expression result is known at compile time, so we can use a constant.
1208+
return Integer(length, c_pyssize_t_rprimitive if use_pyssize_t else short_int_rprimitive)

mypyc/test-data/fixtures/ir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def __getitem__(self, i: int) -> int: ...
171171
def __getitem__(self, i: slice) -> bytes: ...
172172
def join(self, x: Iterable[object]) -> bytes: ...
173173
def decode(self, x: str=..., y: str=...) -> str: ...
174+
def __iter__(self) -> Iterator[int]: ...
174175

175176
class bytearray:
176177
@overload

mypyc/test-data/irbuild-generics.test

Lines changed: 69 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -689,94 +689,93 @@ def inner_deco_obj.__call__(__mypyc_self__, args, kwargs):
689689
r0 :: __main__.deco_env
690690
r1 :: native_int
691691
r2 :: list
692-
r3, r4 :: native_int
693-
r5 :: bit
694-
r6, x :: object
695-
r7 :: native_int
692+
r3 :: native_int
693+
r4 :: bit
694+
r5, x :: object
695+
r6 :: native_int
696696
can_listcomp :: list
697-
r8 :: dict
698-
r9 :: short_int
699-
r10 :: native_int
700-
r11 :: object
701-
r12 :: tuple[bool, short_int, object, object]
702-
r13 :: short_int
703-
r14 :: bool
704-
r15, r16 :: object
705-
r17, k :: str
697+
r7 :: dict
698+
r8 :: short_int
699+
r9 :: native_int
700+
r10 :: object
701+
r11 :: tuple[bool, short_int, object, object]
702+
r12 :: short_int
703+
r13 :: bool
704+
r14, r15 :: object
705+
r16, k :: str
706706
v :: object
707-
r18 :: i32
708-
r19, r20, r21 :: bit
707+
r17 :: i32
708+
r18, r19, r20 :: bit
709709
can_dictcomp :: dict
710-
r22, can_iter, r23, can_use_keys, r24, can_use_values :: list
711-
r25 :: object
712-
r26 :: list
713-
r27 :: object
714-
r28 :: dict
715-
r29 :: i32
716-
r30 :: bit
717-
r31 :: tuple
718-
r32 :: object
719-
r33 :: int
710+
r21, can_iter, r22, can_use_keys, r23, can_use_values :: list
711+
r24 :: object
712+
r25 :: list
713+
r26 :: object
714+
r27 :: dict
715+
r28 :: i32
716+
r29 :: bit
717+
r30 :: tuple
718+
r31 :: object
719+
r32 :: int
720720
L0:
721721
r0 = __mypyc_self__.__mypyc_env__
722722
r1 = var_object_size args
723723
r2 = PyList_New(r1)
724-
r3 = var_object_size args
725-
r4 = 0
724+
r3 = 0
726725
L1:
727-
r5 = r4 < r3 :: signed
728-
if r5 goto L2 else goto L4 :: bool
726+
r4 = r3 < r1 :: signed
727+
if r4 goto L2 else goto L4 :: bool
729728
L2:
730-
r6 = CPySequenceTuple_GetItemUnsafe(args, r4)
731-
x = r6
732-
CPyList_SetItemUnsafe(r2, r4, x)
729+
r5 = CPySequenceTuple_GetItemUnsafe(args, r3)
730+
x = r5
731+
CPyList_SetItemUnsafe(r2, r3, x)
733732
L3:
734-
r7 = r4 + 1
735-
r4 = r7
733+
r6 = r3 + 1
734+
r3 = r6
736735
goto L1
737736
L4:
738737
can_listcomp = r2
739-
r8 = PyDict_New()
740-
r9 = 0
741-
r10 = PyDict_Size(kwargs)
742-
r11 = CPyDict_GetItemsIter(kwargs)
738+
r7 = PyDict_New()
739+
r8 = 0
740+
r9 = PyDict_Size(kwargs)
741+
r10 = CPyDict_GetItemsIter(kwargs)
743742
L5:
744-
r12 = CPyDict_NextItem(r11, r9)
745-
r13 = r12[1]
746-
r9 = r13
747-
r14 = r12[0]
748-
if r14 goto L6 else goto L8 :: bool
743+
r11 = CPyDict_NextItem(r10, r8)
744+
r12 = r11[1]
745+
r8 = r12
746+
r13 = r11[0]
747+
if r13 goto L6 else goto L8 :: bool
749748
L6:
750-
r15 = r12[2]
751-
r16 = r12[3]
752-
r17 = cast(str, r15)
753-
k = r17
754-
v = r16
755-
r18 = CPyDict_SetItem(r8, k, v)
756-
r19 = r18 >= 0 :: signed
749+
r14 = r11[2]
750+
r15 = r11[3]
751+
r16 = cast(str, r14)
752+
k = r16
753+
v = r15
754+
r17 = CPyDict_SetItem(r7, k, v)
755+
r18 = r17 >= 0 :: signed
757756
L7:
758-
r20 = CPyDict_CheckSize(kwargs, r10)
757+
r19 = CPyDict_CheckSize(kwargs, r9)
759758
goto L5
760759
L8:
761-
r21 = CPy_NoErrOccurred()
760+
r20 = CPy_NoErrOccurred()
762761
L9:
763-
can_dictcomp = r8
764-
r22 = PySequence_List(kwargs)
765-
can_iter = r22
766-
r23 = CPyDict_Keys(kwargs)
767-
can_use_keys = r23
768-
r24 = CPyDict_Values(kwargs)
769-
can_use_values = r24
770-
r25 = r0.func
771-
r26 = PyList_New(0)
772-
r27 = CPyList_Extend(r26, args)
773-
r28 = PyDict_New()
774-
r29 = CPyDict_UpdateInDisplay(r28, kwargs)
775-
r30 = r29 >= 0 :: signed
776-
r31 = PyList_AsTuple(r26)
777-
r32 = PyObject_Call(r25, r31, r28)
778-
r33 = unbox(int, r32)
779-
return r33
762+
can_dictcomp = r7
763+
r21 = PySequence_List(kwargs)
764+
can_iter = r21
765+
r22 = CPyDict_Keys(kwargs)
766+
can_use_keys = r22
767+
r23 = CPyDict_Values(kwargs)
768+
can_use_values = r23
769+
r24 = r0.func
770+
r25 = PyList_New(0)
771+
r26 = CPyList_Extend(r25, args)
772+
r27 = PyDict_New()
773+
r28 = CPyDict_UpdateInDisplay(r27, kwargs)
774+
r29 = r28 >= 0 :: signed
775+
r30 = PyList_AsTuple(r25)
776+
r31 = PyObject_Call(r24, r30, r27)
777+
r32 = unbox(int, r31)
778+
return r32
780779
def deco(func):
781780
func :: object
782781
r0 :: __main__.deco_env
@@ -795,3 +794,4 @@ def f(x):
795794
x :: int
796795
L0:
797796
return x
797+

0 commit comments

Comments
 (0)