Skip to content

Commit d39eacc

Browse files
authored
[mypyc] Fixing index variable in for-loop with builtins.enumerate. (#18202)
Fixes [mypyc/mypyc#1046](mypyc/mypyc#1046) This change fixes two problems: 1. The index variable was getting instantiated even while enumerating an empty iterable. 2. After exiting the for-loop, the value of the index variable is off by 1 (see issue linked above). This change fixes both problems by assigning the temporary register to the index variable at the beginning of the for-loop body. Before this change, this assignment was happening before the for-loop and at the end of the for-loop body.
1 parent 2842e8f commit d39eacc

File tree

3 files changed

+32
-11
lines changed

3 files changed

+32
-11
lines changed

mypyc/irbuild/for_helpers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,6 @@ def init(self) -> None:
985985
zero = Integer(0)
986986
self.index_reg = builder.maybe_spill_assignable(zero)
987987
self.index_target: Register | AssignmentTarget = builder.get_assignment_target(self.index)
988-
builder.assign(self.index_target, zero, self.line)
989988

990989
def gen_step(self) -> None:
991990
builder = self.builder
@@ -997,7 +996,9 @@ def gen_step(self) -> None:
997996
short_int_rprimitive, builder.read(self.index_reg, line), Integer(1), IntOp.ADD, line
998997
)
999998
builder.assign(self.index_reg, new_val, line)
1000-
builder.assign(self.index_target, new_val, line)
999+
1000+
def begin_body(self) -> None:
1001+
self.builder.assign(self.index_target, self.builder.read(self.index_reg), self.line)
10011002

10021003

10031004
class ForEnumerate(ForGenerator):

mypyc/test-data/irbuild-statements.test

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -864,33 +864,31 @@ def g(x: Iterable[int]) -> None:
864864
[out]
865865
def f(a):
866866
a :: list
867-
r0 :: short_int
868-
i :: int
869-
r1 :: short_int
867+
r0, r1 :: short_int
870868
r2 :: native_int
871869
r3 :: short_int
872870
r4 :: bit
871+
i :: int
873872
r5 :: object
874873
r6, x, r7 :: int
875874
r8, r9 :: short_int
876875
L0:
877876
r0 = 0
878-
i = 0
879877
r1 = 0
880878
L1:
881879
r2 = var_object_size a
882880
r3 = r2 << 1
883881
r4 = int_lt r1, r3
884882
if r4 goto L2 else goto L4 :: bool
885883
L2:
884+
i = r0
886885
r5 = CPyList_GetItemUnsafe(a, r1)
887886
r6 = unbox(int, r5)
888887
x = r6
889888
r7 = CPyTagged_Add(i, x)
890889
L3:
891890
r8 = r0 + 2
892891
r0 = r8
893-
i = r8
894892
r9 = r1 + 2
895893
r1 = r9
896894
goto L1
@@ -900,25 +898,23 @@ L5:
900898
def g(x):
901899
x :: object
902900
r0 :: short_int
903-
i :: int
904901
r1, r2 :: object
905-
r3, n :: int
902+
i, r3, n :: int
906903
r4 :: short_int
907904
r5 :: bit
908905
L0:
909906
r0 = 0
910-
i = 0
911907
r1 = PyObject_GetIter(x)
912908
L1:
913909
r2 = PyIter_Next(r1)
914910
if is_error(r2) goto L4 else goto L2
915911
L2:
912+
i = r0
916913
r3 = unbox(int, r2)
917914
n = r3
918915
L3:
919916
r4 = r0 + 2
920917
r0 = r4
921-
i = r4
922918
goto L1
923919
L4:
924920
r5 = CPy_NoErrOccured()

mypyc/test-data/run-loops.test

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def nested_enumerate() -> None:
228228
assert i == inner
229229
inner += 1
230230
outer += 1
231+
assert i == 2
231232
assert outer_seen == l1
232233

233234
def nested_range() -> None:
@@ -465,6 +466,29 @@ assert g([6, 7], ['a', 'b']) == [(0, 6, 'a'), (1, 7, 'b')]
465466
assert f([6, 7], [8]) == [(0, 6, 8)]
466467
assert f([6], [8, 9]) == [(0, 6, 8)]
467468

469+
[case testEnumerateEmptyList]
470+
from typing import List
471+
472+
def get_enumerate_locals(iterable: List[int]) -> int:
473+
for i, j in enumerate(iterable):
474+
pass
475+
try:
476+
return i
477+
except NameError:
478+
return -100
479+
480+
[file driver.py]
481+
from native import get_enumerate_locals
482+
483+
print(get_enumerate_locals([]))
484+
print(get_enumerate_locals([55]))
485+
print(get_enumerate_locals([551, 552]))
486+
487+
[out]
488+
-100
489+
0
490+
1
491+
468492
[case testIterTypeTrickiness]
469493
# Test inferring the type of a for loop body doesn't cause us grief
470494
# Extracted from somethings that broke in mypy

0 commit comments

Comments
 (0)