Skip to content

Commit db9743b

Browse files
committed
Refactor QASM2 emission logic. Removed handling GateFunction, it is now handled in func.Function.
1 parent d3b41f5 commit db9743b

File tree

7 files changed

+55
-94
lines changed

7 files changed

+55
-94
lines changed

src/bloqade/cirq_utils/emit/base.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -229,15 +229,9 @@ class __FuncEmit(MethodTable):
229229

230230
@impl(func.Function)
231231
def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function):
232-
for block in stmt.body.blocks:
233-
frame.current_block = block
234-
for s in block.stmts:
235-
frame.current_stmt = s
236-
stmt_results = emit.frame_eval(frame, s)
237-
if isinstance(stmt_results, tuple):
238-
if len(stmt_results) != 0:
239-
frame.set_values(s.results, stmt_results)
240-
continue
232+
result = emit.frame_eval(frame, stmt)
233+
if isinstance(result, tuple):
234+
frame.set_values(stmt.results, result)
241235

242236
return (emit.circuit,)
243237

src/bloqade/qasm2/dialects/expr/_emit.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,42 +13,6 @@
1313
@dialect.register(key="emit.qasm2.gate")
1414
class EmitExpr(interp.MethodTable):
1515

16-
@interp.impl(stmts.GateFunction)
17-
def emit_func(
18-
self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.GateFunction
19-
):
20-
21-
args: list[ast.Node] = []
22-
cparams, qparams = [], []
23-
entry_args = stmt.body.blocks[0].args
24-
user_args = entry_args[1:] if len(entry_args) > 0 else []
25-
26-
for arg in user_args:
27-
assert arg.name is not None
28-
29-
args.append(ast.Name(id=arg.name))
30-
if arg.type.is_subseteq(QubitType):
31-
qparams.append(arg.name)
32-
else:
33-
cparams.append(arg.name)
34-
35-
frame.worklist.append(interp.Successor(stmt.body.blocks[0], *args))
36-
if len(entry_args) > 0:
37-
frame.set(entry_args[0], ast.Name(stmt.sym_name or "gate"))
38-
39-
while (succ := frame.worklist.pop()) is not None:
40-
frame.set_values(succ.block.args[1:], succ.block_args)
41-
block_header = emit.emit_block(frame, succ.block)
42-
frame.block_ref[succ.block] = block_header
43-
return (
44-
ast.Gate(
45-
name=stmt.sym_name,
46-
cparams=cparams,
47-
qparams=qparams,
48-
body=frame.body,
49-
),
50-
)
51-
5216
@interp.impl(stmts.ConstInt)
5317
@interp.impl(stmts.ConstFloat)
5418
def emit_const_int(

src/bloqade/qasm2/emit/base.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,17 @@ def run_method(
5353
def emit_block(self, frame: EmitQASM2Frame, block: ir.Block) -> ast.Node | None:
5454
for stmt in block.stmts:
5555
result = self.frame_eval(frame, stmt)
56+
print(stmt.results, result)
5657
if isinstance(result, tuple):
57-
if len(result) == 0:
58-
continue
59-
keys = getattr(stmt, "_results", None) or getattr(stmt, "results", None)
60-
if keys is None:
61-
continue
62-
frame.set_values(keys, result)
58+
frame.set_values(stmt.results, result)
6359
return None
6460

6561
A = TypeVar("A")
6662
B = TypeVar("B")
6763

64+
def eval_fallback(self, frame: EmitQASM2Frame, node: ir.Statement):
65+
return tuple(None for _ in range(len(node.results)))
66+
6867
@overload
6968
def assert_node(self, typ: type[A], node: ast.Node | None) -> A: ...
7069

src/bloqade/qasm2/emit/gate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import field, dataclass
22

33
from kirin import ir, types, interp
4+
from kirin.ir import method
45
from kirin.dialects import py, func, ilist
56
from kirin.ir.dialect import Dialect as Dialect
67
from typing_extensions import Self
@@ -87,6 +88,5 @@ def emit_err(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt):
8788
raise RuntimeError(f"illegal statement {stmt.name} for QASM2 gate routine")
8889

8990
@interp.impl(func.Return)
90-
@interp.impl(func.ConstantNone)
9191
def ignore(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt):
9292
return ()

src/bloqade/qasm2/emit/main.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from kirin.ir.dialect import Dialect as Dialect
77
from typing_extensions import Self
88

9+
from bloqade.types import QubitType
910
from bloqade.qasm2.parse import ast
1011
from bloqade.qasm2.dialects.uop import SingleQubitGate, TwoQubitCtrlGate
1112
from bloqade.qasm2.dialects.expr import GateFunction
@@ -105,27 +106,40 @@ def emit_func(
105106
with gate_emitter.new_frame(
106107
callable_node, has_parent_access=False
107108
) as gate_frame:
108-
gate_result = gate_emitter.frame_eval(gate_frame, callable_node)
109-
gate_obj = None
110-
if isinstance(gate_result, tuple) and len(gate_result) > 0:
111-
maybe = gate_result[0]
112-
if isinstance(maybe, ast.Gate):
113-
gate_obj = maybe
114-
115-
if gate_obj is None:
116-
name = emit.callables.get(
117-
callable_node
118-
) or emit.callables.add(callable_node)
119-
prefix = getattr(emit.callables, "prefix", "") or ""
120-
emit_name = (
121-
name[len(prefix) :]
122-
if prefix and name.startswith(prefix)
123-
else name
124-
)
125-
gate_obj = ast.Gate(
126-
name=emit_name, cparams=[], qparams=[], body=[]
109+
args: list[ast.Node] = []
110+
cparams, qparams = [], []
111+
entry_args = callable_node.body.blocks[0].args
112+
user_args = entry_args[1:] if len(entry_args) > 0 else []
113+
114+
for arg in user_args:
115+
assert arg.name is not None
116+
117+
args.append(ast.Name(id=arg.name))
118+
if arg.type.is_subseteq(QubitType):
119+
qparams.append(arg.name)
120+
else:
121+
cparams.append(arg.name)
122+
123+
# Map block arguments to AST names in the gate frame
124+
for arg in user_args:
125+
gate_frame.set(
126+
arg, ast.Name(id=getattr(arg, "name", "arg"))
127127
)
128128

129+
# Actually emit the gate body by interpreting all blocks
130+
for block in callable_node.body.blocks:
131+
gate_emitter.emit_block(gate_frame, block)
132+
133+
name = emit.callables.get(callable_node) or emit.callables.add(
134+
callable_node
135+
)
136+
gate_obj = ast.Gate(
137+
name=name,
138+
cparams=cparams,
139+
qparams=qparams,
140+
body=gate_frame.body,
141+
)
142+
129143
gate_defs.append(gate_obj)
130144

131145
if emit.dialects.data.intersection((parallel.dialect, glob.dialect)):

src/bloqade/qasm2/emit/target.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,8 @@ def emit(self, entry: ir.Method) -> ast.MainProgram:
124124
extra = []
125125
if self.qelib1:
126126
extra.append(ast.Include("qelib1.inc"))
127-
128127
if self.custom_gate:
129128
cg = CallGraph(entry)
130-
target_gate = EmitQASM2Gate(self.gate_target).initialize()
131129

132130
for _, fns in cg.defs.items():
133131
if len(fns) != 1:
@@ -140,20 +138,16 @@ def emit(self, entry: ir.Method) -> ast.MainProgram:
140138
fn = fn.similar()
141139
QASM2Fold(fn.dialects).fixpoint(fn)
142140

143-
if not self.allow_global:
144-
# rewrite global to parallel
145-
GlobalToParallel(dialects=fn.dialects)(fn)
141+
# if not self.allow_global:
142+
# # rewrite global to parallel
143+
# GlobalToParallel(dialects=fn.dialects)(fn)
146144

147-
if not self.allow_parallel:
148-
# rewrite parallel to uop
149-
ParallelToUOp(dialects=fn.dialects)(fn)
145+
# if not self.allow_parallel:
146+
# # rewrite parallel to uop
147+
# ParallelToUOp(dialects=fn.dialects)(fn)
150148

151149
Py2QASM(fn.dialects)(fn)
152150

153-
target_gate.run(fn)
154-
assert target_gate.output is not None, f"failed to emit {fn.sym_name}"
155-
extra.append(target_gate.output)
156-
157151
main_program.statements = extra + main_program.statements
158152
return main_program
159153

test/qasm2/emit/test_qasm2.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import io
12
from pathlib import Path
23
from contextlib import redirect_stdout
34

@@ -30,14 +31,9 @@ def main():
3031
ast = target.emit(main)
3132
filename = "t_qasm2.qasm"
3233
# Read the filename and store in target
33-
with open(Path(__file__).parent.resolve() / filename, "r") as io:
34-
target = io.read()
35-
# Output into filename in the current directory
36-
out_path = Path(__file__).parent.resolve() / filename
37-
with open(out_path, "w") as f:
38-
with redirect_stdout(f):
39-
qasm2.parse.pprint(ast)
40-
# Read the filename again and store in generated
41-
with open(Path(__file__).parent.resolve() / filename, "r") as io:
42-
generated = io.read()
43-
assert generated.strip() == target.strip()
34+
with open(Path(__file__).parent.resolve() / filename, "r") as txt:
35+
target = txt.read()
36+
buf = io.StringIO()
37+
with redirect_stdout(buf):
38+
qasm2.parse.pprint(ast)
39+
assert buf.getvalue().strip() == target.strip()

0 commit comments

Comments
 (0)