Skip to content

Commit c3df28b

Browse files
committed
fix: rtuple and feat: general literal sequence
1 parent 126f300 commit c3df28b

File tree

3 files changed

+188
-85
lines changed

3 files changed

+188
-85
lines changed

mypyc/irbuild/for_helpers.py

Lines changed: 132 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,18 @@
77

88
from __future__ import annotations
99

10-
from typing import Callable, ClassVar
10+
from typing import Any, Callable, ClassVar, Union
1111

1212
from mypy.nodes import (
1313
ARG_POS,
14+
BytesExpr,
1415
CallExpr,
1516
DictionaryComprehension,
1617
Expression,
18+
FloatExpr,
1719
GeneratorExpr,
20+
IntExpr,
21+
ListExpr,
1822
Lvalue,
1923
MemberExpr,
2024
NameExpr,
@@ -385,6 +389,27 @@ def is_range_ref(expr: RefExpr) -> bool:
385389
)
386390

387391

392+
def is_literal_expr(expr: Expression) -> bool:
393+
# Add other literal types as needed
394+
if isinstance(expr, (IntExpr, StrExpr, FloatExpr, BytesExpr)):
395+
return True
396+
if isinstance(expr, NameExpr) and expr.fullname in {"builtins.None", "builtins.True", "builtins.False"}:
397+
return True
398+
return False
399+
400+
401+
def is_iterable_expr_with_literal_mambers(expr: Expression) -> bool:
402+
return (
403+
isinstance(expr, (ListExpr, SetExpr, TupleExpr))
404+
and not isinstance(expr, MemberExpr)
405+
and all(
406+
is_literal_expr(item)
407+
or is_iterable_expr_with_literal_mambers(item)
408+
for item in expr.items
409+
)
410+
)
411+
412+
388413
def make_for_loop_generator(
389414
builder: IRBuilder,
390415
index: Lvalue,
@@ -413,21 +438,22 @@ def make_for_loop_generator(
413438
rtyp = builder.node_type(expr)
414439

415440
# Special case: tuple literal (unroll the loop)
416-
if isinstance(expr, TupleExpr):
417-
return ForUnrolledLiteral(builder, index, body_block, loop_exit, line, expr.items, expr, body_insts)
441+
if is_iterable_expr_with_literal_mambers(expr):
442+
return ForUnrolledSequenceLiteral(builder, index, body_block, loop_exit, line, expr, body_insts)
418443

419-
# Special case: RTuple (known-length tuple, index-based iteration)
444+
# Special case: RTuple (known-length tuple, struct field iteration)
420445
if isinstance(rtyp, RTuple):
421446
expr_reg = builder.accept(expr)
422-
target_type = builder.get_sequence_type(expr)
423-
for_tuple = ForSequence(builder, index, body_block, loop_exit, line, nested)
424-
for_tuple.init(expr_reg, target_type, reverse=False)
425-
return for_tuple
447+
return ForUnrolledRTuple(builder, index, body_block, loop_exit, line, rtyp, expr_reg, expr, body_insts)
426448

427449
# Special case: string literal (unroll the loop)
428450
if isinstance(expr, StrExpr):
429451
return ForUnrolledStringLiteral(builder, index, body_block, loop_exit, line, expr.value, expr, body_insts)
430452

453+
# Special case: string literal (unroll the loop)
454+
if isinstance(expr, BytesExpr):
455+
return ForUnrolledBytesLiteral(builder, index, body_block, loop_exit, line, expr.value, expr, body_insts)
456+
431457
if is_sequence_rprimitive(rtyp):
432458
# Special case "for x in <list>".
433459
expr_reg = builder.accept(expr)
@@ -790,7 +816,28 @@ def gen_step(self) -> None:
790816
pass
791817

792818

793-
class ForUnrolledLiteral(ForGenerator):
819+
class _ForUnrolled(ForGenerator):
820+
"""Generate IR for a for loop over a value known at compile time by unrolling the loop.
821+
822+
This class emits the loop body for each element of the value literal directly,
823+
avoiding any runtime iteration logic and generator handling.
824+
"""
825+
826+
def __init__(self, *args: Any, **kwargs: Any):
827+
if type(self) is _ForUnrolled:
828+
raise NotImplementedError("This is a base class and should not be initialized directly.")
829+
super().__init__(*args, **kwargs)
830+
831+
def gen_condition(self) -> None:
832+
# Unrolled: nothing to do here.
833+
pass
834+
835+
def gen_step(self) -> None:
836+
# Unrolled: nothing to do here.
837+
pass
838+
839+
840+
class ForUnrolledSequenceLiteral(_ForUnrolled):
794841
"""Generate IR for a for loop over a tuple literal by unrolling the loop.
795842
796843
This class emits the loop body for each element of the tuple literal directly,
@@ -805,31 +852,25 @@ def __init__(
805852
body_block: BasicBlock,
806853
loop_exit: BasicBlock,
807854
line: int,
808-
items: list[Expression],
809-
expr: Expression,
855+
expr: Union[ListExpr, SetExpr, TupleExpr],
810856
body_insts: GenFunc,
811857
) -> None:
812858
super().__init__(builder, index, body_block, loop_exit, line, nested=False)
813-
self.items = items
814859
self.expr = expr
860+
self.items = expr.items
815861
self.body_insts = body_insts
816-
817-
def gen_condition(self) -> None:
818-
# Unrolled: nothing to do here.
819-
pass
862+
self.item_types = [builder.node_type(item) for item in self.items]
820863

821864
def begin_body(self) -> None:
822865
builder = self.builder
823-
for item in self.items:
824-
builder.assign(builder.get_assignment_target(self.index), builder.accept(item), self.line)
866+
for item, item_type in zip(self.items, self.item_types):
867+
value = builder.accept(item)
868+
value = builder.coerce(value, item_type, self.line)
869+
builder.assign(builder.get_assignment_target(self.index), value, self.line)
825870
self.body_insts()
826871

827-
def gen_step(self) -> None:
828-
# Unrolled: nothing to do here.
829-
pass
830-
831872

832-
class ForUnrolledStringLiteral(ForGenerator):
873+
class ForUnrolledStringLiteral(_ForUnrolled):
833874
"""Generate IR for a for loop over a string literal by unrolling the loop.
834875
835876
This class emits the loop body for each character of the string literal directly,
@@ -853,10 +894,6 @@ def __init__(
853894
self.expr = expr
854895
self.body_insts = body_insts
855896

856-
def gen_condition(self) -> None:
857-
# Unrolled: nothing to do here.
858-
pass
859-
860897
def begin_body(self) -> None:
861898
builder = self.builder
862899
for c in self.value:
@@ -867,9 +904,74 @@ def begin_body(self) -> None:
867904
)
868905
self.body_insts()
869906

870-
def gen_step(self) -> None:
871-
# Unrolled: nothing to do here.
872-
pass
907+
908+
class ForUnrolledBytesLiteral(_ForUnrolled):
909+
"""Generate IR for a for loop over a string literal by unrolling the loop.
910+
911+
This class emits the loop body for each character of the string literal directly,
912+
avoiding any runtime iteration logic.
913+
"""
914+
handles_body_insts = True
915+
916+
def __init__(
917+
self,
918+
builder: IRBuilder,
919+
index: Lvalue,
920+
body_block: BasicBlock,
921+
loop_exit: BasicBlock,
922+
line: int,
923+
value: bytes,
924+
expr: Expression,
925+
body_insts: GenFunc,
926+
) -> None:
927+
super().__init__(builder, index, body_block, loop_exit, line, nested=False)
928+
self.value = value
929+
self.expr = expr
930+
self.body_insts = body_insts
931+
932+
def begin_body(self) -> None:
933+
builder = self.builder
934+
for c in self.value:
935+
builder.assign(
936+
builder.get_assignment_target(self.index),
937+
builder.accept(IntExpr(c)),
938+
self.line,
939+
)
940+
self.body_insts()
941+
942+
943+
class ForUnrolledRTuple(_ForUnrolled):
944+
"""Generate IR for a for loop over an RTuple by directly accessing struct fields."""
945+
946+
handles_body_insts = True
947+
948+
def __init__(
949+
self,
950+
builder: IRBuilder,
951+
index: Lvalue,
952+
body_block: BasicBlock,
953+
loop_exit: BasicBlock,
954+
line: int,
955+
rtuple_type: RTuple,
956+
expr_reg: Value,
957+
expr: Expression,
958+
body_insts: GenFunc,
959+
) -> None:
960+
super().__init__(builder, index, body_block, loop_exit, line, nested=False)
961+
self.rtuple_type = rtuple_type
962+
self.expr_reg = expr_reg
963+
self.expr = expr
964+
self.body_insts = body_insts
965+
966+
def begin_body(self) -> None:
967+
builder = self.builder
968+
line = self.line
969+
for i, item_type in enumerate(self.rtuple_type.types):
970+
# Directly access the struct field for each RTuple element
971+
value = builder.add(TupleGet(self.expr_reg, i, line))
972+
value = builder.coerce(value, item_type, line)
973+
builder.assign(builder.get_assignment_target(self.index), value, line)
974+
self.body_insts()
873975

874976

875977
def unsafe_index(builder: IRBuilder, target: Value, index: Value, line: int) -> Value:

mypyc/test-data/irbuild-basic.test

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3605,15 +3605,17 @@ def f(t: Tuple[int, int]) -> int:
36053605
[out]
36063606
def f(t):
36073607
t :: tuple[int, int]
3608-
s, x :: int
3609-
r0 :: int
3608+
s, r0, x, r1, r2, r3 :: int
3609+
L0:
36103610
s = 0
3611-
x = t.f0
3612-
L1:
3613-
s = CPyTagged_Add(s, x)
3614-
x = t.f1
3615-
L2:
3616-
s = CPyTagged_Add(s, x)
3611+
r0 = t[0]
3612+
x = r0
3613+
r1 = CPyTagged_Add(s, x)
3614+
s = r1
3615+
r2 = t[1]
3616+
x = r2
3617+
r3 = CPyTagged_Add(s, x)
3618+
s = r3
36173619
return s
36183620

36193621
[case testForOverStringVar]
@@ -3640,3 +3642,12 @@ def f(s):
36403642
goto L1
36413643
L4:
36423644
return out
3645+
3646+
[case TestForOverCompledTupleExpr]
3647+
def f() -> None:
3648+
abc = (1, 2, str(3))
3649+
for x in abc:
3650+
y = x
3651+
[out]
3652+
def f():
3653+
abc :: tuple[int, int, str]

mypyc/test-data/irbuild-set.test

Lines changed: 37 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -129,41 +129,41 @@ L4:
129129
def test2():
130130
r0, tmp_tuple :: tuple[int, int, int]
131131
r1 :: set
132-
r2 :: native_int
133-
r3 :: object
134-
r4 :: native_int
135-
r5, r6 :: bit
136-
r7 :: object
137-
r8 :: bit
138-
r9, r10, r12 :: int
139-
r13 :: object
140-
r14, x, r15 :: int
141-
r16 :: object
142-
r17 :: i32
143-
r18 :: bit
144-
r19 :: native_int
132+
r2, x, r3 :: int
133+
r4 :: object
134+
r5 :: i32
135+
r6 :: bit
136+
r7, r8 :: int
137+
r9 :: object
138+
r10 :: i32
139+
r11 :: bit
140+
r12, r13 :: int
141+
r14 :: object
142+
r15 :: i32
143+
r16 :: bit
145144
b :: set
146145
L0:
147146
r0 = (2, 6, 10)
148147
tmp_tuple = r0
149148
r1 = PySet_New(0)
150-
r2 = box(tuple[int, int, int], tmp_tuple)
151-
r3 = PyObject_GetIter(r2)
152-
L1:
153-
r4 = PyIter_Next(r3)
154-
if is_error(r4) goto L4 else goto L2
155-
L2:
156-
r5 = unbox(int, r4)
157-
x = r5
158-
r6 = f(x)
159-
r7 = box(int, r6)
160-
r8 = PySet_Add(r1, r7)
161-
r9 = r8 >= 0 :: signed
162-
L3:
163-
goto L1
164-
L4:
165-
r10 = CPy_NoErrOccurred()
166-
L5:
149+
r2 = tmp_tuple[0]
150+
x = r2
151+
r3 = f(x)
152+
r4 = box(int, r3)
153+
r5 = PySet_Add(r1, r4)
154+
r6 = r5 >= 0 :: signed
155+
r7 = tmp_tuple[1]
156+
x = r7
157+
r8 = f(x)
158+
r9 = box(int, r8)
159+
r10 = PySet_Add(r1, r9)
160+
r11 = r10 >= 0 :: signed
161+
r12 = tmp_tuple[2]
162+
x = r12
163+
r13 = f(x)
164+
r14 = box(int, r13)
165+
r15 = PySet_Add(r1, r14)
166+
r16 = r15 >= 0 :: signed
167167
b = r1
168168
return 1
169169
def test3():
@@ -735,25 +735,15 @@ def not_precomputed() -> None:
735735

736736
[out]
737737
def precomputed():
738-
r0 :: set
739-
r1, r2 :: object
740-
r3 :: str
738+
r0 :: str
741739
_ :: object
742-
r4 :: bit
743740
L0:
744-
r0 = frozenset({'False', 'None', 'True'})
745-
r1 = PyObject_GetIter(r0)
746-
L1:
747-
r2 = PyIter_Next(r1)
748-
if is_error(r2) goto L4 else goto L2
749-
L2:
750-
r3 = cast(str, r2)
751-
_ = r3
752-
L3:
753-
goto L1
754-
L4:
755-
r4 = CPy_NoErrOccurred()
756-
L5:
741+
r0 = 'None'
742+
_ = r0
743+
r1 = 'True'
744+
_ = r1
745+
r2 = 'False'
746+
_ = r2
757747
return 1
758748
def precomputed2():
759749
r0 :: set

0 commit comments

Comments
 (0)