Skip to content

Commit 866c896

Browse files
committed
[mypyc] Fixing index variable in for-loop with builtins.enumerate.
1 parent 15cd6d3 commit 866c896

File tree

3 files changed

+30
-11
lines changed

3 files changed

+30
-11
lines changed

mypyc/irbuild/for_helpers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,6 @@ def init(self) -> None:
985985
zero = Integer(0)
986986
self.index_reg = builder.maybe_spill_assignable(zero)
987987
self.index_target: Register | AssignmentTarget = builder.get_assignment_target(self.index)
988-
builder.assign(self.index_target, zero, self.line)
989988

990989
def gen_step(self) -> None:
991990
builder = self.builder
@@ -997,7 +996,9 @@ def gen_step(self) -> None:
997996
short_int_rprimitive, builder.read(self.index_reg, line), Integer(1), IntOp.ADD, line
998997
)
999998
builder.assign(self.index_reg, new_val, line)
1000-
builder.assign(self.index_target, new_val, line)
999+
1000+
def begin_body(self) -> None:
1001+
self.builder.assign(self.index_target, self.builder.read(self.index_reg), self.line)
10011002

10021003

10031004
class ForEnumerate(ForGenerator):

mypyc/test-data/irbuild-statements.test

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -864,33 +864,31 @@ def g(x: Iterable[int]) -> None:
864864
[out]
865865
def f(a):
866866
a :: list
867-
r0 :: short_int
868-
i :: int
869-
r1 :: short_int
867+
r0, r1 :: short_int
870868
r2 :: native_int
871869
r3 :: short_int
872870
r4 :: bit
871+
i :: int
873872
r5 :: object
874873
r6, x, r7 :: int
875874
r8, r9 :: short_int
876875
L0:
877876
r0 = 0
878-
i = 0
879877
r1 = 0
880878
L1:
881879
r2 = var_object_size a
882880
r3 = r2 << 1
883881
r4 = int_lt r1, r3
884882
if r4 goto L2 else goto L4 :: bool
885883
L2:
884+
i = r0
886885
r5 = CPyList_GetItemUnsafe(a, r1)
887886
r6 = unbox(int, r5)
888887
x = r6
889888
r7 = CPyTagged_Add(i, x)
890889
L3:
891890
r8 = r0 + 2
892891
r0 = r8
893-
i = r8
894892
r9 = r1 + 2
895893
r1 = r9
896894
goto L1
@@ -900,25 +898,23 @@ L5:
900898
def g(x):
901899
x :: object
902900
r0 :: short_int
903-
i :: int
904901
r1, r2 :: object
905-
r3, n :: int
902+
i, r3, n :: int
906903
r4 :: short_int
907904
r5 :: bit
908905
L0:
909906
r0 = 0
910-
i = 0
911907
r1 = PyObject_GetIter(x)
912908
L1:
913909
r2 = PyIter_Next(r1)
914910
if is_error(r2) goto L4 else goto L2
915911
L2:
912+
i = r0
916913
r3 = unbox(int, r2)
917914
n = r3
918915
L3:
919916
r4 = r0 + 2
920917
r0 = r4
921-
i = r4
922918
goto L1
923919
L4:
924920
r5 = CPy_NoErrOccured()

mypyc/test-data/run-loops.test

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def nested_enumerate() -> None:
228228
assert i == inner
229229
inner += 1
230230
outer += 1
231+
assert i == 2
231232
assert outer_seen == l1
232233

233234
def nested_range() -> None:
@@ -465,6 +466,27 @@ assert g([6, 7], ['a', 'b']) == [(0, 6, 'a'), (1, 7, 'b')]
465466
assert f([6, 7], [8]) == [(0, 6, 8)]
466467
assert f([6], [8, 9]) == [(0, 6, 8)]
467468

469+
[case testEnumerateEmptyList]
470+
def get_enumerate_locals(iterable: list[int]) -> int:
471+
for i, j in enumerate(iterable):
472+
pass
473+
try:
474+
return i
475+
except NameError:
476+
return -100
477+
478+
[file driver.py]
479+
from native import get_enumerate_locals
480+
481+
print(get_enumerate_locals([]))
482+
print(get_enumerate_locals([55]))
483+
print(get_enumerate_locals([551, 552]))
484+
485+
[out]
486+
-100
487+
0
488+
1
489+
468490
[case testIterTypeTrickiness]
469491
# Test inferring the type of a for loop body doesn't cause us grief
470492
# Extracted from somethings that broke in mypy

0 commit comments

Comments
 (0)