diff --git a/mypyc/irbuild/format_str_tokenizer.py b/mypyc/irbuild/format_str_tokenizer.py index 5a35900006d2..1284c99f5c74 100644 --- a/mypyc/irbuild/format_str_tokenizer.py +++ b/mypyc/irbuild/format_str_tokenizer.py @@ -23,6 +23,7 @@ is_str_rprimitive, ) from mypyc.irbuild.builder import IRBuilder +from mypyc.irbuild.constant_fold import constant_fold_expr from mypyc.primitives.bytes_ops import bytes_build_op from mypyc.primitives.int_ops import int_to_str_op from mypyc.primitives.str_ops import str_build_op, str_op @@ -143,16 +144,19 @@ def convert_format_expr_to_str( for x, format_op in zip(exprs, format_ops): node_type = builder.node_type(x) if format_op == FormatOp.STR: - if is_str_rprimitive(node_type) or isinstance( - x, StrExpr - ): # NOTE: why does mypyc think our fake StrExprs are not str rprimitives? + if is_str_rprimitive(node_type) or isinstance(x, StrExpr): + # NOTE: why does mypyc think our fake StrExprs are not str rprimitives? var_str = builder.accept(x) + elif (folded := constant_fold_expr(builder, x)) is not None: + var_str = builder.accept(StrExpr(str(folded))) elif is_int_rprimitive(node_type) or is_short_int_rprimitive(node_type): var_str = builder.primitive_op(int_to_str_op, [builder.accept(x)], line) else: var_str = builder.primitive_op(str_op, [builder.accept(x)], line) elif format_op == FormatOp.INT: - if is_int_rprimitive(node_type) or is_short_int_rprimitive(node_type): + if isinstance(folded := constant_fold_expr(builder, x), int): + var_str = builder.accept(StrExpr(str(folded))) + elif is_int_rprimitive(node_type) or is_short_int_rprimitive(node_type): var_str = builder.primitive_op(int_to_str_op, [builder.accept(x)], line) else: return None diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test index 056f120c7bac..471c81cc666f 100644 --- a/mypyc/test-data/irbuild-str.test +++ b/mypyc/test-data/irbuild-str.test @@ -638,11 +638,11 @@ floating: Final = 3.14 boolean: Final = True def test(x: str) -> str: - return f"{string}{integer}{floating}{boolean}{x}{boolean}{floating}{integer}{string}{x}{string}{integer}{floating}{boolean}" + return f"{string[:3]}{integer}{floating}{boolean}{x}{boolean}{floating}{integer}{string}{x}{string}{integer}{floating}{boolean}" def test2(x: str) -> str: - return f"{string}{integer}{floating}{boolean}{x}{boolean}{floating}{integer}{string}{x}{string}{integer}{floating}{boolean}{x}" + return f"{string[-3:]}{integer}{floating}{boolean}{x}{boolean}{floating}{integer}{string}{x}{string}{integer}{floating}{boolean}{x}" def test3(x: str) -> str: - return f"{x}{string}{integer}{floating}{boolean}{x}{boolean}{floating}{integer}{string}{x}{string}{integer}{floating}{boolean}{x}" + return f"{x}{string[:]}{integer}{floating}{boolean}{x}{boolean}{floating}{integer}{string}{x}{string}{integer}{floating}{boolean}{x}" [out] def test(x):