|
| 1 | +from kirin import ir |
| 2 | +from kirin.dialects import scf, func |
| 3 | +from kirin.rewrite.abc import RewriteRule, RewriteResult |
| 4 | + |
| 5 | +from ..dialects.uop.stmts import SingleQubitGate, TwoQubitCtrlGate |
| 6 | +from ..dialects.core.stmts import Reset, Measure |
| 7 | +from ..dialects.expr.stmts import GateFunction |
| 8 | + |
| 9 | +# TODO: unify with PR #248 |
| 10 | +AllowedThenType = SingleQubitGate | TwoQubitCtrlGate | Measure | GateFunction | Reset |
| 11 | + |
| 12 | +DontLiftType = AllowedThenType | scf.Yield | func.Return |
| 13 | + |
| 14 | + |
| 15 | +class LiftThenBody(RewriteRule): |
| 16 | + """Lifts anything that's not a UOP or a yield/return out of the then body""" |
| 17 | + |
| 18 | + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: |
| 19 | + if not isinstance(node, scf.IfElse): |
| 20 | + return RewriteResult() |
| 21 | + |
| 22 | + then_stmts = node.then_body.stmts() |
| 23 | + |
| 24 | + # TODO: should we leave QRegGet in? |
| 25 | + lift_stmts = [stmt for stmt in then_stmts if not isinstance(stmt, DontLiftType)] |
| 26 | + |
| 27 | + for stmt in lift_stmts: |
| 28 | + stmt.detach() |
| 29 | + stmt.insert_before(node) |
| 30 | + |
| 31 | + return RewriteResult(has_done_something=True) |
| 32 | + |
| 33 | + |
| 34 | +class SplitIfStmts(RewriteRule): |
| 35 | + """Splits the then body of an if-else statement into multiple if statements""" |
| 36 | + |
| 37 | + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: |
| 38 | + if not isinstance(node, scf.IfElse): |
| 39 | + return RewriteResult() |
| 40 | + |
| 41 | + *stmts, yield_or_return = node.then_body.stmts() |
| 42 | + |
| 43 | + if len(stmts) == 1: |
| 44 | + return RewriteResult() |
| 45 | + |
| 46 | + is_yield = isinstance(yield_or_return, scf.Yield) |
| 47 | + |
| 48 | + for stmt in stmts: |
| 49 | + stmt.detach() |
| 50 | + |
| 51 | + yield_or_return = scf.Yield() if is_yield else func.Return() |
| 52 | + |
| 53 | + then_block = ir.Block((stmt, yield_or_return), argtypes=(node.cond.type,)) |
| 54 | + then_body = ir.Region(then_block) |
| 55 | + else_body = node.else_body.clone() |
| 56 | + else_body.detach() |
| 57 | + new_if = scf.IfElse( |
| 58 | + cond=node.cond, then_body=then_body, else_body=else_body |
| 59 | + ) |
| 60 | + |
| 61 | + new_if.insert_after(node) |
| 62 | + |
| 63 | + node.delete() |
| 64 | + |
| 65 | + return RewriteResult(has_done_something=True) |
0 commit comments