Skip to content

Commit 676ed3a

Browse files
david-plweinbe58
andauthored
Implement pass that splits multiline ifs in qasm2 (#252)
Closes #189 Also, it will probably break the tests in #248 --------- Co-authored-by: Phillip Weinberg <[email protected]>
1 parent 75b62f0 commit 676ed3a

File tree

7 files changed

+216
-5
lines changed

7 files changed

+216
-5
lines changed

src/bloqade/qasm2/emit/target.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
allow_parallel: bool = False,
2828
allow_global: bool = False,
2929
custom_gate: bool = True,
30+
unroll_ifs: bool = True,
3031
) -> None:
3132
"""Initialize the QASM2 target.
3233
@@ -43,9 +44,14 @@ def __init__(
4344
qelib1 (bool):
4445
Include the `include "qelib1.inc"` line in the resulting QASM2 AST that's
4546
submitted to qBraid. Defaults to `True`.
47+
4648
custom_gate (bool):
4749
Include the custom gate definitions in the resulting QASM2 AST. Defaults to `True`. If `False`, all the qasm2.gate will be inlined.
4850
51+
unroll_ifs (bool):
52+
Unrolls if statements with multiple qasm2 statements in the body in order to produce valid qasm2 output, which only allows a single
53+
operation in an if body. Defaults to `True`.
54+
4955
5056
5157
"""
@@ -58,6 +64,7 @@ def __init__(
5864
self.custom_gate = custom_gate
5965
self.allow_parallel = allow_parallel
6066
self.allow_global = allow_global
67+
self.unroll_ifs = unroll_ifs
6168

6269
if allow_parallel:
6370
self.main_target = self.main_target.add(qasm2.dialects.parallel)
@@ -87,9 +94,11 @@ def emit(self, entry: ir.Method) -> ast.MainProgram:
8794

8895
# make a cloned instance of kernel
8996
entry = entry.similar()
90-
QASM2Fold(entry.dialects, inline_gate_subroutine=not self.custom_gate).fixpoint(
91-
entry
92-
)
97+
QASM2Fold(
98+
entry.dialects,
99+
inline_gate_subroutine=not self.custom_gate,
100+
unroll_ifs=self.unroll_ifs,
101+
).fixpoint(entry)
93102

94103
if not self.allow_global:
95104
# rewrite global to parallel

src/bloqade/qasm2/passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .py2qasm import Py2QASM as Py2QASM
44
from .qasm2py import QASM2Py as QASM2Py
55
from .parallel import UOpToParallel as UOpToParallel
6+
from .unroll_if import UnrollIfs as UnrollIfs

src/bloqade/qasm2/passes/fold.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,16 @@
2323

2424
from bloqade.qasm2.dialects import expr
2525

26+
from .unroll_if import UnrollIfs
27+
2628

2729
@dataclass
2830
class QASM2Fold(Pass):
2931
"""Fold pass for qasm2.extended"""
3032

3133
constprop: const.Propagate = field(init=False)
3234
inline_gate_subroutine: bool = True
35+
unroll_ifs: bool = True
3336

3437
def __post_init__(self):
3538
self.constprop = const.Propagate(self.dialects)
@@ -61,6 +64,9 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
6164
.join(result)
6265
)
6366

67+
if self.unroll_ifs:
68+
UnrollIfs(mt.dialects).unsafe_run(mt).join(result)
69+
6470
# run typeinfer again after unroll etc. because we now insert
6571
# a lot of new nodes, which might have more precise types
6672
self.typeinfer.unsafe_run(mt)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from kirin import ir
2+
from kirin.passes import Pass
3+
from kirin.rewrite import (
4+
Walk,
5+
Chain,
6+
Fixpoint,
7+
ConstantFold,
8+
CommonSubexpressionElimination,
9+
)
10+
11+
from ..rewrite.split_ifs import LiftThenBody, SplitIfStmts
12+
13+
14+
class UnrollIfs(Pass):
15+
"""This pass lifts statements that are not UOP out of the if body and then splits whatever is left into multiple if statements so you obtain valid QASM2"""
16+
17+
def unsafe_run(self, mt: ir.Method):
18+
result = Walk(LiftThenBody()).rewrite(mt.code)
19+
result = Walk(SplitIfStmts()).rewrite(mt.code).join(result)
20+
result = (
21+
Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
22+
.rewrite(mt.code)
23+
.join(result)
24+
)
25+
return result
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from kirin import ir
2+
from kirin.dialects import scf, func
3+
from kirin.rewrite.abc import RewriteRule, RewriteResult
4+
5+
from ..dialects.uop.stmts import SingleQubitGate, TwoQubitCtrlGate
6+
from ..dialects.core.stmts import Reset, Measure
7+
8+
# TODO: unify with PR #248
9+
AllowedThenType = SingleQubitGate | TwoQubitCtrlGate | Measure | Reset
10+
11+
DontLiftType = AllowedThenType | scf.Yield | func.Return | func.Invoke
12+
13+
14+
class LiftThenBody(RewriteRule):
15+
"""Lifts anything that's not a UOP or a yield/return out of the then body"""
16+
17+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
18+
if not isinstance(node, scf.IfElse):
19+
return RewriteResult()
20+
21+
then_stmts = node.then_body.stmts()
22+
23+
lift_stmts = [stmt for stmt in then_stmts if not isinstance(stmt, DontLiftType)]
24+
25+
if len(lift_stmts) == 0:
26+
return RewriteResult()
27+
28+
for stmt in lift_stmts:
29+
stmt.detach()
30+
stmt.insert_before(node)
31+
32+
return RewriteResult(has_done_something=True)
33+
34+
35+
class SplitIfStmts(RewriteRule):
36+
"""Splits the then body of an if-else statement into multiple if statements"""
37+
38+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
39+
if not isinstance(node, scf.IfElse):
40+
return RewriteResult()
41+
42+
*stmts, yield_or_return = node.then_body.stmts()
43+
44+
if len(stmts) == 1:
45+
return RewriteResult()
46+
47+
is_yield = isinstance(yield_or_return, scf.Yield)
48+
49+
for stmt in stmts:
50+
stmt.detach()
51+
52+
yield_or_return = scf.Yield() if is_yield else func.Return()
53+
54+
then_block = ir.Block((stmt, yield_or_return), argtypes=(node.cond.type,))
55+
then_body = ir.Region(then_block)
56+
else_body = node.else_body.clone()
57+
else_body.detach()
58+
new_if = scf.IfElse(
59+
cond=node.cond, then_body=then_body, else_body=else_body
60+
)
61+
62+
new_if.insert_before(node)
63+
64+
node.delete()
65+
66+
return RewriteResult(has_done_something=True)

test/qasm2/emit/test_qasm2_emit.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,7 @@ def multiline_then():
257257

258258
return q
259259

260-
target = qasm2.emit.QASM2()
261-
260+
target = qasm2.emit.QASM2(unroll_ifs=False)
262261
with pytest.raises(InterpreterError):
263262
target.emit(multiline_then)
264263

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from bloqade import qasm2
2+
from bloqade.qasm2.emit import QASM2
3+
4+
5+
def test_unrolling_ifs():
6+
@qasm2.main
7+
def main():
8+
q = qasm2.qreg(2)
9+
c = qasm2.creg(2)
10+
11+
qasm2.h(q[0])
12+
qasm2.measure(q[0], c[0])
13+
14+
if c[0] == 1:
15+
qasm2.x(q[0])
16+
qasm2.x(q[1])
17+
18+
return q
19+
20+
main.print()
21+
22+
target = QASM2()
23+
ast = target.emit(main)
24+
25+
qasm2.parse.pprint(ast)
26+
27+
@qasm2.main
28+
def main_unrolled():
29+
q = qasm2.qreg(2)
30+
c = qasm2.creg(2)
31+
32+
qasm2.h(q[0])
33+
qasm2.measure(q[0], c[0])
34+
35+
if c[0] == 1:
36+
qasm2.x(q[0])
37+
if c[0] == 1:
38+
qasm2.x(q[1])
39+
40+
return q
41+
42+
main_unrolled.print()
43+
44+
target = QASM2()
45+
ast_unrolled = target.emit(main_unrolled)
46+
47+
qasm2.parse.pprint(ast_unrolled)
48+
49+
50+
def test_nested_kernels():
51+
@qasm2.main
52+
def nested(q: qasm2.QReg, c: qasm2.CReg):
53+
qasm2.h(q[0])
54+
55+
qasm2.measure(q, c)
56+
if c[0] == 1:
57+
qasm2.x(q[0])
58+
qasm2.x(q[1])
59+
60+
return q
61+
62+
@qasm2.main
63+
def main():
64+
q = qasm2.qreg(2)
65+
c = qasm2.creg(2)
66+
67+
nested(q, c)
68+
69+
return c
70+
71+
target = QASM2()
72+
ast = target.emit(main)
73+
74+
qasm2.parse.pprint(ast)
75+
76+
77+
def test_conditional_nested_kernel():
78+
@qasm2.main
79+
def nested(q: qasm2.QReg, c: qasm2.CReg):
80+
qasm2.h(q[0])
81+
82+
qasm2.measure(q, c)
83+
84+
qasm2.x(q[0])
85+
qasm2.x(q[1])
86+
87+
return q
88+
89+
@qasm2.main
90+
def main():
91+
q = qasm2.qreg(2)
92+
c = qasm2.creg(2)
93+
94+
qasm2.h(q[0])
95+
qasm2.measure(q, c)
96+
97+
if c[0] == 1:
98+
nested(q, c)
99+
100+
return c
101+
102+
target = QASM2(unroll_ifs=True)
103+
ast = target.emit(main)
104+
105+
qasm2.parse.pprint(ast)

0 commit comments

Comments
 (0)