diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 20440d4a26f4..8e708178d1df 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import Callable, ClassVar +from typing import Callable, ClassVar, cast from mypy.nodes import ( ARG_POS, @@ -67,6 +67,7 @@ short_int_rprimitive, ) from mypyc.irbuild.builder import IRBuilder +from mypyc.irbuild.constant_fold import constant_fold_expr from mypyc.irbuild.prepare import GENERATOR_HELPER_NAME from mypyc.irbuild.targets import AssignmentTarget, AssignmentTargetTuple from mypyc.primitives.dict_ops import ( @@ -1203,18 +1204,18 @@ def gen_cleanup(self) -> None: gen.gen_cleanup() -def get_expr_length(expr: Expression) -> int | None: +def get_expr_length(builder: IRBuilder, expr: Expression) -> int | None: if isinstance(expr, (StrExpr, BytesExpr)): return len(expr.value) elif isinstance(expr, (ListExpr, TupleExpr)): # if there are no star expressions, or we know the length of them, # we know the length of the expression - stars = [get_expr_length(i) for i in expr.items if isinstance(i, StarExpr)] + stars = [get_expr_length(builder, i) for i in expr.items if isinstance(i, StarExpr)] if None not in stars: other = sum(not isinstance(i, StarExpr) for i in expr.items) return other + sum(stars) # type: ignore [arg-type] elif isinstance(expr, StarExpr): - return get_expr_length(expr.expr) + return get_expr_length(builder, expr.expr) elif ( isinstance(expr, RefExpr) and isinstance(expr.node, Var) @@ -1223,10 +1224,42 @@ def get_expr_length(expr: Expression) -> int | None: and expr.node.has_explicit_value ): return len(expr.node.final_value) + elif ( + isinstance(expr, CallExpr) + and isinstance(callee := expr.callee, NameExpr) + and all(kind == ARG_POS for kind in expr.arg_kinds) + ): + fullname = callee.fullname + if ( + fullname + in ( + "builtins.list", + "builtins.tuple", + "builtins.enumerate", + "builtins.sorted", + "builtins.reversed", + ) + and len(expr.args) == 1 + ): + return get_expr_length(builder, expr.args[0]) + elif fullname == "builtins.map" and len(expr.args) == 2: + return get_expr_length(builder, expr.args[1]) + elif fullname == "builtins.zip" and expr.args: + arg_lengths = [get_expr_length(builder, arg) for arg in expr.args] + if all(arg is not None for arg in arg_lengths): + return min(arg_lengths) # type: ignore [type-var] + elif fullname == "builtins.range" and len(expr.args) <= 3: + folded_args = [constant_fold_expr(builder, arg) for arg in expr.args] + if all(isinstance(arg, int) for arg in folded_args): + try: + return len(range(*cast(list[int], folded_args))) + except ValueError: # prevent crash if invalid args + pass + # TODO: extend this, passing length of listcomp and genexp should have worthwhile # performance boost and can be (sometimes) figured out pretty easily. set and dict # comps *can* be done as well but will need special logic to consider the possibility - # of key conflicts. Range, enumerate, zip are all simple logic. + # of key conflicts. return None @@ -1235,7 +1268,7 @@ def get_expr_length_value( ) -> Value: rtype = builder.node_type(expr) assert is_sequence_rprimitive(rtype) or isinstance(rtype, RTuple), rtype - length = get_expr_length(expr) + length = get_expr_length(builder, expr) if length is None: # We cannot compute the length at compile time, so we will fetch it. return builder.builder.builtin_len(expr_reg, line, use_pyssize_t=use_pyssize_t) diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 22a6a5986cbd..0ed5720a591d 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -309,6 +309,11 @@ def __iter__(self) -> Iterator[int]: pass def __len__(self) -> int: pass def __next__(self) -> int: pass +class map(Iterator[_S]): + def __init__(self, func: Callable[[_T], _S], iterable: Iterable[_T]) -> None: pass + def __iter__(self) -> Self: pass + def __next__(self) -> _S: pass + class property: def __init__(self, fget: Optional[Callable[[Any], Any]] = ..., fset: Optional[Callable[[Any, Any], None]] = ..., diff --git a/mypyc/test-data/irbuild-tuple.test b/mypyc/test-data/irbuild-tuple.test index 3613c5f0101d..48929705884a 100644 --- a/mypyc/test-data/irbuild-tuple.test +++ b/mypyc/test-data/irbuild-tuple.test @@ -860,6 +860,126 @@ L4: a = r1 return 1 +[case testTupleBuiltFromLengthCheckable] +from typing import Tuple + +def f(val: bool) -> bool: + return not val + +def test() -> None: + # this tuple is created from a very complex genexp but we can still compute the length and preallocate the tuple + a = tuple( + x + for x + in zip( + map(str, range(5)), + enumerate(sorted(reversed(tuple("abcdefg")))) + ) + ) +[out] +def f(val): + val, r0 :: bool +L0: + r0 = val ^ 1 + return r0 +def test(): + r0 :: list + r1, r2, r3 :: object + r4 :: object[1] + r5 :: object_ptr + r6 :: object + r7 :: range + r8 :: object + r9 :: str + r10 :: object + r11 :: object[2] + r12 :: object_ptr + r13 :: object + r14 :: str + r15 :: tuple + r16 :: object + r17 :: str + r18 :: object + r19 :: object[1] + r20 :: object_ptr + r21 :: object + r22 :: list + r23 :: object + r24 :: str + r25 :: object + r26 :: object[1] + r27 :: object_ptr + r28, r29 :: object + r30 :: str + r31 :: object + r32 :: object[2] + r33 :: object_ptr + r34, r35, r36 :: object + r37, x :: tuple[str, tuple[int, str]] + r38 :: object + r39 :: i32 + r40, r41 :: bit + r42, a :: tuple +L0: + r0 = PyList_New(0) + r1 = load_address PyUnicode_Type + r2 = load_address PyRange_Type + r3 = object 5 + r4 = [r3] + r5 = load_address r4 + r6 = PyObject_Vectorcall(r2, r5, 1, 0) + keep_alive r3 + r7 = cast(range, r6) + r8 = builtins :: module + r9 = 'map' + r10 = CPyObject_GetAttr(r8, r9) + r11 = [r1, r7] + r12 = load_address r11 + r13 = PyObject_Vectorcall(r10, r12, 2, 0) + keep_alive r1, r7 + r14 = 'abcdefg' + r15 = PySequence_Tuple(r14) + r16 = builtins :: module + r17 = 'reversed' + r18 = CPyObject_GetAttr(r16, r17) + r19 = [r15] + r20 = load_address r19 + r21 = PyObject_Vectorcall(r18, r20, 1, 0) + keep_alive r15 + r22 = CPySequence_Sort(r21) + r23 = builtins :: module + r24 = 'enumerate' + r25 = CPyObject_GetAttr(r23, r24) + r26 = [r22] + r27 = load_address r26 + r28 = PyObject_Vectorcall(r25, r27, 1, 0) + keep_alive r22 + r29 = builtins :: module + r30 = 'zip' + r31 = CPyObject_GetAttr(r29, r30) + r32 = [r13, r28] + r33 = load_address r32 + r34 = PyObject_Vectorcall(r31, r33, 2, 0) + keep_alive r13, r28 + r35 = PyObject_GetIter(r34) +L1: + r36 = PyIter_Next(r35) + if is_error(r36) goto L4 else goto L2 +L2: + r37 = unbox(tuple[str, tuple[int, str]], r36) + x = r37 + r38 = box(tuple[str, tuple[int, str]], x) + r39 = PyList_Append(r0, r38) + r40 = r39 >= 0 :: signed +L3: + goto L1 +L4: + r41 = CPy_NoErrOccurred() +L5: + r42 = PyList_AsTuple(r0) + a = r42 + return 1 + [case testTupleBuiltFromStars] from typing import Final