Skip to content

Commit 377b31e

Browse files
refactor
1 parent fe2f670 commit 377b31e

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

mypyc/irbuild/expression.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@
113113
from mypyc.primitives.str_ops import str_slice_op
114114
from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op
115115

116+
TransformFunc = Callable[[IRBuilder, Expression], Value | None]
117+
118+
116119
# Name and attribute references
117120

118121

@@ -527,11 +530,8 @@ def translate_cast_expr(builder: IRBuilder, expr: CastExpr) -> Value:
527530
# Operators
528531

529532

533+
@folding_candidate
530534
def transform_unary_expr(builder: IRBuilder, expr: UnaryExpr) -> Value:
531-
folded = try_constant_fold(builder, expr)
532-
if folded:
533-
return folded
534-
535535
return builder.unary_op(builder.accept(expr.expr), expr.op, expr.line)
536536

537537

@@ -582,6 +582,7 @@ def try_optimize_int_floor_divide(builder: IRBuilder, expr: OpExpr) -> OpExpr:
582582
return expr
583583

584584

585+
@folding_candidate
585586
def transform_index_expr(builder: IRBuilder, expr: IndexExpr) -> Value:
586587
index = expr.index
587588
base_type = builder.node_type(expr.base)
@@ -615,6 +616,19 @@ def try_constant_fold(builder: IRBuilder, expr: Expression) -> Value | None:
615616
return None
616617

617618

619+
def folding_candidate(transform: TransformFunc) -> TransformFunc:
620+
"""Mark a transform function as a candidate for constant folding.
621+
622+
Candidate functions will attempt to short-circuit the transformation
623+
by constant folding the expression and will only proceed to transform
624+
the expression if folding is not possible.
625+
"""
626+
def constant_fold_wrap(builder: IRBuilder, expr: Expression) -> Value | None:
627+
folded = try_constant_fold(builder, expr)
628+
return folded if folded is not None else transform(builder, expr)
629+
return constant_fold_wrap
630+
631+
618632
def try_gen_slice_op(builder: IRBuilder, base: Value, index: SliceExpr) -> Value | None:
619633
"""Generate specialized slice op for some index expressions.
620634

0 commit comments

Comments
 (0)