diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index db986f3fd9a7..615c42daa697 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -16,6 +16,7 @@ DictionaryComprehension, Expression, GeneratorExpr, + ListComprehension, ListExpr, Lvalue, MemberExpr, @@ -1202,13 +1203,22 @@ def get_expr_length(expr: Expression) -> int | None: and isinstance(expr.node, Var) and expr.node.is_final and isinstance(expr.node.final_value, str) - and expr.node.has_explicit_value ): return len(expr.node.final_value) - # TODO: extend this, passing length of listcomp and genexp should have worthwhile - # performance boost and can be (sometimes) figured out pretty easily. set and dict - # comps *can* be done as well but will need special logic to consider the possibility - # of key conflicts. Range, enumerate, zip are all simple logic. + elif isinstance(expr, ListComprehension): + return get_expr_length(expr.generator) + elif isinstance(expr, GeneratorExpr) and not expr.condlists: + sequence_lengths = [get_expr_length(seq) for seq in expr.sequences] + if None not in sequence_lengths: + if len(sequence_lengths) == 1: + return sequence_lengths[0] + product = sequence_lengths[0] + for l in sequence_lengths[1:]: + product *= l # type: ignore [operator] + return product + # TODO: extend this, set and dict comps can be done as well but will + # need special logic to consider the possibility of key conflicts. + # Range, enumerate, zip are all simple logic. return None diff --git a/mypyc/test-data/irbuild-tuple.test b/mypyc/test-data/irbuild-tuple.test index 7507b6255740..743977779ff4 100644 --- a/mypyc/test-data/irbuild-tuple.test +++ b/mypyc/test-data/irbuild-tuple.test @@ -344,6 +344,122 @@ L4: a = r6 return 1 +[case testTupleBuiltFromListComprehension] +def f(val: int) -> bool: + return val % 2 == 0 + +def test() -> None: + a = tuple(f(x) for x in [a * b for a in [1, 2, 3] for b in [1, 2, 3]]) +[out] +def f(val): + val, r0 :: int + r1 :: bit +L0: + r0 = CPyTagged_Remainder(val, 4) + r1 = int_eq r0, 0 + return r1 +def test(): + r0, r1 :: list + r2, r3, r4 :: object + r5 :: ptr + r6, r7 :: native_int + r8 :: bit + r9 :: object + r10, a :: int + r11 :: list + r12, r13, r14 :: object + r15 :: ptr + r16, r17 :: native_int + r18 :: bit + r19 :: object + r20, b, r21 :: int + r22 :: object + r23 :: i32 + r24 :: bit + r25, r26, r27 :: native_int + r28 :: tuple + r29, r30 :: native_int + r31 :: bit + r32 :: object + r33, x :: int + r34 :: bool + r35 :: object + r36 :: native_int + a_2 :: tuple +L0: + r0 = PyList_New(0) + r1 = PyList_New(3) + r2 = object 1 + r3 = object 2 + r4 = object 3 + r5 = list_items r1 + buf_init_item r5, 0, r2 + buf_init_item r5, 1, r3 + buf_init_item r5, 2, r4 + keep_alive r1 + r6 = 0 +L1: + r7 = var_object_size r1 + r8 = r6 < r7 :: signed + if r8 goto L2 else goto L8 :: bool +L2: + r9 = list_get_item_unsafe r1, r6 + r10 = unbox(int, r9) + a = r10 + r11 = PyList_New(3) + r12 = object 1 + r13 = object 2 + r14 = object 3 + r15 = list_items r11 + buf_init_item r15, 0, r12 + buf_init_item r15, 1, r13 + buf_init_item r15, 2, r14 + keep_alive r11 + r16 = 0 +L3: + r17 = var_object_size r11 + r18 = r16 < r17 :: signed + if r18 goto L4 else goto L6 :: bool +L4: + r19 = list_get_item_unsafe r11, r16 + r20 = unbox(int, r19) + b = r20 + r21 = CPyTagged_Multiply(a, b) + r22 = box(int, r21) + r23 = PyList_Append(r0, r22) + r24 = r23 >= 0 :: signed +L5: + r25 = r16 + 1 + r16 = r25 + goto L3 +L6: +L7: + r26 = r6 + 1 + r6 = r26 + goto L1 +L8: + r27 = var_object_size r0 + r28 = PyTuple_New(r27) + r29 = 0 +L9: + r30 = var_object_size r0 + r31 = r29 < r30 :: signed + if r31 goto L10 else goto L12 :: bool +L10: + r32 = list_get_item_unsafe r0, r29 + r33 = unbox(int, r32) + x = r33 + r34 = f(x) + r35 = box(bool, r34) + CPySequenceTuple_SetItemUnsafe(r28, r29, r35) +L11: + r36 = r29 + 1 + r29 = r36 + goto L9 +L12: + a_2 = r28 + return 1 + [case testTupleBuiltFromStr] def f2(val: str) -> str: return val + "f2"