diff --git a/mypy/constant_fold.py b/mypy/constant_fold.py index 4582b2a7396d..315ba1f313b8 100644 --- a/mypy/constant_fold.py +++ b/mypy/constant_fold.py @@ -5,16 +5,21 @@ from __future__ import annotations -from typing import Final, Union +from typing import Any, Callable, Final, Union from mypy.nodes import ( + ArgKind, + CallExpr, ComplexExpr, Expression, FloatExpr, IntExpr, + ListExpr, + MemberExpr, NameExpr, OpExpr, StrExpr, + TupleExpr, UnaryExpr, Var, ) @@ -73,6 +78,8 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non value = constant_fold_expr(expr.expr, cur_mod_id) if value is not None: return constant_fold_unary_op(expr.op, value) + elif isinstance(expr, CallExpr): + return constant_fold_call_expr(expr, cur_mod_id) return None @@ -185,3 +192,90 @@ def constant_fold_unary_op(op: str, value: ConstantValue) -> int | float | None: elif op == "+" and isinstance(value, (int, float)): return value return None + + +foldable_builtins: dict[str, Callable[..., Any]] = { + "builtins.str": str, + "builtins.int": int, + "builtins.bool": bool, + "builtins.float": float, + "builtins.complex": complex, + "builtins.repr": repr, + "builtins.len": len, + "builtins.hasattr": hasattr, + "builtins.hex": hex, + "builtins.hash": hash, + "builtins.min": min, + "builtins.max": max, + "builtins.oct": oct, + "builtins.pow": pow, + "builtins.round": round, + "builtins.abs": abs, + "builtins.ascii": ascii, + "builtins.bin": bin, + "builtins.chr": chr, +} + + +def constant_fold_call_expr( + expr: CallExpr, + cur_mod_id: str, + foldable_builtins: dict[str, Callable[..., Any]] = foldable_builtins, +) -> ConstantValue | None: + folded_args: list[ConstantValue] + + callee = expr.callee + if isinstance(callee, NameExpr): + func = foldable_builtins.get(callee.fullname) + if func is None: + return None + + folded_args = [] + for arg in expr.args: + val = constant_fold_expr(arg, cur_mod_id) + if val is None: + return None + folded_args.append(val) + + call_args: list[ConstantValue] = [] + call_kwargs: dict[str, ConstantValue] = {} + try: + for folded_arg, arg_kind, arg_name in zip(folded_args, expr.arg_kinds, expr.arg_names): + if arg_kind == ArgKind.ARG_POS: + call_args.append(folded_arg) + elif arg_kind == ArgKind.ARG_NAMED: + call_kwargs[arg_name] = folded_arg # type: ignore [index] + elif arg_kind == ArgKind.ARG_STAR: + call_args.extend(folded_arg) # type: ignore [arg-type] + elif arg_kind == ArgKind.ARG_STAR2: + call_kwargs.update(folded_arg) # type: ignore [arg-type] + return func(*call_args, **call_kwargs) # type: ignore [no-any-return] + except Exception: + return None + # --- f-string requires partial support for both str.join and str.format --- + elif isinstance(callee, MemberExpr) and isinstance( + folded_callee := constant_fold_expr(callee.expr, cur_mod_id), str + ): + # --- partial str.join constant folding --- + if ( + callee.name == "join" + and len(args := expr.args) == 1 + and isinstance(arg := args[0], (ListExpr, TupleExpr)) + ): + folded_strings: list[str] = [] + for item in arg.items: + val = constant_fold_expr(item, cur_mod_id) + if not isinstance(val, str): + return None + folded_strings.append(val) + return folded_callee.join(folded_strings) + # --- str.format constant folding --- + elif callee.name == "format": + folded_args = [] + for arg in expr.args: + arg_val = constant_fold_expr(arg, cur_mod_id) + if arg_val is None: + return None + folded_args.append(arg_val) + return folded_callee.format(*folded_args) + return None diff --git a/mypyc/test-data/exceptions-freq.test b/mypyc/test-data/exceptions-freq.test index b0e4cd6d35f7..fbfbd80ad923 100644 --- a/mypyc/test-data/exceptions-freq.test +++ b/mypyc/test-data/exceptions-freq.test @@ -99,11 +99,22 @@ hot blocks: [0, 1] [case testRareBranch_freq] from typing import Final -x: Final = str() +def setter() -> str: + # we need this helper to ensure `x` cannot be constant folded + return "" + +x: Final = setter() def f() -> str: return x [out] +def setter(): + r0 :: str +L0: + r0 = '' + inc_ref r0 + return r0 +hot blocks: [0] def f(): r0 :: str r1 :: bool @@ -113,7 +124,7 @@ L0: if is_error(r0) goto L1 else goto L3 L1: r1 = raise NameError('value for final name "x" was not set') - if not r1 goto L4 (error at f:6) else goto L2 :: bool + if not r1 goto L4 (error at f:10) else goto L2 :: bool L2: unreachable L3: diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index 612f3266fd79..8fa352aa2715 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -3093,18 +3093,8 @@ def f() -> int: return x - 1 [out] def f(): - r0 :: int - r1 :: bool - r2 :: int L0: - r0 = __main__.x :: static - if is_error(r0) goto L1 else goto L2 -L1: - r1 = raise NameError('value for final name "x" was not set') - unreachable -L2: - r2 = CPyTagged_Subtract(r0, 2) - return r2 + return 0 [case testFinalRestrictedTypeVar] from typing import TypeVar diff --git a/test-data/unit/check-final.test b/test-data/unit/check-final.test index e3fc4614fc06..62e51286216d 100644 --- a/test-data/unit/check-final.test +++ b/test-data/unit/check-final.test @@ -11,7 +11,7 @@ y: Final[float] = int() z: Final[int] = int() bad: Final[str] = int() # E: Incompatible types in assignment (expression has type "int", variable has type "str") -reveal_type(x) # N: Revealed type is "builtins.int" +reveal_type(x) # N: Revealed type is "Literal[0]?" reveal_type(y) # N: Revealed type is "builtins.float" reveal_type(z) # N: Revealed type is "builtins.int" [out] @@ -26,10 +26,10 @@ class C: bad: Final[str] = int() # E: Incompatible types in assignment (expression has type "int", variable has type "str") class D(C): pass -reveal_type(D.x) # N: Revealed type is "builtins.int" +reveal_type(D.x) # N: Revealed type is "Literal[0]?" reveal_type(D.y) # N: Revealed type is "builtins.float" reveal_type(D.z) # N: Revealed type is "builtins.int" -reveal_type(D().x) # N: Revealed type is "builtins.int" +reveal_type(D().x) # N: Revealed type is "Literal[0]?" reveal_type(D().y) # N: Revealed type is "builtins.float" reveal_type(D().z) # N: Revealed type is "builtins.int" [out] diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 63278d6c4547..6730aac5620e 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -2174,8 +2174,8 @@ class A: [out] [case testMultipassAndTopLevelVariable] -y = x # E: Cannot determine type of "x" # E: Name "x" is used before definition -y() +y = x # E: Name "x" is used before definition +y() # E: "int" not callable x = 1+int() [out]