diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 20440d4a26f4..881c154c0b00 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -11,7 +11,6 @@ from mypy.nodes import ( ARG_POS, - BytesExpr, CallExpr, DictionaryComprehension, Expression, @@ -23,10 +22,8 @@ RefExpr, SetExpr, StarExpr, - StrExpr, TupleExpr, TypeAlias, - Var, ) from mypy.types import LiteralType, TupleType, get_proper_type, get_proper_types from mypyc.ir.ops import ( @@ -67,6 +64,7 @@ short_int_rprimitive, ) from mypyc.irbuild.builder import IRBuilder +from mypyc.irbuild.constant_fold import constant_fold_expr from mypyc.irbuild.prepare import GENERATOR_HELPER_NAME from mypyc.irbuild.targets import AssignmentTarget, AssignmentTargetTuple from mypyc.primitives.dict_ops import ( @@ -1203,26 +1201,19 @@ def gen_cleanup(self) -> None: gen.gen_cleanup() -def get_expr_length(expr: Expression) -> int | None: - if isinstance(expr, (StrExpr, BytesExpr)): - return len(expr.value) +def get_expr_length(builder: IRBuilder, expr: Expression) -> int | None: + folded = constant_fold_expr(builder, expr) + if isinstance(folded, (str, bytes)): + return len(folded) elif isinstance(expr, (ListExpr, TupleExpr)): # if there are no star expressions, or we know the length of them, # we know the length of the expression - stars = [get_expr_length(i) for i in expr.items if isinstance(i, StarExpr)] + stars = [get_expr_length(builder, i) for i in expr.items if isinstance(i, StarExpr)] if None not in stars: other = sum(not isinstance(i, StarExpr) for i in expr.items) return other + sum(stars) # type: ignore [arg-type] elif isinstance(expr, StarExpr): - return get_expr_length(expr.expr) - elif ( - isinstance(expr, RefExpr) - and isinstance(expr.node, Var) - and expr.node.is_final - and isinstance(expr.node.final_value, str) - and expr.node.has_explicit_value - ): - return len(expr.node.final_value) + return get_expr_length(builder, expr.expr) # TODO: extend this, passing length of listcomp and genexp should have worthwhile # performance boost and can be (sometimes) figured out pretty easily. set and dict # comps *can* be done as well but will need special logic to consider the possibility @@ -1235,7 +1226,7 @@ def get_expr_length_value( ) -> Value: rtype = builder.node_type(expr) assert is_sequence_rprimitive(rtype) or isinstance(rtype, RTuple), rtype - length = get_expr_length(expr) + length = get_expr_length(builder, expr) if length is None: # We cannot compute the length at compile time, so we will fetch it. return builder.builder.builtin_len(expr_reg, line, use_pyssize_t=use_pyssize_t) diff --git a/mypyc/test-data/irbuild-tuple.test b/mypyc/test-data/irbuild-tuple.test index 3613c5f0101d..a91c94a1afa3 100644 --- a/mypyc/test-data/irbuild-tuple.test +++ b/mypyc/test-data/irbuild-tuple.test @@ -434,6 +434,54 @@ L4: a = r1 return 1 +[case testTupleBuiltFromConstantFolding] +from typing import Final + +c: Final = "c" + +def f2(val: str) -> str: + return val + "f2" + +def test() -> None: + # `"ab" + c` should constant fold to "abc" + a = tuple(f2(x) for x in "ab" + c) + +[out] +def f2(val): + val, r0, r1 :: str +L0: + r0 = 'f2' + r1 = PyUnicode_Concat(val, r0) + return r1 +def test(): + r0 :: str + r1 :: tuple + r2 :: native_int + r3 :: bit + r4, x, r5 :: str + r6 :: native_int + a :: tuple +L0: + r0 = 'abc' + r1 = PyTuple_New(3) + r2 = 0 + goto L2 +L1: + r3 = r2 < 3 :: signed + if r3 goto L2 else goto L4 :: bool +L2: + r4 = CPyStr_GetItemUnsafe(r0, r2) + x = r4 + r5 = f2(x) + CPySequenceTuple_SetItemUnsafe(r1, r2, r5) +L3: + r6 = r2 + 1 + r2 = r6 + goto L1 +L4: + a = r1 + return 1 + [case testTupleBuiltFromFinalStr] from typing import Final