Skip to content

Commit 8025c39

Browse files
committed
Clean up & fix qasm2 lowering of if stmt (#319)
This fixes the bug in #249. However, instead of actually fixing the code that was there I basically just rewrote the whole thing. It was easy to simplify since the if statements in qasm2 are much more restrictive than the previous code assumed (no else, no assigments, etc.).
1 parent bc62c86 commit 8025c39

File tree

2 files changed

+36
-71
lines changed

2 files changed

+36
-71
lines changed

src/bloqade/qasm2/parse/lowering.py

Lines changed: 9 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import field, dataclass
33

44
from kirin import ir, types, lowering
5-
from kirin.dialects import cf, func, ilist
5+
from kirin.dialects import cf, scf, func, ilist
66

77
from bloqade.qasm2.types import CRegType, QRegType, QubitType
88
from bloqade.qasm2.dialects import uop, core, expr, glob, noise, parallel
@@ -178,92 +178,30 @@ def visit_UGate(self, state: lowering.State[ast.Node], node: ast.UGate):
178178
def visit_Reset(self, state: lowering.State[ast.Node], node: ast.Reset):
179179
state.current_frame.push(core.Reset(qarg=state.lower(node.qarg).expect_one()))
180180

181-
# TODO: clean this up? copied from cf dialect with a small modification
182181
def visit_IfStmt(self, state: lowering.State[ast.Node], node: ast.IfStmt):
183182
cond_stmt = core.CRegEq(
184183
lhs=state.lower(node.cond.lhs).expect_one(),
185184
rhs=state.lower(node.cond.rhs).expect_one(),
186185
)
187186
cond = state.current_frame.push(cond_stmt).result
188187
frame = state.current_frame
189-
before_block = frame.curr_block
190188

191-
with state.frame(node.body, region=frame.curr_region) as if_frame:
189+
with state.frame(node.body) as if_frame:
192190
true_cond = if_frame.entr_block.args.append_from(types.Bool, cond.name)
193191
if cond.name:
194192
if_frame.defs[cond.name] = true_cond
195193

194+
# NOTE: pass in definitions from outer scope (usually just for the qreg)
195+
if_frame.defs.update(frame.defs)
196+
196197
if_frame.exhaust()
197-
self.branch_next_if_not_terminated(if_frame)
198198

199-
with state.frame([], region=frame.curr_region) as else_frame:
200-
true_cond = else_frame.entr_block.args.append_from(types.Bool, cond.name)
201-
if cond.name:
202-
else_frame.defs[cond.name] = true_cond
203-
else_frame.exhaust()
204-
self.branch_next_if_not_terminated(else_frame)
205-
206-
with state.frame(frame.stream.split(), region=frame.curr_region) as after_frame:
207-
after_frame.defs.update(frame.defs)
208-
phi: set[str] = set()
209-
for name in if_frame.defs.keys():
210-
if frame.get(name):
211-
phi.add(name)
212-
elif name in else_frame.defs:
213-
phi.add(name)
214-
215-
for name in else_frame.defs.keys():
216-
if frame.get(name): # not defined in if_frame
217-
phi.add(name)
218-
219-
for name in phi:
220-
after_frame.defs[name] = after_frame.entr_block.args.append_from(
221-
types.Any, name
222-
)
199+
# NOTE: qasm2 can never yield anything from if
200+
if_frame.push(scf.Yield())
223201

224-
after_frame.exhaust()
225-
self.branch_next_if_not_terminated(after_frame)
226-
after_frame.next_block.stmts.append(
227-
cf.Branch(arguments=(), successor=frame.next_block)
228-
)
202+
then_body = if_frame.curr_region
229203

230-
if_args = []
231-
for name in phi:
232-
if value := if_frame.get(name):
233-
if_args.append(value)
234-
else:
235-
raise lowering.BuildError(f"undefined variable {name} in if branch")
236-
237-
else_args = []
238-
for name in phi:
239-
if value := else_frame.get(name):
240-
else_args.append(value)
241-
else:
242-
raise lowering.BuildError(f"undefined variable {name} in else branch")
243-
244-
if_frame.next_block.stmts.append(
245-
cf.Branch(
246-
arguments=tuple(if_args),
247-
successor=after_frame.entr_block,
248-
)
249-
)
250-
else_frame.next_block.stmts.append(
251-
cf.Branch(
252-
arguments=tuple(else_args),
253-
successor=after_frame.entr_block,
254-
)
255-
)
256-
before_block.stmts.append(
257-
cf.ConditionalBranch(
258-
cond=cond,
259-
then_arguments=(cond,),
260-
then_successor=if_frame.entr_block,
261-
else_arguments=(cond,),
262-
else_successor=else_frame.entr_block,
263-
)
264-
)
265-
frame.defs.update(after_frame.defs)
266-
frame.jump_next_block()
204+
state.current_frame.push(scf.IfElse(cond, then_body=then_body))
267205

268206
def branch_next_if_not_terminated(self, frame: lowering.Frame):
269207
"""Branch to the next block if the current block is not terminated.

test/qasm2/test_lowering.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,30 @@ def test_gate_with_params():
136136
assert ket[1] == ket[2] == 0
137137
assert math.isclose(abs(ket[0]) ** 2, 0.5, abs_tol=1e-6)
138138
assert math.isclose(abs(ket[3]) ** 2, 0.5, abs_tol=1e-6)
139+
140+
141+
def test_if_lowering():
142+
143+
qasm2_prog = textwrap.dedent(
144+
"""
145+
OPENQASM 2.0;
146+
include "qelib1.inc";
147+
qreg q[1];
148+
creg c[1];
149+
if(c == 1) x q[0];
150+
"""
151+
)
152+
153+
main = qasm2.loads(qasm2_prog)
154+
155+
main.print()
156+
157+
@qasm2.main
158+
def main2():
159+
q = qasm2.qreg(1)
160+
c = qasm2.creg(1)
161+
162+
if c == 1:
163+
qasm2.x(q[0])
164+
165+
main2.print()

0 commit comments

Comments
 (0)