diff --git a/mypy/constant_fold.py b/mypy/constant_fold.py index 4582b2a7396d..7a5f207bb820 100644 --- a/mypy/constant_fold.py +++ b/mypy/constant_fold.py @@ -8,13 +8,17 @@ from typing import Final, Union from mypy.nodes import ( + CallExpr, ComplexExpr, Expression, FloatExpr, IntExpr, + ListExpr, + MemberExpr, NameExpr, OpExpr, StrExpr, + TupleExpr, UnaryExpr, Var, ) @@ -73,6 +77,22 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non value = constant_fold_expr(expr.expr, cur_mod_id) if value is not None: return constant_fold_unary_op(expr.op, value) + # --- partial str.join support in preparation for f-string constant folding --- + elif ( + isinstance(expr, CallExpr) + and isinstance(callee := expr.callee, MemberExpr) + and isinstance(folded_callee := constant_fold_expr(callee.expr, cur_mod_id), str) + and callee.name == "join" + and len(args := expr.args) == 1 + and isinstance(arg := args[0], (ListExpr, TupleExpr)) + ): + folded_items = [] + for item in arg.items: + val = constant_fold_expr(item, cur_mod_id) + if not isinstance(val, str): + return None + folded_items.append(val) + return folded_callee.join(folded_items) return None diff --git a/mypyc/irbuild/ast_helpers.py b/mypyc/irbuild/ast_helpers.py index 3b0f50514594..4a9bc5baddef 100644 --- a/mypyc/irbuild/ast_helpers.py +++ b/mypyc/irbuild/ast_helpers.py @@ -9,6 +9,7 @@ from mypy.nodes import ( LDEF, BytesExpr, + CallExpr, ComparisonExpr, Expression, FloatExpr, @@ -109,7 +110,7 @@ def is_borrow_friendly_expr(self: IRBuilder, expr: Expression) -> bool: # Literals are immortal and can always be borrowed return True if ( - isinstance(expr, (UnaryExpr, OpExpr, NameExpr, MemberExpr)) + isinstance(expr, (UnaryExpr, OpExpr, NameExpr, MemberExpr, CallExpr)) and constant_fold_expr(self, expr) is not None ): # Literal expressions are similar to literals diff --git a/mypyc/irbuild/constant_fold.py b/mypyc/irbuild/constant_fold.py index b1133f95b18e..33a73bcaf327 100644 --- a/mypyc/irbuild/constant_fold.py +++ b/mypyc/irbuild/constant_fold.py @@ -10,22 +10,28 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Final, Union +from typing import TYPE_CHECKING, Final, Union, overload +from mypy.checkexpr import try_getting_literal from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_op from mypy.nodes import ( BytesExpr, + CallExpr, ComplexExpr, Expression, FloatExpr, IntExpr, + ListExpr, MemberExpr, NameExpr, OpExpr, StrExpr, + TupleExpr, UnaryExpr, Var, ) +from mypy.types import LiteralType, TupleType, get_proper_type +from mypyc.irbuild.builder import IRBuilder from mypyc.irbuild.util import bytes_from_str if TYPE_CHECKING: @@ -74,6 +80,48 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | value = constant_fold_expr(builder, expr.expr) if value is not None and not isinstance(value, bytes): return constant_fold_unary_op(expr.op, value) + # we can also constant fold some common methods of builtin types + elif isinstance(expr, CallExpr) and isinstance(callee := expr.callee, MemberExpr): + folded_callee = constant_fold_expr(builder, callee.expr) + + # builtins.str methods + if isinstance(folded_callee, str): + # str.join + if callee.name == "join" and len(args := expr.args) == 1: + arg = args[0] + if isinstance(arg, (ListExpr, TupleExpr)): + folded_items = constant_fold_container_expr(builder, arg) + if folded_items is not None and all( + isinstance(item, str) for item in folded_items + ): + return folded_callee.join(folded_items) # type: ignore [arg-type] + if expr_type := builder.types.get(arg): + proper_type = get_proper_type(expr_type) + if isinstance(proper_type, TupleType): + values: list[str] = [] + for item_type in map(try_getting_literal, proper_type.items): + if not ( + isinstance(item_type, LiteralType) + and isinstance(item_type.value, str) + ): + return None + values.append(item_type.value) + return folded_callee.join(values) + + # builtins.bytes methods + elif isinstance(folded_callee, bytes): + # bytes.join + if ( + callee.name == "join" + and len(args := expr.args) == 1 + # TODO extend this to work with rtuples comprised of known literal values + and isinstance(arg := args[0], (ListExpr, TupleExpr)) + ): + folded_items = constant_fold_container_expr(builder, arg) + if folded_items is not None and all( + isinstance(item, bytes) for item in folded_items + ): + return folded_callee.join(folded_items) # type: ignore [arg-type] return None @@ -95,3 +143,25 @@ def constant_fold_binary_op_extended( return left * right return None + + +@overload +def constant_fold_container_expr( + builder: IRBuilder, expr: ListExpr +) -> list[ConstantValue] | None: ... +@overload +def constant_fold_container_expr( + builder: IRBuilder, expr: TupleExpr +) -> tuple[ConstantValue, ...] | None: ... +def constant_fold_container_expr( + builder: IRBuilder, expr: ListExpr | TupleExpr +) -> list[ConstantValue] | tuple[ConstantValue, ...] | None: + folded_items = [constant_fold_expr(builder, item_expr) for item_expr in expr.items] + if None in folded_items: + return None + if isinstance(expr, ListExpr): + return folded_items # type: ignore [return-value] + elif isinstance(expr, TupleExpr): + return tuple(folded_items) # type: ignore [arg-type] + else: + raise NotImplementedError(type(expr), expr) diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index 59ecc4ac2c5c..b9ed40c994d1 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -324,6 +324,10 @@ def transform_call_expr(builder: IRBuilder, expr: CallExpr) -> Value: # A call to a NewType type is a no-op at runtime. return builder.accept(expr.args[0]) + folded = try_constant_fold(builder, expr) + if folded is not None: + return folded + if isinstance(callee, IndexExpr) and isinstance(callee.analyzed, TypeApplication): callee = callee.analyzed.expr # Unwrap type application diff --git a/mypyc/test-data/irbuild-bytes.test b/mypyc/test-data/irbuild-bytes.test index 8cfefe03ae22..a253eeaaed66 100644 --- a/mypyc/test-data/irbuild-bytes.test +++ b/mypyc/test-data/irbuild-bytes.test @@ -217,3 +217,25 @@ L2: L3: keep_alive y return r2 + +[case testBytesJoinConstantFold] +from typing import Final + +# TODO: this is not currently supported but probably should be +constant: Final = b"constant" + +def fold_tuple() -> bytes: + return b" ".join((b"constant", b"folded")) +def fold_list() -> bytes: + return b" ".join([b"constant", b"folded"]) +[out] +def fold_tuple(): + r0 :: bytes +L0: + r0 = b'constant folded' + return r0 +def fold_list(): + r0 :: bytes +L0: + r0 = b'constant folded' + return r0 diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test index 3fa39819498d..1ff3caeb7e59 100644 --- a/mypyc/test-data/irbuild-str.test +++ b/mypyc/test-data/irbuild-str.test @@ -740,3 +740,32 @@ L2: L3: keep_alive x return r2 + +[case testStrJoinConstantFold] +from typing import Final + +constant: Final = "constant" +tuple_of_literals = "tuple", "of", "literals" + +def fold_tuple() -> str: + return " ".join((constant, "folded")) +def fold_tuple_var() -> str: + return " ".join(tuple_of_literals) +def fold_list() -> str: + return " ".join([constant, "folded"]) +[out] +def fold_tuple(): + r0 :: str +L0: + r0 = 'constant folded' + return r0 +def fold_tuple_var(): + r0 :: str +L0: + r0 = 'tuple of literals' + return r0 +def fold_list(): + r0 :: str +L0: + r0 = 'constant folded' + return r0 diff --git a/test-data/unit/check-final.test b/test-data/unit/check-final.test index e3fc4614fc06..8221288816e9 100644 --- a/test-data/unit/check-final.test +++ b/test-data/unit/check-final.test @@ -1336,3 +1336,32 @@ class S9(S2, S4): # E: Class "S9" has incompatible disjoint bases pass [builtins fixtures/tuple.pyi] + +[case testFinalJoin] +from typing import Final + + +hello: Final = "hello" +x: Final = " ".join((hello,"my","name","is","joe")) +y: Final = " ".join([hello,"joe","my","name","is","moe"]) + +reveal_type(x) # N: Revealed type is "Literal['hello my name is joe']?" +reveal_type(y) # N: Revealed type is "Literal['hello joe my name is moe']?" + +delimiter: Final = "," +headers: Final = delimiter.join(("name", "age")) +joe: Final = delimiter.join(["joe", "24"]) +jack: Final = delimiter.join(("jack", "77")) +jill: Final = delimiter.join(["jill", "30"]) +lines: Final = "\n".join((headers, joe, jack, jill)) + +reveal_type(headers) # N: Revealed type is "Literal['name,age']?" +reveal_type(lines) # N: Revealed type is "Literal['name,age\njoe,24\njack,77\njill,30']?" + +# TODO: implement me +constant_fold_not_yet_supported: Final = "you", "cant", "fold", "this", "yet" +z: Final = " ".join(constant_fold_not_yet_supported) + +reveal_type(z) # N: Revealed type is "builtins.str" +[builtins fixtures/f_string.pyi] +[out] diff --git a/test-data/unit/fixtures/f_string.pyi b/test-data/unit/fixtures/f_string.pyi index 328c666b7ece..83f36e04c731 100644 --- a/test-data/unit/fixtures/f_string.pyi +++ b/test-data/unit/fixtures/f_string.pyi @@ -32,7 +32,7 @@ class bool(int): pass class str: def __add__(self, s: str) -> str: pass def format(self, *args) -> str: pass - def join(self, l: List[str]) -> str: pass + def join(self, l: Iterable[str]) -> str: pass class dict: pass