|
113 | 113 | from mypyc.primitives.str_ops import str_slice_op |
114 | 114 | from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op |
115 | 115 |
|
| 116 | +TransformFunc = Callable[[IRBuilder, Expression], Value | None] |
| 117 | + |
| 118 | + |
116 | 119 | # Name and attribute references |
117 | 120 |
|
118 | 121 |
|
@@ -527,11 +530,8 @@ def translate_cast_expr(builder: IRBuilder, expr: CastExpr) -> Value: |
527 | 530 | # Operators |
528 | 531 |
|
529 | 532 |
|
| 533 | +@folding_candidate |
530 | 534 | def transform_unary_expr(builder: IRBuilder, expr: UnaryExpr) -> Value: |
531 | | - folded = try_constant_fold(builder, expr) |
532 | | - if folded: |
533 | | - return folded |
534 | | - |
535 | 535 | return builder.unary_op(builder.accept(expr.expr), expr.op, expr.line) |
536 | 536 |
|
537 | 537 |
|
@@ -582,6 +582,7 @@ def try_optimize_int_floor_divide(builder: IRBuilder, expr: OpExpr) -> OpExpr: |
582 | 582 | return expr |
583 | 583 |
|
584 | 584 |
|
| 585 | +@folding_candidate |
585 | 586 | def transform_index_expr(builder: IRBuilder, expr: IndexExpr) -> Value: |
586 | 587 | index = expr.index |
587 | 588 | base_type = builder.node_type(expr.base) |
@@ -615,6 +616,19 @@ def try_constant_fold(builder: IRBuilder, expr: Expression) -> Value | None: |
615 | 616 | return None |
616 | 617 |
|
617 | 618 |
|
| 619 | +def folding_candidate(transform: TransformFunc) -> TransformFunc: |
| 620 | + """Mark a transform function as a candidate for constant folding. |
| 621 | +
|
| 622 | + Candidate functions will attempt to short-circuit the transformation |
| 623 | + by constant folding the expression and will only proceed to transform |
| 624 | + the expression if folding is not possible. |
| 625 | + """ |
| 626 | + def constant_fold_wrap(builder: IRBuilder, expr: Expression) -> Value | None: |
| 627 | + folded = try_constant_fold(builder, expr) |
| 628 | + return folded if folded is not None else transform(builder, expr) |
| 629 | + return constant_fold_wrap |
| 630 | + |
| 631 | + |
618 | 632 | def try_gen_slice_op(builder: IRBuilder, base: Value, index: SliceExpr) -> Value | None: |
619 | 633 | """Generate specialized slice op for some index expressions. |
620 | 634 |
|
|
0 commit comments