diff --git a/mypy/constant_fold.py b/mypy/constant_fold.py index 4582b2a7396d..e98041b69b05 100644 --- a/mypy/constant_fold.py +++ b/mypy/constant_fold.py @@ -11,9 +11,11 @@ ComplexExpr, Expression, FloatExpr, + IndexExpr, IntExpr, NameExpr, OpExpr, + SliceExpr, StrExpr, UnaryExpr, Var, @@ -73,6 +75,40 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non value = constant_fold_expr(expr.expr, cur_mod_id) if value is not None: return constant_fold_unary_op(expr.op, value) + elif isinstance(expr, IndexExpr): + base = constant_fold_expr(expr.base, cur_mod_id) + if base is not None: + index_expr = expr.index + if isinstance(index_expr, SliceExpr): + if index_expr.begin_index is None: + begin_index = None + else: + begin_index = constant_fold_expr(index_expr.begin_index, cur_mod_id) + if begin_index is None: + return None + if index_expr.end_index is None: + end_index = None + else: + end_index = constant_fold_expr(index_expr.end_index, cur_mod_id) + if end_index is None: + return None + if index_expr.stride is None: + stride = None + else: + stride = constant_fold_expr(index_expr.stride, cur_mod_id) + if stride is None: + return None + try: + return base[begin_index:end_index:stride] # type: ignore [index, misc] + except Exception: + return None + + index = constant_fold_expr(index_expr, cur_mod_id) + if index is not None: + try: + return base[index] # type: ignore [index] + except Exception: + return None return None diff --git a/mypyc/irbuild/constant_fold.py b/mypyc/irbuild/constant_fold.py index b1133f95b18e..c4c8ea21a85c 100644 --- a/mypyc/irbuild/constant_fold.py +++ b/mypyc/irbuild/constant_fold.py @@ -10,7 +10,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Final, Union +from typing import TYPE_CHECKING, Callable, Final, TypeVar, Union from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_op from mypy.nodes import ( @@ -18,14 +18,17 @@ ComplexExpr, Expression, FloatExpr, + IndexExpr, IntExpr, MemberExpr, NameExpr, OpExpr, + SliceExpr, StrExpr, UnaryExpr, Var, ) +from mypyc.ir.ops import Value from mypyc.irbuild.util import bytes_from_str if TYPE_CHECKING: @@ -35,6 +38,8 @@ ConstantValue = Union[int, float, complex, str, bytes] CONST_TYPES: Final = (int, float, complex, str, bytes) +Expr = TypeVar("Expr", bound=Expression) + def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | None: """Return the constant value of an expression for supported operations. @@ -74,6 +79,40 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | value = constant_fold_expr(builder, expr.expr) if value is not None and not isinstance(value, bytes): return constant_fold_unary_op(expr.op, value) + elif isinstance(expr, IndexExpr): + base = constant_fold_expr(builder, expr.base) + if base is not None: + index_expr = expr.index + if isinstance(index_expr, SliceExpr): + if index_expr.begin_index is None: + begin_index = None + else: + begin_index = constant_fold_expr(builder, index_expr.begin_index) + if begin_index is None: + return None + if index_expr.end_index is None: + end_index = None + else: + end_index = constant_fold_expr(builder, index_expr.end_index) + if end_index is None: + return None + if index_expr.stride is None: + stride = None + else: + stride = constant_fold_expr(builder, index_expr.stride) + if stride is None: + return None + try: + return base[begin_index:end_index:stride] # type: ignore [index, misc] + except Exception: + return None + + index = constant_fold_expr(builder, index_expr) + if index is not None: + try: + return base[index] # type: ignore [index] + except Exception: + return None return None @@ -95,3 +134,31 @@ def constant_fold_binary_op_extended( return left * right return None + + +def try_constant_fold(builder: IRBuilder, expr: Expression) -> Value | None: + """Return the constant value of an expression if possible. + + Return None otherwise. + """ + value = constant_fold_expr(builder, expr) + if value is not None: + return builder.load_literal_value(value) + return None + + +def folding_candidate( + transform: Callable[[IRBuilder, Expr], Value], +) -> Callable[[IRBuilder, Expr], Value]: + """Mark a transform function as a candidate for constant folding. + + Candidate functions will attempt to short-circuit the transformation + by constant folding the expression and will only proceed to transform + the expression if folding is not possible. + """ + + def constant_fold_wrap(builder: IRBuilder, expr: Expr) -> Value: + folded = try_constant_fold(builder, expr) + return folded if folded is not None else transform(builder, expr) + + return constant_fold_wrap diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index 59ecc4ac2c5c..8b4d6fcb1051 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -83,7 +83,7 @@ ) from mypyc.irbuild.ast_helpers import is_borrow_friendly_expr, process_conditional from mypyc.irbuild.builder import IRBuilder, int_borrow_friendly_op -from mypyc.irbuild.constant_fold import constant_fold_expr +from mypyc.irbuild.constant_fold import constant_fold_expr, folding_candidate, try_constant_fold from mypyc.irbuild.for_helpers import ( comprehension_helper, raise_error_if_contains_unreachable_names, @@ -527,11 +527,8 @@ def translate_cast_expr(builder: IRBuilder, expr: CastExpr) -> Value: # Operators +@folding_candidate def transform_unary_expr(builder: IRBuilder, expr: UnaryExpr) -> Value: - folded = try_constant_fold(builder, expr) - if folded: - return folded - return builder.unary_op(builder.accept(expr.expr), expr.op, expr.line) @@ -582,6 +579,7 @@ def try_optimize_int_floor_divide(builder: IRBuilder, expr: OpExpr) -> OpExpr: return expr +@folding_candidate def transform_index_expr(builder: IRBuilder, expr: IndexExpr) -> Value: index = expr.index base_type = builder.node_type(expr.base) @@ -604,17 +602,6 @@ def transform_index_expr(builder: IRBuilder, expr: IndexExpr) -> Value: ) -def try_constant_fold(builder: IRBuilder, expr: Expression) -> Value | None: - """Return the constant value of an expression if possible. - - Return None otherwise. - """ - value = constant_fold_expr(builder, expr) - if value is not None: - return builder.load_literal_value(value) - return None - - def try_gen_slice_op(builder: IRBuilder, base: Value, index: SliceExpr) -> Value | None: """Generate specialized slice op for some index expressions. diff --git a/mypyc/test-data/irbuild-constant-fold.test b/mypyc/test-data/irbuild-constant-fold.test index cd953c84c541..aca649be7d28 100644 --- a/mypyc/test-data/irbuild-constant-fold.test +++ b/mypyc/test-data/irbuild-constant-fold.test @@ -478,3 +478,34 @@ L0: r3 = (-1.5+2j) neg_2 = r3 return 1 + +[case testIndexExprConstantFolding] +from typing import Final + +long_string: Final = "long string" + +def pos_index() -> None: + a = long_string[5] +def neg_index() -> None: + a = long_string[-5] +def slice_index() -> None: + a = long_string[5:] +[out] +def pos_index(): + r0, a :: str +L0: + r0 = 's' + a = r0 + return 1 +def neg_index(): + r0, a :: str +L0: + r0 = 't' + a = r0 + return 1 +def slice_index(): + r0, a :: str +L0: + r0 = 'string' + a = r0 + return 1 diff --git a/mypyc/test-data/run-strings.test b/mypyc/test-data/run-strings.test index 6a62db6ee3ee..22f3a5906f66 100644 --- a/mypyc/test-data/run-strings.test +++ b/mypyc/test-data/run-strings.test @@ -151,7 +151,7 @@ def test_unicode() -> None: assert ne("\U0001f4a9foo", "\U0001f4a8foo" + str()) [case testStringOps] -from typing import List, Optional, Tuple +from typing import Final, List, Optional, Tuple from testutil import assertRaises def do_split(s: str, sep: Optional[str] = None, max_split: Optional[int] = None) -> List[str]: @@ -226,6 +226,12 @@ def contains(s: str, o: str) -> bool: def getitem(s: str, index: int) -> str: return s[index] +final_string: Final = "abc" +final_int: Final = 1 + +def getitem_folded() -> str: + return final_string[final_int] + final_string[-1] + def find(s: str, substr: str, start: Optional[int] = None, end: Optional[int] = None) -> int: if start is not None: if end is not None: @@ -263,6 +269,7 @@ def test_getitem() -> None: getitem(s, 4) with assertRaises(IndexError, "string index out of range"): getitem(s, -4) + assert getitem_folded() == "bc" def test_find() -> None: s = "abcab"