Skip to content

Commit 0b46162

Browse files
committed
Implement pass that splits multiline ifs in qasm2
1 parent e33e6af commit 0b46162

File tree

5 files changed

+144
-0
lines changed

5 files changed

+144
-0
lines changed

src/bloqade/qasm2/passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .py2qasm import Py2QASM as Py2QASM
44
from .qasm2py import QASM2Py as QASM2Py
55
from .parallel import UOpToParallel as UOpToParallel
6+
from .unroll_if import UnrollIfs as UnrollIfs

src/bloqade/qasm2/passes/fold.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,16 @@
2323

2424
from bloqade.qasm2.dialects import expr
2525

26+
from .unroll_if import UnrollIfs
27+
2628

2729
@dataclass
2830
class QASM2Fold(Pass):
2931
"""Fold pass for qasm2.extended"""
3032

3133
constprop: const.Propagate = field(init=False)
3234
inline_gate_subroutine: bool = True
35+
unroll_ifs: bool = True
3336

3437
def __post_init__(self):
3538
self.constprop = const.Propagate(self.dialects)
@@ -61,6 +64,9 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
6164
.join(result)
6265
)
6366

67+
if self.unroll_ifs:
68+
UnrollIfs(mt.dialects).unsafe_run(mt).join(result)
69+
6470
# run typeinfer again after unroll etc. because we now insert
6571
# a lot of new nodes, which might have more precise types
6672
self.typeinfer.unsafe_run(mt)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from kirin import ir
2+
from kirin.passes import Pass
3+
from kirin.rewrite import (
4+
Walk,
5+
Chain,
6+
Fixpoint,
7+
ConstantFold,
8+
CommonSubexpressionElimination,
9+
)
10+
11+
from ..rewrite.split_ifs import LiftThenBody, SplitIfStmts
12+
13+
14+
class UnrollIfs(Pass):
15+
"""This pass lifts statements that are not UOP out of the if body and then splits whatever is left into multiple if statements so you obtain valid QASM2"""
16+
17+
def unsafe_run(self, mt: ir.Method):
18+
result = Walk(LiftThenBody()).rewrite(mt.code)
19+
result = Walk(SplitIfStmts()).rewrite(mt.code).join(result)
20+
result = (
21+
Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
22+
.rewrite(mt.code)
23+
.join(result)
24+
)
25+
return result
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from kirin import ir
2+
from kirin.dialects import scf, func
3+
from kirin.rewrite.abc import RewriteRule, RewriteResult
4+
5+
from ..dialects.uop.stmts import SingleQubitGate, TwoQubitCtrlGate
6+
from ..dialects.core.stmts import Reset, Measure
7+
from ..dialects.expr.stmts import GateFunction
8+
9+
# TODO: unify with PR #248
10+
AllowedThenType = SingleQubitGate | TwoQubitCtrlGate | Measure | GateFunction | Reset
11+
12+
DontLiftType = AllowedThenType | scf.Yield | func.Return
13+
14+
15+
class LiftThenBody(RewriteRule):
16+
"""Lifts anything that's not a UOP or a yield/return out of the then body"""
17+
18+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
19+
if not isinstance(node, scf.IfElse):
20+
return RewriteResult()
21+
22+
then_stmts = node.then_body.stmts()
23+
24+
# TODO: should we leave QRegGet in?
25+
lift_stmts = [stmt for stmt in then_stmts if not isinstance(stmt, DontLiftType)]
26+
27+
for stmt in lift_stmts:
28+
stmt.detach()
29+
stmt.insert_before(node)
30+
31+
return RewriteResult(has_done_something=True)
32+
33+
34+
class SplitIfStmts(RewriteRule):
35+
"""Splits the then body of an if-else statement into multiple if statements"""
36+
37+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
38+
if not isinstance(node, scf.IfElse):
39+
return RewriteResult()
40+
41+
*stmts, yield_or_return = node.then_body.stmts()
42+
43+
if len(stmts) == 1:
44+
return RewriteResult()
45+
46+
is_yield = isinstance(yield_or_return, scf.Yield)
47+
48+
for stmt in stmts:
49+
stmt.detach()
50+
51+
yield_or_return = scf.Yield() if is_yield else func.Return()
52+
53+
then_block = ir.Block((stmt, yield_or_return), argtypes=(node.cond.type,))
54+
then_body = ir.Region(then_block)
55+
else_body = node.else_body.clone()
56+
else_body.detach()
57+
new_if = scf.IfElse(
58+
cond=node.cond, then_body=then_body, else_body=else_body
59+
)
60+
61+
new_if.insert_after(node)
62+
63+
node.delete()
64+
65+
return RewriteResult(has_done_something=True)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from bloqade import qasm2
2+
from bloqade.qasm2.emit import QASM2
3+
4+
5+
def test_unrolling_ifs():
6+
@qasm2.main
7+
def main():
8+
q = qasm2.qreg(2)
9+
c = qasm2.creg(2)
10+
11+
qasm2.h(q[0])
12+
qasm2.measure(q[0], c[0])
13+
14+
if c[0] == 1:
15+
qasm2.x(q[0])
16+
qasm2.x(q[1])
17+
18+
return q
19+
20+
main.print()
21+
22+
target = QASM2()
23+
ast = target.emit(main)
24+
25+
qasm2.parse.pprint(ast)
26+
27+
@qasm2.main
28+
def main_unrolled():
29+
q = qasm2.qreg(2)
30+
c = qasm2.creg(2)
31+
32+
qasm2.h(q[0])
33+
qasm2.measure(q[0], c[0])
34+
35+
if c[0] == 1:
36+
qasm2.x(q[0])
37+
if c[0] == 1:
38+
qasm2.x(q[1])
39+
40+
return q
41+
42+
main_unrolled.print()
43+
44+
target = QASM2()
45+
ast_unrolled = target.emit(main_unrolled)
46+
47+
qasm2.parse.pprint(ast_unrolled)

0 commit comments

Comments
 (0)