Skip to content

Commit ad7d232

Browse files
committed
Expose kwarg and fix insertion order
1 parent 264a9fb commit ad7d232

File tree

3 files changed

+48
-15
lines changed

3 files changed

+48
-15
lines changed

src/bloqade/qasm2/emit/target.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
allow_parallel: bool = False,
2828
allow_global: bool = False,
2929
custom_gate: bool = True,
30+
unroll_ifs: bool = True,
3031
) -> None:
3132
"""Initialize the QASM2 target.
3233
@@ -43,9 +44,14 @@ def __init__(
4344
qelib1 (bool):
4445
Include the `include "qelib1.inc"` line in the resulting QASM2 AST that's
4546
submitted to qBraid. Defaults to `True`.
47+
4648
custom_gate (bool):
4749
Include the custom gate definitions in the resulting QASM2 AST. Defaults to `True`. If `False`, all the qasm2.gate will be inlined.
4850
51+
unroll_ifs (bool):
52+
Unrolls if statements with multiple qasm2 statements in the body in order to produce valid qasm2 output, which only allows a single
53+
operation in an if body. Defaults to `True`.
54+
4955
5056
5157
"""
@@ -58,6 +64,7 @@ def __init__(
5864
self.custom_gate = custom_gate
5965
self.allow_parallel = allow_parallel
6066
self.allow_global = allow_global
67+
self.unroll_ifs = unroll_ifs
6168

6269
if allow_parallel:
6370
self.main_target = self.main_target.add(qasm2.dialects.parallel)
@@ -87,9 +94,11 @@ def emit(self, entry: ir.Method) -> ast.MainProgram:
8794

8895
# make a cloned instance of kernel
8996
entry = entry.similar()
90-
QASM2Fold(entry.dialects, inline_gate_subroutine=not self.custom_gate).fixpoint(
91-
entry
92-
)
97+
QASM2Fold(
98+
entry.dialects,
99+
inline_gate_subroutine=not self.custom_gate,
100+
unroll_ifs=self.unroll_ifs,
101+
).fixpoint(entry)
93102

94103
if not self.allow_global:
95104
# rewrite global to parallel

src/bloqade/qasm2/rewrite/split_ifs.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44

55
from ..dialects.uop.stmts import SingleQubitGate, TwoQubitCtrlGate
66
from ..dialects.core.stmts import Reset, Measure
7-
from ..dialects.expr.stmts import GateFunction
87

98
# TODO: unify with PR #248
109
AllowedThenType = SingleQubitGate | TwoQubitCtrlGate | Measure | Reset
1110

12-
DontLiftType = AllowedThenType | scf.Yield | func.Return
11+
DontLiftType = AllowedThenType | scf.Yield | func.Return | func.Invoke
1312

1413

1514
class LiftThenBody(RewriteRule):
@@ -21,16 +20,10 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
2120

2221
then_stmts = node.then_body.stmts()
2322

24-
# TODO: should we leave QRegGet in?
25-
lift_stmts: list[ir.Statement] = []
26-
for stmt in then_stmts:
27-
if isinstance(stmt, DontLiftType):
28-
continue
23+
lift_stmts = [stmt for stmt in then_stmts if not isinstance(stmt, DontLiftType)]
2924

30-
if isinstance(stmt, func.Invoke) and isinstance(stmt.callee, GateFunction):
31-
continue
32-
33-
lift_stmts.append(stmt)
25+
if len(lift_stmts) == 0:
26+
return RewriteResult()
3427

3528
for stmt in lift_stmts:
3629
stmt.detach()
@@ -66,7 +59,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
6659
cond=node.cond, then_body=then_body, else_body=else_body
6760
)
6861

69-
new_if.insert_after(node)
62+
new_if.insert_before(node)
7063

7164
node.delete()
7265

test/qasm2/passes/test_unroll_if.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,34 @@ def main():
7272
ast = target.emit(main)
7373

7474
qasm2.parse.pprint(ast)
75+
76+
77+
def test_conditional_nested_kernel():
78+
@qasm2.main
79+
def nested(q: qasm2.QReg, c: qasm2.CReg):
80+
qasm2.h(q[0])
81+
82+
qasm2.measure(q, c)
83+
84+
qasm2.x(q[0])
85+
qasm2.x(q[1])
86+
87+
return q
88+
89+
@qasm2.main
90+
def main():
91+
q = qasm2.qreg(2)
92+
c = qasm2.creg(2)
93+
94+
qasm2.h(q[0])
95+
qasm2.measure(q, c)
96+
97+
if c[0] == 1:
98+
nested(q, c)
99+
100+
return c
101+
102+
target = QASM2(unroll_ifs=True)
103+
ast = target.emit(main)
104+
105+
qasm2.parse.pprint(ast)

0 commit comments

Comments
 (0)