diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 20440d4a26f4..b3a444d2ab05 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -1203,18 +1203,18 @@ def gen_cleanup(self) -> None: gen.gen_cleanup() -def get_expr_length(expr: Expression) -> int | None: +def get_expr_length(builder: IRBuilder, expr: Expression) -> int | None: if isinstance(expr, (StrExpr, BytesExpr)): return len(expr.value) elif isinstance(expr, (ListExpr, TupleExpr)): # if there are no star expressions, or we know the length of them, # we know the length of the expression - stars = [get_expr_length(i) for i in expr.items if isinstance(i, StarExpr)] + stars = [get_expr_length(builder, i) for i in expr.items if isinstance(i, StarExpr)] if None not in stars: other = sum(not isinstance(i, StarExpr) for i in expr.items) return other + sum(stars) # type: ignore [arg-type] elif isinstance(expr, StarExpr): - return get_expr_length(expr.expr) + return get_expr_length(builder, expr.expr) elif ( isinstance(expr, RefExpr) and isinstance(expr.node, Var) @@ -1227,6 +1227,14 @@ def get_expr_length(expr: Expression) -> int | None: # performance boost and can be (sometimes) figured out pretty easily. set and dict # comps *can* be done as well but will need special logic to consider the possibility # of key conflicts. Range, enumerate, zip are all simple logic. + + # we might still be able to get the length directly from the type + rtype = builder.node_type(expr) + if isinstance(rtype, RTuple): + return len(rtype.types) + proper_type = get_proper_type(builder.types[expr]) + if isinstance(proper_type, TupleType): + return len(proper_type.items) return None @@ -1235,7 +1243,7 @@ def get_expr_length_value( ) -> Value: rtype = builder.node_type(expr) assert is_sequence_rprimitive(rtype) or isinstance(rtype, RTuple), rtype - length = get_expr_length(expr) + length = get_expr_length(builder, expr) if length is None: # We cannot compute the length at compile time, so we will fetch it. return builder.builder.builtin_len(expr_reg, line, use_pyssize_t=use_pyssize_t) diff --git a/mypyc/test-data/irbuild-tuple.test b/mypyc/test-data/irbuild-tuple.test index 3613c5f0101d..0fdd8e87a154 100644 --- a/mypyc/test-data/irbuild-tuple.test +++ b/mypyc/test-data/irbuild-tuple.test @@ -694,51 +694,46 @@ L0: return r1 def test(): r0, source :: tuple[int, int, int] - r1 :: object - r2 :: native_int - r3 :: bit - r4, r5, r6 :: int - r7, r8, r9 :: object - r10, r11 :: tuple - r12 :: native_int - r13 :: bit + r1, r2, r3 :: int + r4, r5, r6 :: object + r7, r8 :: tuple + r9 :: native_int + r10 :: bit + r11 :: object + r12, x :: int + r13 :: bool r14 :: object - r15, x :: int - r16 :: bool - r17 :: object - r18 :: native_int + r15 :: native_int a :: tuple L0: r0 = (2, 4, 6) source = r0 - r1 = box(tuple[int, int, int], source) - r2 = PyObject_Size(r1) - r3 = r2 >= 0 :: signed - r4 = source[0] - r5 = source[1] - r6 = source[2] - r7 = box(int, r4) - r8 = box(int, r5) - r9 = box(int, r6) - r10 = PyTuple_Pack(3, r7, r8, r9) - r11 = PyTuple_New(r2) - r12 = 0 + r1 = source[0] + r2 = source[1] + r3 = source[2] + r4 = box(int, r1) + r5 = box(int, r2) + r6 = box(int, r3) + r7 = PyTuple_Pack(3, r4, r5, r6) + r8 = PyTuple_New(3) + r9 = 0 + goto L2 L1: - r13 = r12 < r2 :: signed - if r13 goto L2 else goto L4 :: bool + r10 = r9 < 3 :: signed + if r10 goto L2 else goto L4 :: bool L2: - r14 = CPySequenceTuple_GetItemUnsafe(r10, r12) - r15 = unbox(int, r14) - x = r15 - r16 = f(x) - r17 = box(bool, r16) - CPySequenceTuple_SetItemUnsafe(r11, r12, r17) + r11 = CPySequenceTuple_GetItemUnsafe(r7, r9) + r12 = unbox(int, r11) + x = r12 + r13 = f(x) + r14 = box(bool, r13) + CPySequenceTuple_SetItemUnsafe(r8, r9, r14) L3: - r18 = r12 + 1 - r12 = r18 + r15 = r9 + 1 + r9 = r15 goto L1 L4: - a = r11 + a = r8 return 1 [case testTupleBuiltFromFinalFixedLengthTuple] @@ -762,19 +757,16 @@ L0: def test(): r0 :: tuple[int, int, int] r1 :: bool - r2 :: object - r3 :: native_int - r4 :: bit - r5, r6, r7 :: int - r8, r9, r10 :: object - r11, r12 :: tuple - r13 :: native_int - r14 :: bit + r2, r3, r4 :: int + r5, r6, r7 :: object + r8, r9 :: tuple + r10 :: native_int + r11 :: bit + r12 :: object + r13, x :: int + r14 :: bool r15 :: object - r16, x :: int - r17 :: bool - r18 :: object - r19 :: native_int + r16 :: native_int a :: tuple L0: r0 = __main__.source :: static @@ -783,34 +775,32 @@ L1: r1 = raise NameError('value for final name "source" was not set') unreachable L2: - r2 = box(tuple[int, int, int], r0) - r3 = PyObject_Size(r2) - r4 = r3 >= 0 :: signed - r5 = r0[0] - r6 = r0[1] - r7 = r0[2] - r8 = box(int, r5) - r9 = box(int, r6) - r10 = box(int, r7) - r11 = PyTuple_Pack(3, r8, r9, r10) - r12 = PyTuple_New(r3) - r13 = 0 + r2 = r0[0] + r3 = r0[1] + r4 = r0[2] + r5 = box(int, r2) + r6 = box(int, r3) + r7 = box(int, r4) + r8 = PyTuple_Pack(3, r5, r6, r7) + r9 = PyTuple_New(3) + r10 = 0 + goto L4 L3: - r14 = r13 < r3 :: signed - if r14 goto L4 else goto L6 :: bool + r11 = r10 < 3 :: signed + if r11 goto L4 else goto L6 :: bool L4: - r15 = CPySequenceTuple_GetItemUnsafe(r11, r13) - r16 = unbox(int, r15) - x = r16 - r17 = f(x) - r18 = box(bool, r17) - CPySequenceTuple_SetItemUnsafe(r12, r13, r18) + r12 = CPySequenceTuple_GetItemUnsafe(r8, r10) + r13 = unbox(int, r12) + x = r13 + r14 = f(x) + r15 = box(bool, r14) + CPySequenceTuple_SetItemUnsafe(r9, r10, r15) L5: - r19 = r13 + 1 - r13 = r19 + r16 = r10 + 1 + r10 = r16 goto L3 L6: - a = r12 + a = r9 return 1 [case testTupleBuiltFromVariableLengthTuple]