Skip to content
Open
16 changes: 12 additions & 4 deletions mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,18 +1203,18 @@ def gen_cleanup(self) -> None:
gen.gen_cleanup()


def get_expr_length(expr: Expression) -> int | None:
def get_expr_length(builder: IRBuilder, expr: Expression) -> int | None:
if isinstance(expr, (StrExpr, BytesExpr)):
return len(expr.value)
elif isinstance(expr, (ListExpr, TupleExpr)):
# if there are no star expressions, or we know the length of them,
# we know the length of the expression
stars = [get_expr_length(i) for i in expr.items if isinstance(i, StarExpr)]
stars = [get_expr_length(builder, i) for i in expr.items if isinstance(i, StarExpr)]
if None not in stars:
other = sum(not isinstance(i, StarExpr) for i in expr.items)
return other + sum(stars) # type: ignore [arg-type]
elif isinstance(expr, StarExpr):
return get_expr_length(expr.expr)
return get_expr_length(builder, expr.expr)
elif (
isinstance(expr, RefExpr)
and isinstance(expr.node, Var)
Expand All @@ -1227,6 +1227,14 @@ def get_expr_length(expr: Expression) -> int | None:
# performance boost and can be (sometimes) figured out pretty easily. set and dict
# comps *can* be done as well but will need special logic to consider the possibility
# of key conflicts. Range, enumerate, zip are all simple logic.

# we might still be able to get the length directly from the type
rtype = builder.node_type(expr)
if isinstance(rtype, RTuple):
return len(rtype.types)
proper_type = get_proper_type(builder.types[expr])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this or will the check above it catch every case?

if isinstance(proper_type, TupleType):
return len(proper_type.items)
return None


Expand All @@ -1235,7 +1243,7 @@ def get_expr_length_value(
) -> Value:
rtype = builder.node_type(expr)
assert is_sequence_rprimitive(rtype) or isinstance(rtype, RTuple), rtype
length = get_expr_length(expr)
length = get_expr_length(builder, expr)
if length is None:
# We cannot compute the length at compile time, so we will fetch it.
return builder.builder.builtin_len(expr_reg, line, use_pyssize_t=use_pyssize_t)
Expand Down
130 changes: 60 additions & 70 deletions mypyc/test-data/irbuild-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -694,51 +694,46 @@ L0:
return r1
def test():
r0, source :: tuple[int, int, int]
r1 :: object
r2 :: native_int
r3 :: bit
r4, r5, r6 :: int
r7, r8, r9 :: object
r10, r11 :: tuple
r12 :: native_int
r13 :: bit
r1, r2, r3 :: int
r4, r5, r6 :: object
r7, r8 :: tuple
r9 :: native_int
r10 :: bit
r11 :: object
r12, x :: int
r13 :: bool
r14 :: object
r15, x :: int
r16 :: bool
r17 :: object
r18 :: native_int
r15 :: native_int
a :: tuple
L0:
r0 = (2, 4, 6)
source = r0
r1 = box(tuple[int, int, int], source)
r2 = PyObject_Size(r1)
r3 = r2 >= 0 :: signed
r4 = source[0]
r5 = source[1]
r6 = source[2]
r7 = box(int, r4)
r8 = box(int, r5)
r9 = box(int, r6)
r10 = PyTuple_Pack(3, r7, r8, r9)
r11 = PyTuple_New(r2)
r12 = 0
r1 = source[0]
r2 = source[1]
r3 = source[2]
r4 = box(int, r1)
r5 = box(int, r2)
r6 = box(int, r3)
r7 = PyTuple_Pack(3, r4, r5, r6)
r8 = PyTuple_New(3)
r9 = 0
goto L2
L1:
r13 = r12 < r2 :: signed
if r13 goto L2 else goto L4 :: bool
r10 = r9 < 3 :: signed
if r10 goto L2 else goto L4 :: bool
L2:
r14 = CPySequenceTuple_GetItemUnsafe(r10, r12)
r15 = unbox(int, r14)
x = r15
r16 = f(x)
r17 = box(bool, r16)
CPySequenceTuple_SetItemUnsafe(r11, r12, r17)
r11 = CPySequenceTuple_GetItemUnsafe(r7, r9)
r12 = unbox(int, r11)
x = r12
r13 = f(x)
r14 = box(bool, r13)
CPySequenceTuple_SetItemUnsafe(r8, r9, r14)
L3:
r18 = r12 + 1
r12 = r18
r15 = r9 + 1
r9 = r15
goto L1
L4:
a = r11
a = r8
return 1

[case testTupleBuiltFromFinalFixedLengthTuple]
Expand All @@ -762,19 +757,16 @@ L0:
def test():
r0 :: tuple[int, int, int]
r1 :: bool
r2 :: object
r3 :: native_int
r4 :: bit
r5, r6, r7 :: int
r8, r9, r10 :: object
r11, r12 :: tuple
r13 :: native_int
r14 :: bit
r2, r3, r4 :: int
r5, r6, r7 :: object
r8, r9 :: tuple
r10 :: native_int
r11 :: bit
r12 :: object
r13, x :: int
r14 :: bool
r15 :: object
r16, x :: int
r17 :: bool
r18 :: object
r19 :: native_int
r16 :: native_int
a :: tuple
L0:
r0 = __main__.source :: static
Expand All @@ -783,34 +775,32 @@ L1:
r1 = raise NameError('value for final name "source" was not set')
unreachable
L2:
r2 = box(tuple[int, int, int], r0)
r3 = PyObject_Size(r2)
r4 = r3 >= 0 :: signed
r5 = r0[0]
r6 = r0[1]
r7 = r0[2]
r8 = box(int, r5)
r9 = box(int, r6)
r10 = box(int, r7)
r11 = PyTuple_Pack(3, r8, r9, r10)
r12 = PyTuple_New(r3)
r13 = 0
r2 = r0[0]
r3 = r0[1]
r4 = r0[2]
r5 = box(int, r2)
r6 = box(int, r3)
r7 = box(int, r4)
r8 = PyTuple_Pack(3, r5, r6, r7)
r9 = PyTuple_New(3)
r10 = 0
goto L4
L3:
r14 = r13 < r3 :: signed
if r14 goto L4 else goto L6 :: bool
r11 = r10 < 3 :: signed
if r11 goto L4 else goto L6 :: bool
L4:
r15 = CPySequenceTuple_GetItemUnsafe(r11, r13)
r16 = unbox(int, r15)
x = r16
r17 = f(x)
r18 = box(bool, r17)
CPySequenceTuple_SetItemUnsafe(r12, r13, r18)
r12 = CPySequenceTuple_GetItemUnsafe(r8, r10)
r13 = unbox(int, r12)
x = r13
r14 = f(x)
r15 = box(bool, r14)
CPySequenceTuple_SetItemUnsafe(r9, r10, r15)
L5:
r19 = r13 + 1
r13 = r19
r16 = r10 + 1
r10 = r16
goto L3
L6:
a = r12
a = r9
return 1

[case testTupleBuiltFromVariableLengthTuple]
Expand Down