diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 762b41866a05..5edee6cb4df4 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -11,17 +11,22 @@ from mypy.nodes import ( ARG_POS, + BytesExpr, CallExpr, DictionaryComprehension, Expression, GeneratorExpr, + ListExpr, Lvalue, MemberExpr, NameExpr, RefExpr, SetExpr, + StarExpr, + StrExpr, TupleExpr, TypeAlias, + Var, ) from mypyc.ir.ops import ( ERR_NEVER, @@ -152,6 +157,7 @@ def for_loop_helper_with_index( expr_reg: Value, body_insts: Callable[[Value], None], line: int, + length: Value, ) -> None: """Generate IR for a sequence iteration. @@ -173,7 +179,7 @@ def for_loop_helper_with_index( condition_block = BasicBlock() for_gen = ForSequence(builder, index, body_block, exit_block, line, False) - for_gen.init(expr_reg, target_type, reverse=False) + for_gen.init(expr_reg, target_type, reverse=False, length=length) builder.push_loop_stack(step_block, exit_block) @@ -227,7 +233,9 @@ def sequence_from_generator_preallocate_helper( rtype = builder.node_type(gen.sequences[0]) if is_sequence_rprimitive(rtype): sequence = builder.accept(gen.sequences[0]) - length = builder.builder.builtin_len(sequence, gen.line, use_pyssize_t=True) + length = get_expr_length_value( + builder, gen.sequences[0], sequence, gen.line, use_pyssize_t=True + ) target_op = empty_op_llbuilder(length, gen.line) def set_item(item_index: Value) -> None: @@ -235,7 +243,7 @@ def set_item(item_index: Value) -> None: builder.call_c(set_item_op, [target_op, item_index, e], gen.line) for_loop_helper_with_index( - builder, gen.indices[0], gen.sequences[0], sequence, set_item, gen.line + builder, gen.indices[0], gen.sequences[0], sequence, set_item, gen.line, length ) return target_op @@ -788,9 +796,13 @@ class ForSequence(ForGenerator): length_reg: Value | AssignmentTarget | None - def init(self, expr_reg: Value, target_type: RType, reverse: bool) -> None: + def init( + self, expr_reg: Value, target_type: RType, reverse: bool, length: Value | None = None + ) -> None: assert is_sequence_rprimitive(expr_reg.type), expr_reg builder = self.builder + # Record a Value indicating the length of the sequence, if known at compile time. + self.length = length self.reverse = reverse # Define target to contain the expression, along with the index that will be used # for the for-loop. If we are inside of a generator function, spill these into the @@ -798,7 +810,7 @@ def init(self, expr_reg: Value, target_type: RType, reverse: bool) -> None: self.expr_target = builder.maybe_spill(expr_reg) if is_immutable_rprimitive(expr_reg.type): # If the expression is an immutable type, we can load the length just once. - self.length_reg = builder.maybe_spill(self.load_len(self.expr_target)) + self.length_reg = builder.maybe_spill(self.length or self.load_len(self.expr_target)) else: # Otherwise, even if the length is known, we must recalculate the length # at every iteration for compatibility with python semantics. @@ -1166,3 +1178,43 @@ def gen_step(self) -> None: def gen_cleanup(self) -> None: for gen in self.gens: gen.gen_cleanup() + + +def get_expr_length(expr: Expression) -> int | None: + if isinstance(expr, (StrExpr, BytesExpr)): + return len(expr.value) + elif isinstance(expr, (ListExpr, TupleExpr)): + # if there are no star expressions, or we know the length of them, + # we know the length of the expression + stars = [get_expr_length(i) for i in expr.items if isinstance(i, StarExpr)] + if None not in stars: + other = sum(not isinstance(i, StarExpr) for i in expr.items) + return other + sum(stars) # type: ignore [arg-type] + elif isinstance(expr, StarExpr): + return get_expr_length(expr.expr) + elif ( + isinstance(expr, RefExpr) + and isinstance(expr.node, Var) + and expr.node.is_final + and isinstance(expr.node.final_value, str) + and expr.node.has_explicit_value + ): + return len(expr.node.final_value) + # TODO: extend this, passing length of listcomp and genexp should have worthwhile + # performance boost and can be (sometimes) figured out pretty easily. set and dict + # comps *can* be done as well but will need special logic to consider the possibility + # of key conflicts. Range, enumerate, zip are all simple logic. + return None + + +def get_expr_length_value( + builder: IRBuilder, expr: Expression, expr_reg: Value, line: int, use_pyssize_t: bool +) -> Value: + rtype = builder.node_type(expr) + assert is_sequence_rprimitive(rtype), rtype + length = get_expr_length(expr) + if length is None: + # We cannot compute the length at compile time, so we will fetch it. + return builder.builder.builtin_len(expr_reg, line, use_pyssize_t=use_pyssize_t) + # The expression result is known at compile time, so we can use a constant. + return Integer(length, c_pyssize_t_rprimitive if use_pyssize_t else short_int_rprimitive) diff --git a/mypyc/test-data/irbuild-generics.test b/mypyc/test-data/irbuild-generics.test index 96437a0079c9..9ec29182e89b 100644 --- a/mypyc/test-data/irbuild-generics.test +++ b/mypyc/test-data/irbuild-generics.test @@ -678,84 +678,83 @@ def inner_deco_obj.__call__(__mypyc_self__, args, kwargs): r0 :: __main__.deco_env r1 :: native_int r2 :: list - r3, r4 :: native_int - r5 :: bit - r6, x :: object - r7 :: native_int + r3 :: native_int + r4 :: bit + r5, x :: object + r6 :: native_int can_listcomp :: list - r8 :: dict - r9 :: short_int - r10 :: native_int - r11 :: object - r12 :: tuple[bool, short_int, object, object] - r13 :: short_int - r14 :: bool - r15, r16 :: object - r17, k :: str + r7 :: dict + r8 :: short_int + r9 :: native_int + r10 :: object + r11 :: tuple[bool, short_int, object, object] + r12 :: short_int + r13 :: bool + r14, r15 :: object + r16, k :: str v :: object - r18 :: i32 - r19, r20, r21 :: bit + r17 :: i32 + r18, r19, r20 :: bit can_dictcomp :: dict - r22, can_iter, r23, can_use_keys, r24, can_use_values :: list - r25 :: object - r26 :: dict - r27 :: object - r28 :: int + r21, can_iter, r22, can_use_keys, r23, can_use_values :: list + r24 :: object + r25 :: dict + r26 :: object + r27 :: int L0: r0 = __mypyc_self__.__mypyc_env__ r1 = var_object_size args r2 = PyList_New(r1) - r3 = var_object_size args - r4 = 0 + r3 = 0 L1: - r5 = r4 < r3 :: signed - if r5 goto L2 else goto L4 :: bool + r4 = r3 < r1 :: signed + if r4 goto L2 else goto L4 :: bool L2: - r6 = CPySequenceTuple_GetItemUnsafe(args, r4) - x = r6 - CPyList_SetItemUnsafe(r2, r4, x) + r5 = CPySequenceTuple_GetItemUnsafe(args, r3) + x = r5 + CPyList_SetItemUnsafe(r2, r3, x) L3: - r7 = r4 + 1 - r4 = r7 + r6 = r3 + 1 + r3 = r6 goto L1 L4: can_listcomp = r2 - r8 = PyDict_New() - r9 = 0 - r10 = PyDict_Size(kwargs) - r11 = CPyDict_GetItemsIter(kwargs) + r7 = PyDict_New() + r8 = 0 + r9 = PyDict_Size(kwargs) + r10 = CPyDict_GetItemsIter(kwargs) L5: - r12 = CPyDict_NextItem(r11, r9) - r13 = r12[1] - r9 = r13 - r14 = r12[0] - if r14 goto L6 else goto L8 :: bool + r11 = CPyDict_NextItem(r10, r8) + r12 = r11[1] + r8 = r12 + r13 = r11[0] + if r13 goto L6 else goto L8 :: bool L6: - r15 = r12[2] - r16 = r12[3] - r17 = cast(str, r15) - k = r17 - v = r16 - r18 = PyDict_SetItem(r8, k, v) - r19 = r18 >= 0 :: signed + r14 = r11[2] + r15 = r11[3] + r16 = cast(str, r14) + k = r16 + v = r15 + r17 = PyDict_SetItem(r7, k, v) + r18 = r17 >= 0 :: signed L7: - r20 = CPyDict_CheckSize(kwargs, r10) + r19 = CPyDict_CheckSize(kwargs, r9) goto L5 L8: - r21 = CPy_NoErrOccurred() + r20 = CPy_NoErrOccurred() L9: - can_dictcomp = r8 - r22 = PySequence_List(kwargs) - can_iter = r22 - r23 = CPyDict_Keys(kwargs) - can_use_keys = r23 - r24 = CPyDict_Values(kwargs) - can_use_values = r24 - r25 = r0.func - r26 = PyDict_Copy(kwargs) - r27 = PyObject_Call(r25, args, r26) - r28 = unbox(int, r27) - return r28 + can_dictcomp = r7 + r21 = PySequence_List(kwargs) + can_iter = r21 + r22 = CPyDict_Keys(kwargs) + can_use_keys = r22 + r23 = CPyDict_Values(kwargs) + can_use_values = r23 + r24 = r0.func + r25 = PyDict_Copy(kwargs) + r26 = PyObject_Call(r24, args, r25) + r27 = unbox(int, r26) + return r27 def deco(func): func :: object r0 :: __main__.deco_env diff --git a/mypyc/test-data/irbuild-lists.test b/mypyc/test-data/irbuild-lists.test index d83fb88390db..2f5b3b39319e 100644 --- a/mypyc/test-data/irbuild-lists.test +++ b/mypyc/test-data/irbuild-lists.test @@ -594,10 +594,8 @@ def test(): r3 :: list r4 :: native_int r5 :: bit - r6 :: native_int - r7 :: bit - r8, x, r9 :: str - r10 :: native_int + r6, x, r7 :: str + r8 :: native_int a :: list L0: r0 = 'abc' @@ -605,20 +603,18 @@ L0: r1 = CPyStr_Size_size_t(source) r2 = r1 >= 0 :: signed r3 = PyList_New(r1) - r4 = CPyStr_Size_size_t(source) - r5 = r4 >= 0 :: signed - r6 = 0 + r4 = 0 L1: - r7 = r6 < r4 :: signed - if r7 goto L2 else goto L4 :: bool + r5 = r4 < r1 :: signed + if r5 goto L2 else goto L4 :: bool L2: - r8 = CPyStr_GetItemUnsafe(source, r6) - x = r8 - r9 = f2(x) - CPyList_SetItemUnsafe(r3, r6, r9) + r6 = CPyStr_GetItemUnsafe(source, r4) + x = r6 + r7 = f2(x) + CPyList_SetItemUnsafe(r3, r4, r7) L3: - r10 = r6 + 1 - r6 = r10 + r8 = r4 + 1 + r4 = r8 goto L1 L4: a = r3 @@ -639,38 +635,30 @@ L0: return r1 def test(): r0 :: str - r1 :: native_int - r2 :: bit - r3 :: list - r4 :: native_int - r5 :: bit + r1 :: list + r2 :: native_int + r3 :: bit + r4, x, r5 :: str r6 :: native_int - r7 :: bit - r8, x, r9 :: str - r10 :: native_int a :: list L0: r0 = 'abc' - r1 = CPyStr_Size_size_t(r0) - r2 = r1 >= 0 :: signed - r3 = PyList_New(r1) - r4 = CPyStr_Size_size_t(r0) - r5 = r4 >= 0 :: signed - r6 = 0 + r1 = PyList_New(3) + r2 = 0 L1: - r7 = r6 < r4 :: signed - if r7 goto L2 else goto L4 :: bool + r3 = r2 < 3 :: signed + if r3 goto L2 else goto L4 :: bool L2: - r8 = CPyStr_GetItemUnsafe(r0, r6) - x = r8 - r9 = f2(x) - CPyList_SetItemUnsafe(r3, r6, r9) + r4 = CPyStr_GetItemUnsafe(r0, r2) + x = r4 + r5 = f2(x) + CPyList_SetItemUnsafe(r1, r2, r5) L3: - r10 = r6 + 1 - r6 = r10 + r6 = r2 + 1 + r2 = r6 goto L1 L4: - a = r3 + a = r1 return 1 [case testListBuiltFromFinalStr] @@ -692,38 +680,30 @@ L0: return r1 def test(): r0 :: str - r1 :: native_int - r2 :: bit - r3 :: list - r4 :: native_int - r5 :: bit + r1 :: list + r2 :: native_int + r3 :: bit + r4, x, r5 :: str r6 :: native_int - r7 :: bit - r8, x, r9 :: str - r10 :: native_int a :: list L0: r0 = 'abc' - r1 = CPyStr_Size_size_t(r0) - r2 = r1 >= 0 :: signed - r3 = PyList_New(r1) - r4 = CPyStr_Size_size_t(r0) - r5 = r4 >= 0 :: signed - r6 = 0 + r1 = PyList_New(3) + r2 = 0 L1: - r7 = r6 < r4 :: signed - if r7 goto L2 else goto L4 :: bool + r3 = r2 < 3 :: signed + if r3 goto L2 else goto L4 :: bool L2: - r8 = CPyStr_GetItemUnsafe(r0, r6) - x = r8 - r9 = f2(x) - CPyList_SetItemUnsafe(r3, r6, r9) + r4 = CPyStr_GetItemUnsafe(r0, r2) + x = r4 + r5 = f2(x) + CPyList_SetItemUnsafe(r1, r2, r5) L3: - r10 = r6 + 1 - r6 = r10 + r6 = r2 + 1 + r2 = r6 goto L1 L4: - a = r3 + a = r1 return 1 [case testListBuiltFromBytes_64bit] @@ -744,48 +724,47 @@ def test(): r0, source :: bytes r1 :: native_int r2 :: list - r3, r4 :: native_int - r5, r6, r7 :: bit - r8, r9, r10, r11 :: int - r12 :: object - r13, x, r14 :: int - r15 :: object - r16 :: native_int + r3 :: native_int + r4, r5, r6 :: bit + r7, r8, r9, r10 :: int + r11 :: object + r12, x, r13 :: int + r14 :: object + r15 :: native_int a :: list L0: r0 = b'abc' source = r0 r1 = var_object_size source r2 = PyList_New(r1) - r3 = var_object_size source - r4 = 0 + r3 = 0 L1: - r5 = r4 < r3 :: signed - if r5 goto L2 else goto L8 :: bool + r4 = r3 < r1 :: signed + if r4 goto L2 else goto L8 :: bool L2: - r6 = r4 <= 4611686018427387903 :: signed - if r6 goto L3 else goto L4 :: bool + r5 = r3 <= 4611686018427387903 :: signed + if r5 goto L3 else goto L4 :: bool L3: - r7 = r4 >= -4611686018427387904 :: signed - if r7 goto L5 else goto L4 :: bool + r6 = r3 >= -4611686018427387904 :: signed + if r6 goto L5 else goto L4 :: bool L4: - r8 = CPyTagged_FromInt64(r4) - r9 = r8 + r7 = CPyTagged_FromInt64(r3) + r8 = r7 goto L6 L5: - r10 = r4 << 1 - r9 = r10 + r9 = r3 << 1 + r8 = r9 L6: - r11 = CPyBytes_GetItem(source, r9) - r12 = box(int, r11) - r13 = unbox(int, r12) - x = r13 - r14 = f2(x) - r15 = box(int, r14) - CPyList_SetItemUnsafe(r2, r4, r15) + r10 = CPyBytes_GetItem(source, r8) + r11 = box(int, r10) + r12 = unbox(int, r11) + x = r12 + r13 = f2(x) + r14 = box(int, r13) + CPyList_SetItemUnsafe(r2, r3, r14) L7: - r16 = r4 + 1 - r4 = r16 + r15 = r3 + 1 + r3 = r15 goto L1 L8: a = r2 @@ -806,52 +785,49 @@ L0: return r0 def test(): r0 :: bytes - r1 :: native_int - r2 :: list - r3, r4 :: native_int - r5, r6, r7 :: bit - r8, r9, r10, r11 :: int - r12 :: object - r13, x, r14 :: int - r15 :: object - r16 :: native_int + r1 :: list + r2 :: native_int + r3, r4, r5 :: bit + r6, r7, r8, r9 :: int + r10 :: object + r11, x, r12 :: int + r13 :: object + r14 :: native_int a :: list L0: r0 = b'abc' - r1 = var_object_size r0 - r2 = PyList_New(r1) - r3 = var_object_size r0 - r4 = 0 + r1 = PyList_New(3) + r2 = 0 L1: - r5 = r4 < r3 :: signed - if r5 goto L2 else goto L8 :: bool + r3 = r2 < 3 :: signed + if r3 goto L2 else goto L8 :: bool L2: - r6 = r4 <= 4611686018427387903 :: signed - if r6 goto L3 else goto L4 :: bool + r4 = r2 <= 4611686018427387903 :: signed + if r4 goto L3 else goto L4 :: bool L3: - r7 = r4 >= -4611686018427387904 :: signed - if r7 goto L5 else goto L4 :: bool + r5 = r2 >= -4611686018427387904 :: signed + if r5 goto L5 else goto L4 :: bool L4: - r8 = CPyTagged_FromInt64(r4) - r9 = r8 + r6 = CPyTagged_FromInt64(r2) + r7 = r6 goto L6 L5: - r10 = r4 << 1 - r9 = r10 + r8 = r2 << 1 + r7 = r8 L6: - r11 = CPyBytes_GetItem(r0, r9) - r12 = box(int, r11) - r13 = unbox(int, r12) - x = r13 - r14 = f2(x) - r15 = box(int, r14) - CPyList_SetItemUnsafe(r2, r4, r15) + r9 = CPyBytes_GetItem(r0, r7) + r10 = box(int, r9) + r11 = unbox(int, r10) + x = r11 + r12 = f2(x) + r13 = box(int, r12) + CPyList_SetItemUnsafe(r1, r2, r13) L7: - r16 = r4 + 1 - r4 = r16 + r14 = r2 + 1 + r2 = r14 goto L1 L8: - a = r2 + a = r1 return 1 [case testListBuiltFromFinalBytes_64bit] @@ -876,13 +852,13 @@ def test(): r1 :: bool r2 :: native_int r3 :: list - r4, r5 :: native_int - r6, r7, r8 :: bit - r9, r10, r11, r12 :: int - r13 :: object - r14, x, r15 :: int - r16 :: object - r17 :: native_int + r4 :: native_int + r5, r6, r7 :: bit + r8, r9, r10, r11 :: int + r12 :: object + r13, x, r14 :: int + r15 :: object + r16 :: native_int a :: list L0: r0 = __main__.source :: static @@ -893,36 +869,102 @@ L1: L2: r2 = var_object_size r0 r3 = PyList_New(r2) - r4 = var_object_size r0 - r5 = 0 + r4 = 0 L3: - r6 = r5 < r4 :: signed - if r6 goto L4 else goto L10 :: bool + r5 = r4 < r2 :: signed + if r5 goto L4 else goto L10 :: bool L4: - r7 = r5 <= 4611686018427387903 :: signed - if r7 goto L5 else goto L6 :: bool + r6 = r4 <= 4611686018427387903 :: signed + if r6 goto L5 else goto L6 :: bool L5: - r8 = r5 >= -4611686018427387904 :: signed - if r8 goto L7 else goto L6 :: bool + r7 = r4 >= -4611686018427387904 :: signed + if r7 goto L7 else goto L6 :: bool L6: - r9 = CPyTagged_FromInt64(r5) - r10 = r9 + r8 = CPyTagged_FromInt64(r4) + r9 = r8 goto L8 L7: - r11 = r5 << 1 - r10 = r11 + r10 = r4 << 1 + r9 = r10 L8: - r12 = CPyBytes_GetItem(r0, r10) - r13 = box(int, r12) - r14 = unbox(int, r13) - x = r14 - r15 = f2(x) - r16 = box(int, r15) - CPyList_SetItemUnsafe(r3, r5, r16) + r11 = CPyBytes_GetItem(r0, r9) + r12 = box(int, r11) + r13 = unbox(int, r12) + x = r13 + r14 = f2(x) + r15 = box(int, r14) + CPyList_SetItemUnsafe(r3, r4, r15) L9: - r17 = r5 + 1 - r5 = r17 + r16 = r4 + 1 + r4 = r16 goto L3 L10: a = r3 return 1 + +[case testListBuiltFromStars] +from typing import Final + +abc: Final = "abc" + +def test() -> None: + a = [str(x) for x in [*abc, *"def", *b"ghi", ("j", "k"), *("l", "m", "n")]] + +[out] +def test(): + r0, r1 :: str + r2 :: bytes + r3, r4 :: str + r5 :: tuple[str, str] + r6, r7, r8 :: str + r9 :: tuple[str, str, str] + r10 :: list + r11, r12, r13, r14 :: object + r15 :: i32 + r16 :: bit + r17, r18 :: object + r19 :: list + r20, r21 :: native_int + r22 :: bit + r23, x :: object + r24 :: str + r25 :: native_int + a :: list +L0: + r0 = 'abc' + r1 = 'def' + r2 = b'ghi' + r3 = 'j' + r4 = 'k' + r5 = (r3, r4) + r6 = 'l' + r7 = 'm' + r8 = 'n' + r9 = (r6, r7, r8) + r10 = PyList_New(0) + r11 = CPyList_Extend(r10, r0) + r12 = CPyList_Extend(r10, r1) + r13 = CPyList_Extend(r10, r2) + r14 = box(tuple[str, str], r5) + r15 = PyList_Append(r10, r14) + r16 = r15 >= 0 :: signed + r17 = box(tuple[str, str, str], r9) + r18 = CPyList_Extend(r10, r17) + r19 = PyList_New(13) + r20 = 0 +L1: + r21 = var_object_size r10 + r22 = r20 < r21 :: signed + if r22 goto L2 else goto L4 :: bool +L2: + r23 = list_get_item_unsafe r10, r20 + x = r23 + r24 = PyObject_Str(x) + CPyList_SetItemUnsafe(r19, r20, r24) +L3: + r25 = r20 + 1 + r20 = r25 + goto L1 +L4: + a = r19 + return 1 diff --git a/mypyc/test-data/irbuild-tuple.test b/mypyc/test-data/irbuild-tuple.test index 00ea7f074a5d..081cc1b174c9 100644 --- a/mypyc/test-data/irbuild-tuple.test +++ b/mypyc/test-data/irbuild-tuple.test @@ -365,10 +365,8 @@ def test(): r3 :: tuple r4 :: native_int r5 :: bit - r6 :: native_int - r7 :: bit - r8, x, r9 :: str - r10 :: native_int + r6, x, r7 :: str + r8 :: native_int a :: tuple L0: r0 = 'abc' @@ -376,20 +374,18 @@ L0: r1 = CPyStr_Size_size_t(source) r2 = r1 >= 0 :: signed r3 = PyTuple_New(r1) - r4 = CPyStr_Size_size_t(source) - r5 = r4 >= 0 :: signed - r6 = 0 + r4 = 0 L1: - r7 = r6 < r4 :: signed - if r7 goto L2 else goto L4 :: bool + r5 = r4 < r1 :: signed + if r5 goto L2 else goto L4 :: bool L2: - r8 = CPyStr_GetItemUnsafe(source, r6) - x = r8 - r9 = f2(x) - CPySequenceTuple_SetItemUnsafe(r3, r6, r9) + r6 = CPyStr_GetItemUnsafe(source, r4) + x = r6 + r7 = f2(x) + CPySequenceTuple_SetItemUnsafe(r3, r4, r7) L3: - r10 = r6 + 1 - r6 = r10 + r8 = r4 + 1 + r4 = r8 goto L1 L4: a = r3 @@ -411,38 +407,30 @@ L0: return r1 def test(): r0 :: str - r1 :: native_int - r2 :: bit - r3 :: tuple - r4 :: native_int - r5 :: bit + r1 :: tuple + r2 :: native_int + r3 :: bit + r4, x, r5 :: str r6 :: native_int - r7 :: bit - r8, x, r9 :: str - r10 :: native_int a :: tuple L0: r0 = 'abc' - r1 = CPyStr_Size_size_t(r0) - r2 = r1 >= 0 :: signed - r3 = PyTuple_New(r1) - r4 = CPyStr_Size_size_t(r0) - r5 = r4 >= 0 :: signed - r6 = 0 + r1 = PyTuple_New(3) + r2 = 0 L1: - r7 = r6 < r4 :: signed - if r7 goto L2 else goto L4 :: bool + r3 = r2 < 3 :: signed + if r3 goto L2 else goto L4 :: bool L2: - r8 = CPyStr_GetItemUnsafe(r0, r6) - x = r8 - r9 = f2(x) - CPySequenceTuple_SetItemUnsafe(r3, r6, r9) + r4 = CPyStr_GetItemUnsafe(r0, r2) + x = r4 + r5 = f2(x) + CPySequenceTuple_SetItemUnsafe(r1, r2, r5) L3: - r10 = r6 + 1 - r6 = r10 + r6 = r2 + 1 + r2 = r6 goto L1 L4: - a = r3 + a = r1 return 1 [case testTupleBuiltFromFinalStr] @@ -464,38 +452,30 @@ L0: return r1 def test(): r0 :: str - r1 :: native_int - r2 :: bit - r3 :: tuple - r4 :: native_int - r5 :: bit + r1 :: tuple + r2 :: native_int + r3 :: bit + r4, x, r5 :: str r6 :: native_int - r7 :: bit - r8, x, r9 :: str - r10 :: native_int a :: tuple L0: r0 = 'abc' - r1 = CPyStr_Size_size_t(r0) - r2 = r1 >= 0 :: signed - r3 = PyTuple_New(r1) - r4 = CPyStr_Size_size_t(r0) - r5 = r4 >= 0 :: signed - r6 = 0 + r1 = PyTuple_New(3) + r2 = 0 L1: - r7 = r6 < r4 :: signed - if r7 goto L2 else goto L4 :: bool + r3 = r2 < 3 :: signed + if r3 goto L2 else goto L4 :: bool L2: - r8 = CPyStr_GetItemUnsafe(r0, r6) - x = r8 - r9 = f2(x) - CPySequenceTuple_SetItemUnsafe(r3, r6, r9) + r4 = CPyStr_GetItemUnsafe(r0, r2) + x = r4 + r5 = f2(x) + CPySequenceTuple_SetItemUnsafe(r1, r2, r5) L3: - r10 = r6 + 1 - r6 = r10 + r6 = r2 + 1 + r2 = r6 goto L1 L4: - a = r3 + a = r1 return 1 [case testTupleBuiltFromBytes_64bit] @@ -516,48 +496,47 @@ def test(): r0, source :: bytes r1 :: native_int r2 :: tuple - r3, r4 :: native_int - r5, r6, r7 :: bit - r8, r9, r10, r11 :: int - r12 :: object - r13, x, r14 :: int - r15 :: object - r16 :: native_int + r3 :: native_int + r4, r5, r6 :: bit + r7, r8, r9, r10 :: int + r11 :: object + r12, x, r13 :: int + r14 :: object + r15 :: native_int a :: tuple L0: r0 = b'abc' source = r0 r1 = var_object_size source r2 = PyTuple_New(r1) - r3 = var_object_size source - r4 = 0 + r3 = 0 L1: - r5 = r4 < r3 :: signed - if r5 goto L2 else goto L8 :: bool + r4 = r3 < r1 :: signed + if r4 goto L2 else goto L8 :: bool L2: - r6 = r4 <= 4611686018427387903 :: signed - if r6 goto L3 else goto L4 :: bool + r5 = r3 <= 4611686018427387903 :: signed + if r5 goto L3 else goto L4 :: bool L3: - r7 = r4 >= -4611686018427387904 :: signed - if r7 goto L5 else goto L4 :: bool + r6 = r3 >= -4611686018427387904 :: signed + if r6 goto L5 else goto L4 :: bool L4: - r8 = CPyTagged_FromInt64(r4) - r9 = r8 + r7 = CPyTagged_FromInt64(r3) + r8 = r7 goto L6 L5: - r10 = r4 << 1 - r9 = r10 + r9 = r3 << 1 + r8 = r9 L6: - r11 = CPyBytes_GetItem(source, r9) - r12 = box(int, r11) - r13 = unbox(int, r12) - x = r13 - r14 = f2(x) - r15 = box(int, r14) - CPySequenceTuple_SetItemUnsafe(r2, r4, r15) + r10 = CPyBytes_GetItem(source, r8) + r11 = box(int, r10) + r12 = unbox(int, r11) + x = r12 + r13 = f2(x) + r14 = box(int, r13) + CPySequenceTuple_SetItemUnsafe(r2, r3, r14) L7: - r16 = r4 + 1 - r4 = r16 + r15 = r3 + 1 + r3 = r15 goto L1 L8: a = r2 @@ -578,52 +557,49 @@ L0: return r0 def test(): r0 :: bytes - r1 :: native_int - r2 :: tuple - r3, r4 :: native_int - r5, r6, r7 :: bit - r8, r9, r10, r11 :: int - r12 :: object - r13, x, r14 :: int - r15 :: object - r16 :: native_int + r1 :: tuple + r2 :: native_int + r3, r4, r5 :: bit + r6, r7, r8, r9 :: int + r10 :: object + r11, x, r12 :: int + r13 :: object + r14 :: native_int a :: tuple L0: r0 = b'abc' - r1 = var_object_size r0 - r2 = PyTuple_New(r1) - r3 = var_object_size r0 - r4 = 0 + r1 = PyTuple_New(3) + r2 = 0 L1: - r5 = r4 < r3 :: signed - if r5 goto L2 else goto L8 :: bool + r3 = r2 < 3 :: signed + if r3 goto L2 else goto L8 :: bool L2: - r6 = r4 <= 4611686018427387903 :: signed - if r6 goto L3 else goto L4 :: bool + r4 = r2 <= 4611686018427387903 :: signed + if r4 goto L3 else goto L4 :: bool L3: - r7 = r4 >= -4611686018427387904 :: signed - if r7 goto L5 else goto L4 :: bool + r5 = r2 >= -4611686018427387904 :: signed + if r5 goto L5 else goto L4 :: bool L4: - r8 = CPyTagged_FromInt64(r4) - r9 = r8 + r6 = CPyTagged_FromInt64(r2) + r7 = r6 goto L6 L5: - r10 = r4 << 1 - r9 = r10 + r8 = r2 << 1 + r7 = r8 L6: - r11 = CPyBytes_GetItem(r0, r9) - r12 = box(int, r11) - r13 = unbox(int, r12) - x = r13 - r14 = f2(x) - r15 = box(int, r14) - CPySequenceTuple_SetItemUnsafe(r2, r4, r15) + r9 = CPyBytes_GetItem(r0, r7) + r10 = box(int, r9) + r11 = unbox(int, r10) + x = r11 + r12 = f2(x) + r13 = box(int, r12) + CPySequenceTuple_SetItemUnsafe(r1, r2, r13) L7: - r16 = r4 + 1 - r4 = r16 + r14 = r2 + 1 + r2 = r14 goto L1 L8: - a = r2 + a = r1 return 1 [case testTupleBuiltFromFinalBytes_64bit] @@ -648,13 +624,13 @@ def test(): r1 :: bool r2 :: native_int r3 :: tuple - r4, r5 :: native_int - r6, r7, r8 :: bit - r9, r10, r11, r12 :: int - r13 :: object - r14, x, r15 :: int - r16 :: object - r17 :: native_int + r4 :: native_int + r5, r6, r7 :: bit + r8, r9, r10, r11 :: int + r12 :: object + r13, x, r14 :: int + r15 :: object + r16 :: native_int a :: tuple L0: r0 = __main__.source :: static @@ -665,35 +641,34 @@ L1: L2: r2 = var_object_size r0 r3 = PyTuple_New(r2) - r4 = var_object_size r0 - r5 = 0 + r4 = 0 L3: - r6 = r5 < r4 :: signed - if r6 goto L4 else goto L10 :: bool + r5 = r4 < r2 :: signed + if r5 goto L4 else goto L10 :: bool L4: - r7 = r5 <= 4611686018427387903 :: signed - if r7 goto L5 else goto L6 :: bool + r6 = r4 <= 4611686018427387903 :: signed + if r6 goto L5 else goto L6 :: bool L5: - r8 = r5 >= -4611686018427387904 :: signed - if r8 goto L7 else goto L6 :: bool + r7 = r4 >= -4611686018427387904 :: signed + if r7 goto L7 else goto L6 :: bool L6: - r9 = CPyTagged_FromInt64(r5) - r10 = r9 + r8 = CPyTagged_FromInt64(r4) + r9 = r8 goto L8 L7: - r11 = r5 << 1 - r10 = r11 + r10 = r4 << 1 + r9 = r10 L8: - r12 = CPyBytes_GetItem(r0, r10) - r13 = box(int, r12) - r14 = unbox(int, r13) - x = r14 - r15 = f2(x) - r16 = box(int, r15) - CPySequenceTuple_SetItemUnsafe(r3, r5, r16) + r11 = CPyBytes_GetItem(r0, r9) + r12 = box(int, r11) + r13 = unbox(int, r12) + x = r13 + r14 = f2(x) + r15 = box(int, r14) + CPySequenceTuple_SetItemUnsafe(r3, r4, r15) L9: - r17 = r5 + 1 - r5 = r17 + r16 = r4 + 1 + r4 = r16 goto L3 L10: a = r3 @@ -825,36 +800,102 @@ def test(source): source :: tuple r0 :: native_int r1 :: tuple - r2, r3 :: native_int - r4 :: bit - r5 :: object - r6, x, r7 :: bool - r8 :: object - r9 :: native_int + r2 :: native_int + r3 :: bit + r4 :: object + r5, x, r6 :: bool + r7 :: object + r8 :: native_int a :: tuple L0: r0 = var_object_size source r1 = PyTuple_New(r0) - r2 = var_object_size source - r3 = 0 + r2 = 0 L1: - r4 = r3 < r2 :: signed - if r4 goto L2 else goto L4 :: bool + r3 = r2 < r0 :: signed + if r3 goto L2 else goto L4 :: bool L2: - r5 = CPySequenceTuple_GetItemUnsafe(source, r3) - r6 = unbox(bool, r5) - x = r6 - r7 = f(x) - r8 = box(bool, r7) - CPySequenceTuple_SetItemUnsafe(r1, r3, r8) + r4 = CPySequenceTuple_GetItemUnsafe(source, r2) + r5 = unbox(bool, r4) + x = r5 + r6 = f(x) + r7 = box(bool, r6) + CPySequenceTuple_SetItemUnsafe(r1, r2, r7) L3: - r9 = r3 + 1 - r3 = r9 + r8 = r2 + 1 + r2 = r8 goto L1 L4: a = r1 return 1 +[case testTupleBuiltFromStars] +from typing import Final + +abc: Final = "abc" + +def test() -> None: + a = tuple(str(x) for x in [*abc, *"def", *b"ghi", ("j", "k"), *("l", "m", "n")]) + +[out] +def test(): + r0, r1 :: str + r2 :: bytes + r3, r4 :: str + r5 :: tuple[str, str] + r6, r7, r8 :: str + r9 :: tuple[str, str, str] + r10 :: list + r11, r12, r13, r14 :: object + r15 :: i32 + r16 :: bit + r17, r18 :: object + r19 :: tuple + r20, r21 :: native_int + r22 :: bit + r23, x :: object + r24 :: str + r25 :: native_int + a :: tuple +L0: + r0 = 'abc' + r1 = 'def' + r2 = b'ghi' + r3 = 'j' + r4 = 'k' + r5 = (r3, r4) + r6 = 'l' + r7 = 'm' + r8 = 'n' + r9 = (r6, r7, r8) + r10 = PyList_New(0) + r11 = CPyList_Extend(r10, r0) + r12 = CPyList_Extend(r10, r1) + r13 = CPyList_Extend(r10, r2) + r14 = box(tuple[str, str], r5) + r15 = PyList_Append(r10, r14) + r16 = r15 >= 0 :: signed + r17 = box(tuple[str, str, str], r9) + r18 = CPyList_Extend(r10, r17) + r19 = PyTuple_New(13) + r20 = 0 +L1: + r21 = var_object_size r10 + r22 = r20 < r21 :: signed + if r22 goto L2 else goto L4 :: bool +L2: + r23 = list_get_item_unsafe r10, r20 + x = r23 + r24 = PyObject_Str(x) + CPySequenceTuple_SetItemUnsafe(r19, r20, r24) +L3: + r25 = r20 + 1 + r20 = r25 + goto L1 +L4: + a = r19 + return 1 + [case testTupleAdd] from typing import Tuple def f(a: Tuple[int, ...], b: Tuple[int, ...]) -> None: diff --git a/mypyc/test-data/run-lists.test b/mypyc/test-data/run-lists.test index 1569579c1156..40ca1b6e005f 100644 --- a/mypyc/test-data/run-lists.test +++ b/mypyc/test-data/run-lists.test @@ -479,6 +479,8 @@ def test_in_operator_various_cases() -> None: assert list_in_mixed(type) [case testListBuiltFromGenerator] +from typing import Final +abc: Final = "abc" def test_from_gen() -> None: source_a = ["a", "b", "c"] a = list(x + "f2" for x in source_a) @@ -498,6 +500,10 @@ def test_from_gen() -> None: source_str = "abcd" f = list("str:" + x for x in source_str) assert f == ["str:a", "str:b", "str:c", "str:d"] +def test_known_length() -> None: + # not built from generator but doesnt need its own test either + built = [str(x) for x in [*abc, *"def", *b"ghi", ("j", "k"), *("l", "m", "n")]] + assert built == ['a', 'b', 'c', 'd', 'e', 'f', '103', '104', '105', "('j', 'k')", 'l', 'm', 'n'] [case testNext] from typing import List diff --git a/mypyc/test-data/run-tuples.test b/mypyc/test-data/run-tuples.test index f5e1733d429b..e2e8358bb43e 100644 --- a/mypyc/test-data/run-tuples.test +++ b/mypyc/test-data/run-tuples.test @@ -270,6 +270,11 @@ def test_slicing() -> None: def f8(val: int) -> bool: return val % 2 == 0 +abc: Final = "abc" + +def known_length() -> tuple[str, ...]: + return tuple(str(x) for x in [*abc, *"def", *b"ghi", ("j", "k"), *("l", "m", "n")]) + def test_sequence_generator() -> None: source_list = [1, 2, 3] a = tuple(f8(x) for x in source_list) @@ -287,6 +292,8 @@ def test_sequence_generator() -> None: b = tuple('s:' + x for x in source_str) assert b == ('s:a', 's:b', 's:b', 's:c') + assert known_length() == ('a', 'b', 'c', 'd', 'e', 'f', '103', '104', '105', "('j', 'k')", 'l', 'm', 'n') + TUPLE: Final[Tuple[str, ...]] = ('x', 'y') def test_final_boxed_tuple() -> None: